OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StringDictionaryProxy.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include "Logger/Logger.h"
20 #include "Shared/ThreadInfo.h"
21 #include "Shared/misc.h"
22 #include "Shared/sqltypes.h"
23 #include "Shared/thread_count.h"
25 #include "StringOps/StringOps.h"
26 #include "Utils/Regexp.h"
27 #include "Utils/StringLike.h"
28 
29 #include <tbb/parallel_for.h>
30 #include <tbb/task_arena.h>
31 
32 #include <algorithm>
33 #include <iomanip>
34 #include <iostream>
35 #include <string>
36 #include <string_view>
37 #include <thread>
38 
39 StringDictionaryProxy::StringDictionaryProxy(std::shared_ptr<StringDictionary> sd,
40  const shared::StringDictKey& string_dict_key,
41  const int64_t generation)
42  : string_dict_(sd), string_dict_key_(string_dict_key), generation_(generation) {}
43 
44 int32_t truncate_to_generation(const int32_t id, const size_t generation) {
46  return id;
47  }
48  CHECK_GE(id, 0);
49  return static_cast<size_t>(id) >= generation ? StringDictionary::INVALID_STR_ID : id;
50 }
51 
53  const std::vector<std::string>& strings) const {
55  std::vector<int32_t> string_ids(strings.size());
56  getTransientBulkImpl(strings, string_ids.data(), true);
57  return string_ids;
58 }
59 
61  const std::vector<std::string>& strings) {
63  const size_t num_strings = strings.size();
64  std::vector<int32_t> string_ids(num_strings);
65  if (num_strings == 0) {
66  return string_ids;
67  }
68  // Since new strings added to a StringDictionaryProxy are not materialized in the
69  // proxy's underlying StringDictionary, we can use the fast parallel
70  // StringDictionary::getBulk method to fetch ids from the underlying dictionary (which
71  // will return StringDictionary::INVALID_STR_ID for strings that don't exist)
72 
73  // Don't need to be under lock here as the string ids for strings in the underlying
74  // materialized dictionary are immutable
75  const size_t num_strings_not_found =
76  string_dict_->getBulk(strings, string_ids.data(), generation_);
77  if (num_strings_not_found > 0) {
78  std::lock_guard<std::shared_mutex> write_lock(rw_mutex_);
79  for (size_t string_idx = 0; string_idx < num_strings; ++string_idx) {
80  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
81  string_ids[string_idx] = getOrAddTransientUnlocked(strings[string_idx]);
82  }
83  }
84  }
85  return string_ids;
86 }
87 
88 template <typename String>
90  unsigned const new_index = transient_str_to_int_.size();
91  auto transient_id = transientIndexToId(new_index);
92  auto const emplaced = transient_str_to_int_.emplace(str, transient_id);
93  if (emplaced.second) { // (str, transient_id) was added to transient_str_to_int_.
94  transient_string_vec_.push_back(&emplaced.first->first);
95  } else { // str already exists in transient_str_to_int_. Return existing transient_id.
96  transient_id = emplaced.first->second;
97  }
98  return transient_id;
99 }
100 
101 template <typename String>
103  auto const string_id = getIdOfStringFromClient(str);
104  if (string_id != StringDictionary::INVALID_STR_ID) {
105  return string_id;
106  }
107  std::lock_guard<std::shared_mutex> write_lock(rw_mutex_);
108  return getOrAddTransientUnlocked(str);
109 }
110 
111 int32_t StringDictionaryProxy::getOrAddTransient(std::string const& str) {
112  return getOrAddTransientImpl<std::string const&>(str);
113 }
114 
115 int32_t StringDictionaryProxy::getOrAddTransient(std::string_view const sv) {
116  return getOrAddTransientImpl<std::string_view const>(sv);
117 }
118 
119 int32_t StringDictionaryProxy::getIdOfString(const std::string& str) const {
120  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
121  auto const str_id = getIdOfStringFromClient(str);
122  if (str_id != StringDictionary::INVALID_STR_ID || transient_str_to_int_.empty()) {
123  return str_id;
124  }
125  auto it = transient_str_to_int_.find(str);
126  return it != transient_str_to_int_.end() ? it->second
128 }
129 
130 template <typename String>
131 int32_t StringDictionaryProxy::getIdOfStringFromClient(const String& str) const {
132  CHECK_GE(generation_, 0);
133  return truncate_to_generation(string_dict_->getIdOfString(str), generation_);
134 }
135 
136 int32_t StringDictionaryProxy::getIdOfStringNoGeneration(const std::string& str) const {
137  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
138  auto str_id = string_dict_->getIdOfString(str);
139  if (str_id != StringDictionary::INVALID_STR_ID || transient_str_to_int_.empty()) {
140  return str_id;
141  }
142  auto it = transient_str_to_int_.find(str);
143  return it != transient_str_to_int_.end() ? it->second
145 }
146 
148  int8_t* proxy_ptr,
149  int32_t string_id) {
150  CHECK(proxy_ptr != nullptr);
151  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
152  auto [c_str, len] = proxy->getStringBytes(string_id);
153  return c_str;
154 }
155 
156 extern "C" DEVICE RUNTIME_EXPORT size_t
157 StringDictionaryProxy_getStringLength(int8_t* proxy_ptr, int32_t string_id) {
158  CHECK(proxy_ptr != nullptr);
159  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
160  auto [c_str, len] = proxy->getStringBytes(string_id);
161  return len;
162 }
163 
164 extern "C" DEVICE RUNTIME_EXPORT int32_t
165 StringDictionaryProxy_getStringId(int8_t* proxy_ptr, char* c_str_ptr) {
166  CHECK(proxy_ptr != nullptr);
167  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
168  std::string str(c_str_ptr);
169  return proxy->getOrAddTransient(str);
170 }
171 
172 std::string StringDictionaryProxy::getString(int32_t string_id) const {
173  if (inline_int_null_value<int32_t>() == string_id) {
174  return "";
175  }
176  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
177  return getStringUnlocked(string_id);
178 }
179 
180 std::string StringDictionaryProxy::getStringUnlocked(const int32_t string_id) const {
181  if (string_id >= 0 && storageEntryCount() > 0) {
182  return string_dict_->getString(string_id);
183  }
184  unsigned const string_index = transientIdToIndex(string_id);
185  CHECK_LT(string_index, transient_string_vec_.size());
186  return *transient_string_vec_[string_index];
187 }
188 
189 std::vector<std::string> StringDictionaryProxy::getStrings(
190  const std::vector<int32_t>& string_ids) const {
191  std::vector<std::string> strings;
192  if (!string_ids.empty()) {
193  strings.reserve(string_ids.size());
194  for (const auto string_id : string_ids) {
195  if (string_id >= 0) {
196  strings.emplace_back(string_dict_->getString(string_id));
197  } else if (inline_int_null_value<int32_t>() == string_id) {
198  strings.emplace_back("");
199  } else {
200  unsigned const string_index = transientIdToIndex(string_id);
201  strings.emplace_back(*transient_string_vec_[string_index]);
202  }
203  }
204  }
205  return strings;
206 }
207 
208 template <typename String>
210  const String& lookup_string) const {
211  const auto it = transient_str_to_int_.find(lookup_string);
213  : it->second;
214 }
215 
218  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
219  auto timer = DEBUG_TIMER(__func__);
220  CHECK(string_op_infos.size());
222  if (translation_map.empty()) {
223  return translation_map;
224  }
225 
226  const StringOps_Namespace::StringOps string_ops(string_op_infos);
227 
228  const size_t num_transient_entries = translation_map.numTransients();
229  if (num_transient_entries) {
230  const int32_t map_domain_start = translation_map.domainStart();
231  if (num_transient_entries > 10000UL) {
233  tbb::blocked_range<int32_t>(map_domain_start, -1),
234  [&](const tbb::blocked_range<int32_t>& r) {
235  const int32_t start_idx = r.begin();
236  const int32_t end_idx = r.end();
237  for (int32_t source_string_id = start_idx; source_string_id < end_idx;
238  ++source_string_id) {
239  const auto source_string = getStringUnlocked(source_string_id);
240  translation_map[source_string_id] = string_ops.numericEval(source_string);
241  }
242  });
243  } else {
244  for (int32_t source_string_id = map_domain_start; source_string_id < -1;
245  ++source_string_id) {
246  const auto source_string = getStringUnlocked(source_string_id);
247  translation_map[source_string_id] = string_ops.numericEval(source_string);
248  }
249  }
250  }
251 
252  Datum* translation_map_stored_entries_ptr = translation_map.storageData();
253  if (generation_ > 0) {
254  string_dict_->buildDictionaryNumericTranslationMap(
255  translation_map_stored_entries_ptr, generation_, string_op_infos);
256  }
257  translation_map.setNumUntranslatedStrings(0UL);
258 
259  // Todo(todd): Set range start/end with scan
260 
261  return translation_map;
262 }
263 
266  const StringDictionaryProxy* dest_proxy,
267  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
268  auto timer = DEBUG_TIMER(__func__);
269  IdMap id_map = initIdMap();
270 
271  if (id_map.empty()) {
272  return id_map;
273  }
274 
275  const StringOps_Namespace::StringOps string_ops(string_op_infos);
276 
277  // First map transient strings, store at front of vector map
278  const size_t num_transient_entries = id_map.numTransients();
279  size_t num_transient_strings_not_translated = 0UL;
280  if (num_transient_entries) {
281  std::vector<std::string> transient_lookup_strings(num_transient_entries);
282  if (string_ops.size()) {
284  transient_string_vec_.cend(),
285  transient_lookup_strings.rbegin(),
286  [&](std::string const* ptr) { return string_ops(*ptr); });
287  } else {
289  transient_string_vec_.cend(),
290  transient_lookup_strings.rbegin(),
291  [](std::string const* ptr) { return *ptr; });
292  }
293 
294  // This lookup may have a different snapshot of
295  // dest_proxy transients and dictionary than what happens under
296  // the below dest_proxy_read_lock. We may need an unlocked version of
297  // getTransientBulk to ensure consistency (I don't believe
298  // current behavior would cause crashes/races, verify this though)
299 
300  // Todo(mattp): Consider variant of getTransientBulkImp that can take
301  // a vector of pointer-to-strings so we don't have to materialize
302  // transient_string_vec_ into transient_lookup_strings.
303 
304  num_transient_strings_not_translated =
305  dest_proxy->getTransientBulkImpl(transient_lookup_strings, id_map.data(), false);
306  }
307 
308  // Now map strings in dictionary
309  // We place non-transient strings after the transient strings
310  // if they exist, otherwise at index 0
311  int32_t* translation_map_stored_entries_ptr = id_map.storageData();
312 
313  auto dest_transient_lookup_callback = [dest_proxy, translation_map_stored_entries_ptr](
314  const std::string_view& source_string,
315  const int32_t source_string_id) {
316  translation_map_stored_entries_ptr[source_string_id] =
317  dest_proxy->lookupTransientStringUnlocked(source_string);
318  return translation_map_stored_entries_ptr[source_string_id] ==
320  };
321 
322  const size_t num_dest_transients = dest_proxy->transientEntryCountUnlocked();
323  const size_t num_persisted_strings_not_translated =
324  generation_ > 0 ? string_dict_->buildDictionaryTranslationMap(
325  dest_proxy->string_dict_.get(),
326  translation_map_stored_entries_ptr,
327  generation_,
328  dest_proxy->generation_,
329  num_dest_transients > 0UL,
330  dest_transient_lookup_callback,
331  string_op_infos)
332  : 0UL;
333 
334  const size_t num_dest_entries = dest_proxy->entryCountUnlocked();
335  const size_t num_total_entries =
336  id_map.getVectorMap().size() - 1UL /* account for skipped entry -1 */;
337  CHECK_GT(num_total_entries, 0UL);
338  const size_t num_strings_not_translated =
339  num_transient_strings_not_translated + num_persisted_strings_not_translated;
340  CHECK_LE(num_strings_not_translated, num_total_entries);
341  id_map.setNumUntranslatedStrings(num_strings_not_translated);
342 
343  // Below is a conservative setting of range based on the size of the destination proxy,
344  // but probably not worth a scan over the data (or inline computation as we translate)
345  // to compute the actual ranges
346 
347  id_map.setRangeStart(
348  num_dest_transients > 0 ? -1 - static_cast<int32_t>(num_dest_transients) : 0);
349  id_map.setRangeEnd(dest_proxy->storageEntryCount());
350 
351  const size_t num_entries_translated = num_total_entries - num_strings_not_translated;
352  const float match_pct =
353  100.0 * static_cast<float>(num_entries_translated) / num_total_entries;
354  VLOG(1) << std::fixed << std::setprecision(2) << match_pct << "% ("
355  << num_entries_translated << " entries) from dictionary ("
356  << string_dict_->getDictKey() << ") with " << num_total_entries
357  << " total entries ( " << num_transient_entries << " literals)"
358  << " translated to dictionary (" << dest_proxy->string_dict_->getDictKey()
359  << ") with " << num_dest_entries << " total entries ("
360  << dest_proxy->transientEntryCountUnlocked() << " literals).";
361 
362  return id_map;
363 }
364 
365 void order_translation_locks(const shared::StringDictKey& source_dict_key,
366  const shared::StringDictKey& dest_dict_key,
367  std::shared_lock<std::shared_mutex>& source_proxy_read_lock,
368  std::unique_lock<std::shared_mutex>& dest_proxy_write_lock) {
369  if (source_dict_key == dest_dict_key) {
370  // proxies are same, only take one write lock
371  dest_proxy_write_lock.lock();
372  } else if (source_dict_key < dest_dict_key) {
373  source_proxy_read_lock.lock();
374  dest_proxy_write_lock.lock();
375  } else {
376  dest_proxy_write_lock.lock();
377  source_proxy_read_lock.lock();
378  }
379 }
380 
383  const StringDictionaryProxy* dest_proxy,
384  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
385  const auto& source_dict_id = getDictKey();
386  const auto& dest_dict_id = dest_proxy->getDictKey();
387 
388  std::shared_lock<std::shared_mutex> source_proxy_read_lock(rw_mutex_, std::defer_lock);
389  std::unique_lock<std::shared_mutex> dest_proxy_write_lock(dest_proxy->rw_mutex_,
390  std::defer_lock);
392  source_dict_id, dest_dict_id, source_proxy_read_lock, dest_proxy_write_lock);
393  return buildIntersectionTranslationMapToOtherProxyUnlocked(dest_proxy, string_op_infos);
394 }
395 
397  StringDictionaryProxy* dest_proxy,
398  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
399  auto timer = DEBUG_TIMER(__func__);
400 
401  const auto& source_dict_id = getDictKey();
402  const auto& dest_dict_id = dest_proxy->getDictKey();
403  std::shared_lock<std::shared_mutex> source_proxy_read_lock(rw_mutex_, std::defer_lock);
404  std::unique_lock<std::shared_mutex> dest_proxy_write_lock(dest_proxy->rw_mutex_,
405  std::defer_lock);
407  source_dict_id, dest_dict_id, source_proxy_read_lock, dest_proxy_write_lock);
408 
409  auto id_map =
410  buildIntersectionTranslationMapToOtherProxyUnlocked(dest_proxy, string_op_infos);
411  if (id_map.empty()) {
412  return id_map;
413  }
414  const auto num_untranslated_strings = id_map.numUntranslatedStrings();
415  if (num_untranslated_strings > 0) {
416  const size_t total_post_translation_dest_transients =
417  num_untranslated_strings + dest_proxy->transientEntryCountUnlocked();
418  constexpr size_t max_allowed_transients =
419  static_cast<size_t>(std::numeric_limits<int32_t>::max() -
420  2); /* -2 accounts for INVALID_STR_ID and NULL value */
421  if (total_post_translation_dest_transients > max_allowed_transients) {
422  std::stringstream ss;
423  ss << "Union translation to dictionary " << getDictKey() << " would result in "
424  << total_post_translation_dest_transients
425  << " transient entries, which is more than limit of " << max_allowed_transients
426  << " transients.";
427  throw std::runtime_error(ss.str());
428  }
429  const int32_t map_domain_start = id_map.domainStart();
430  const int32_t map_domain_end = id_map.domainEnd();
431 
432  const StringOps_Namespace::StringOps string_ops(string_op_infos);
433  const bool has_string_ops = string_ops.size();
434 
435  // First iterate over transient strings and add to dest map
436  // Todo (todd): Add call to fetch string_views (local) or strings (distributed)
437  // for all non-translated ids to avoid string-by-string fetch
438 
439  for (int32_t source_string_id = map_domain_start; source_string_id < -1;
440  ++source_string_id) {
441  if (id_map[source_string_id] == StringDictionary::INVALID_STR_ID) {
442  const auto source_string = getStringUnlocked(source_string_id);
443  const auto dest_string_id = dest_proxy->getOrAddTransientUnlocked(
444  has_string_ops ? string_ops(source_string) : source_string);
445  id_map[source_string_id] = dest_string_id;
446  }
447  }
448  // Now iterate over stored strings
449  for (int32_t source_string_id = 0; source_string_id < map_domain_end;
450  ++source_string_id) {
451  if (id_map[source_string_id] == StringDictionary::INVALID_STR_ID) {
452  const auto source_string = string_dict_->getString(source_string_id);
453  const auto dest_string_id = dest_proxy->getOrAddTransientUnlocked(
454  has_string_ops ? string_ops(source_string) : source_string);
455  id_map[source_string_id] = dest_string_id;
456  }
457  }
458  }
459  // We may have added transients to the destination proxy, use this to update
460  // our id map range (used downstream for ExpressionRange)
461 
462  const size_t num_dest_transients = dest_proxy->transientEntryCountUnlocked();
463  id_map.setRangeStart(
464  num_dest_transients > 0 ? -1 - static_cast<int32_t>(num_dest_transients) : 0);
465  return id_map;
466 }
467 
468 template <typename T>
469 std::vector<T> StringDictionaryProxy::getLike(const std::string& pattern,
470  const bool icase,
471  const bool is_simple,
472  const char escape) const {
473  CHECK_GE(generation_, 0);
474  auto result = string_dict_->getLike<T>(pattern, icase, is_simple, escape, generation_);
475  auto is_like_impl = icase ? is_simple ? string_ilike_simple : string_ilike
476  : is_simple ? string_like_simple
477  : string_like;
478  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
479  auto const str = *transient_string_vec_[index];
480  if (is_like_impl(str.c_str(), str.size(), pattern.c_str(), pattern.size(), escape)) {
481  result.push_back(transientIndexToId(index));
482  }
483  }
484  return result;
485 }
486 
487 template std::vector<int32_t> StringDictionaryProxy::getLike<int32_t>(
488  const std::string& pattern,
489  const bool icase,
490  const bool is_simple,
491  const char escape) const;
492 
493 template std::vector<int64_t> StringDictionaryProxy::getLike<int64_t>(
494  const std::string& pattern,
495  const bool icase,
496  const bool is_simple,
497  const char escape) const;
498 
499 namespace {
500 
501 bool do_compare(const std::string& str,
502  const std::string& pattern,
503  const std::string& comp_operator) {
504  int res = str.compare(pattern);
505  if (comp_operator == "<") {
506  return res < 0;
507  } else if (comp_operator == "<=") {
508  return res <= 0;
509  } else if (comp_operator == "=") {
510  return res == 0;
511  } else if (comp_operator == ">") {
512  return res > 0;
513  } else if (comp_operator == ">=") {
514  return res >= 0;
515  } else if (comp_operator == "<>") {
516  return res != 0;
517  }
518  throw std::runtime_error("unsupported string compare operator");
519 }
520 
521 } // namespace
522 
524  const std::string& pattern,
525  const std::string& comp_operator) const {
526  CHECK_GE(generation_, 0);
527  auto result = string_dict_->getCompare(pattern, comp_operator, generation_);
528  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
529  if (do_compare(*transient_string_vec_[index], pattern, comp_operator)) {
530  result.push_back(transientIndexToId(index));
531  }
532  }
533  return result;
534 }
535 
536 namespace {
537 
538 bool is_regexp_like(const std::string& str,
539  const std::string& pattern,
540  const char escape) {
541  return regexp_like(str.c_str(), str.size(), pattern.c_str(), pattern.size(), escape);
542 }
543 
544 } // namespace
545 
546 std::vector<int32_t> StringDictionaryProxy::getRegexpLike(const std::string& pattern,
547  const char escape) const {
548  CHECK_GE(generation_, 0);
549  auto result = string_dict_->getRegexpLike(pattern, escape, generation_);
550  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
551  if (is_regexp_like(*transient_string_vec_[index], pattern, escape)) {
552  result.push_back(transientIndexToId(index));
553  }
554  }
555  return result;
556 }
557 
558 int32_t StringDictionaryProxy::getOrAdd(const std::string& str) noexcept {
559  return string_dict_->getOrAdd(str);
560 }
561 
562 std::pair<const char*, size_t> StringDictionaryProxy::getStringBytes(
563  int32_t string_id) const noexcept {
564  if (string_id >= 0) {
565  return string_dict_.get()->getStringBytes(string_id);
566  }
567  unsigned const string_index = transientIdToIndex(string_id);
568  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
569  CHECK_LT(string_index, transient_string_vec_.size());
570  std::string const* const str_ptr = transient_string_vec_[string_index];
571  return {str_ptr->c_str(), str_ptr->size()};
572 }
573 
575  const size_t num_storage_entries{generation_ == -1 ? string_dict_->storageEntryCount()
576  : generation_};
577  CHECK_LE(num_storage_entries, static_cast<size_t>(std::numeric_limits<int32_t>::max()));
578  return num_storage_entries;
579 }
580 
582  // CHECK_LE(num_storage_entries,
583  // static_cast<size_t>(std::numeric_limits<int32_t>::max()));
584  const size_t num_transient_entries{transient_str_to_int_.size()};
585  CHECK_LE(num_transient_entries,
586  static_cast<size_t>(std::numeric_limits<int32_t>::max()) - 1);
587  return num_transient_entries;
588 }
589 
591  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
593 }
594 
597 }
598 
600  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
601  return entryCountUnlocked();
602 }
603 
604 // Iterate over transient strings, then non-transients.
606  StringDictionary::StringCallback& serial_callback) const {
607  constexpr int32_t max_transient_id = -2;
608  // Iterate over transient strings.
609  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
610  std::string const& str = *transient_string_vec_[index];
611  int32_t const string_id = max_transient_id - index;
612  serial_callback(str, string_id);
613  }
614  // Iterate over non-transient strings.
615  string_dict_->eachStringSerially(generation_, serial_callback);
616 }
617 
618 // For each (string/_view,old_id) pair passed in:
619 // * Get the new_id based on sdp_'s dictionary, or add it as a transient.
620 // * The StringDictionary is local, so call the faster getUnlocked() method.
621 // * Store the old_id -> new_id translation into the id_map_.
625 
626  public:
628  : sdp_(sdp), id_map_(id_map) {}
629  void operator()(std::string const& str, int32_t const string_id) override {
630  operator()(std::string_view(str), string_id);
631  }
632  void operator()(std::string_view const sv, int32_t const old_id) override {
633  int32_t const new_id = sdp_->string_dict_->getUnlocked(sv);
634  id_map_[old_id] = new_id == StringDictionary::INVALID_STR_ID
636  : new_id;
637  }
638 };
639 
640 // For each (string,old_id) pair passed in:
641 // * Get the new_id based on sdp_'s dictionary, or add it as a transient.
642 // * The StringDictionary is not local, so call string_dict_->makeLambdaStringToId()
643 // to make a lookup hash.
644 // * Store the old_id -> new_id translation into the id_map_.
648  using Lambda = std::function<int32_t(std::string const&)>;
650 
651  public:
653  : sdp_(sdp)
654  , id_map_(id_map)
655  , string_to_id_(sdp->string_dict_->makeLambdaStringToId()) {}
656  void operator()(std::string const& str, int32_t const old_id) override {
657  int32_t const new_id = string_to_id_(str);
658  id_map_[old_id] = new_id == StringDictionary::INVALID_STR_ID
660  : new_id;
661  }
662  void operator()(std::string_view const, int32_t const string_id) override {
663  UNREACHABLE() << "StringNetworkCallback requires a std::string.";
664  }
665 };
666 
667 // Union strings from both StringDictionaryProxies into *this as transients.
668 // Return id_map: sdp_rhs:string_id -> this:string_id for each string in sdp_rhs.
670  StringDictionaryProxy const& sdp_rhs) {
671  IdMap id_map = sdp_rhs.initIdMap();
672  // serial_callback cannot be parallelized due to calling getOrAddTransientUnlocked().
673  std::unique_ptr<StringDictionary::StringCallback> serial_callback;
674  if (string_dict_->isClient()) {
675  serial_callback = std::make_unique<StringNetworkCallback>(this, id_map);
676  } else {
677  serial_callback = std::make_unique<StringLocalCallback>(this, id_map);
678  }
679  // Import all non-duplicate strings (transient and non-transient) and add to id_map.
680  sdp_rhs.eachStringSerially(*serial_callback);
681  return id_map;
682 }
683 
684 void StringDictionaryProxy::updateGeneration(const int64_t generation) noexcept {
685  if (generation == -1) {
686  return;
687  }
688  if (generation_ != -1) {
689  CHECK_EQ(generation_, generation);
690  return;
691  }
692  generation_ = generation;
693 }
694 
696  const std::vector<std::string>& strings,
697  int32_t* string_ids,
698  const bool take_read_lock) const {
699  const size_t num_strings = strings.size();
700  if (num_strings == 0) {
701  return 0UL;
702  }
703  // StringDictionary::getBulk returns the number of strings not found
704  if (string_dict_->getBulk(strings, string_ids, generation_) == 0UL) {
705  return 0UL;
706  }
707 
708  // If here, dictionary could not find at least 1 target string,
709  // now look these up in the transient dictionary
710  // transientLookupBulk returns the number of strings not found
711  return transientLookupBulk(strings, string_ids, take_read_lock);
712 }
713 
714 template <typename String>
716  const std::vector<String>& lookup_strings,
717  int32_t* string_ids,
718  const bool take_read_lock) const {
719  const size_t num_strings = lookup_strings.size();
720  auto read_lock = take_read_lock ? std::shared_lock<std::shared_mutex>(rw_mutex_)
721  : std::shared_lock<std::shared_mutex>();
722 
723  if (num_strings == static_cast<size_t>(0) || transient_str_to_int_.empty()) {
724  return 0UL;
725  }
726  constexpr size_t tbb_parallel_threshold{20000};
727  if (num_strings < tbb_parallel_threshold) {
728  return transientLookupBulkUnlocked(lookup_strings, string_ids);
729  } else {
730  return transientLookupBulkParallelUnlocked(lookup_strings, string_ids);
731  }
732 }
733 
734 template <typename String>
736  const std::vector<String>& lookup_strings,
737  int32_t* string_ids) const {
738  const size_t num_strings = lookup_strings.size();
739  size_t num_strings_not_found = 0;
740  for (size_t string_idx = 0; string_idx < num_strings; ++string_idx) {
741  if (string_ids[string_idx] != StringDictionary::INVALID_STR_ID) {
742  continue;
743  }
744  // If we're here it means we need to look up this string as we don't
745  // have a valid id for it
746  string_ids[string_idx] = lookupTransientStringUnlocked(lookup_strings[string_idx]);
747  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
748  num_strings_not_found++;
749  }
750  }
751  return num_strings_not_found;
752 }
753 
754 template <typename String>
756  const std::vector<String>& lookup_strings,
757  int32_t* string_ids) const {
758  const size_t num_lookup_strings = lookup_strings.size();
759  const size_t target_inputs_per_thread = 20000L;
760  ThreadInfo thread_info(
761  std::thread::hardware_concurrency(), num_lookup_strings, target_inputs_per_thread);
762  CHECK_GE(thread_info.num_threads, 1L);
763  CHECK_GE(thread_info.num_elems_per_thread, 1L);
764 
765  std::vector<size_t> num_strings_not_found_per_thread(thread_info.num_threads, 0UL);
766 
767  tbb::task_arena limited_arena(thread_info.num_threads);
768  limited_arena.execute([&] {
770  tbb::blocked_range<size_t>(
771  0, num_lookup_strings, thread_info.num_elems_per_thread /* tbb grain_size */),
772  [&](const tbb::blocked_range<size_t>& r) {
773  const size_t start_idx = r.begin();
774  const size_t end_idx = r.end();
775  size_t num_local_strings_not_found = 0;
776  for (size_t string_idx = start_idx; string_idx < end_idx; ++string_idx) {
777  if (string_ids[string_idx] != StringDictionary::INVALID_STR_ID) {
778  continue;
779  }
780  string_ids[string_idx] =
781  lookupTransientStringUnlocked(lookup_strings[string_idx]);
782  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
783  num_local_strings_not_found++;
784  }
785  }
786  const size_t tbb_thread_idx = tbb::this_task_arena::current_thread_index();
787  num_strings_not_found_per_thread[tbb_thread_idx] = num_local_strings_not_found;
788  },
789  tbb::simple_partitioner());
790  });
791  size_t num_strings_not_found = 0;
792  for (int64_t thread_idx = 0; thread_idx < thread_info.num_threads; ++thread_idx) {
793  num_strings_not_found += num_strings_not_found_per_thread[thread_idx];
794  }
795  return num_strings_not_found;
796 }
797 
799  return string_dict_.get();
800 }
801 
802 int64_t StringDictionaryProxy::getGeneration() const noexcept {
803  return generation_;
804 }
805 
807  return string_dict_key_ == rhs.string_dict_key_ &&
809 }
810 
812  return !operator==(rhs);
813 }
void eachStringSerially(StringDictionary::StringCallback &) const
int32_t getOrAddTransientImpl(String)
void setNumUntranslatedStrings(const size_t num_untranslated_strings)
const shared::StringDictKey string_dict_key_
#define CHECK_EQ(x, y)
Definition: Logger.h:301
std::pair< const char *, size_t > getStringBytes(int32_t string_id) const noexcept
size_t transientEntryCountUnlocked() const
StringLocalCallback(StringDictionaryProxy *sdp, StringDictionaryProxy::IdMap &id_map)
int64_t num_elems_per_thread
Definition: ThreadInfo.h:23
StringDictionaryProxy::IdMap & id_map_
size_t entryCount() const
Returns the number of total string entries for this proxy, both stored in the underlying dictionary a...
int32_t getIdOfStringNoGeneration(const std::string &str) const
std::function< int32_t(std::string const &)> Lambda
std::string getStringUnlocked(const int32_t string_id) const
size_t storageEntryCount() const
Returns the number of string entries in the underlying string dictionary, at this proxy&#39;s generation_...
#define UNREACHABLE()
Definition: Logger.h:338
StringDictionary * getDictionary() const noexcept
#define CHECK_GE(x, y)
Definition: Logger.h:306
size_t transientLookupBulkUnlocked(const std::vector< String > &lookup_strings, int32_t *string_ids) const
StringDictionaryProxy * sdp_
void operator()(std::string const &str, int32_t const string_id) override
size_t transientLookupBulk(const std::vector< String > &lookup_strings, int32_t *string_ids, const bool take_read_lock) const
std::string getString(int32_t string_id) const
Constants for Builtin SQL Types supported by HEAVY.AI.
RUNTIME_EXPORT DEVICE bool string_ilike_simple(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, char escape_char)
Definition: StringLike.cpp:61
IdMap buildIntersectionTranslationMapToOtherProxyUnlocked(const StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
size_t transientLookupBulkParallelUnlocked(const std::vector< String > &lookup_strings, int32_t *string_ids) const
#define CHECK_GT(x, y)
Definition: Logger.h:305
int32_t getIdOfStringFromClient(String const &) const
std::vector< int32_t > getTransientBulk(const std::vector< std::string > &strings) const
Executes read-only lookup of a vector of strings and returns a vector of their integer ids...
TranslationMap< Datum > buildNumericTranslationMap(const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
Builds a vectorized string_id translation map from this proxy to dest_proxy.
std::vector< int32_t > getCompare(const std::string &pattern, const std::string &comp_operator) const
#define DEVICE
bool is_regexp_like(const std::string &str, const std::string &pattern, const char escape)
StringNetworkCallback(StringDictionaryProxy *sdp, StringDictionaryProxy::IdMap &id_map)
static constexpr int32_t INVALID_STR_ID
std::shared_ptr< StringDictionary > string_dict_
int64_t num_threads
Definition: ThreadInfo.h:22
IdMap transientUnion(StringDictionaryProxy const &)
std::vector< std::string const * > transient_string_vec_
void setRangeEnd(const int32_t range_end)
RUNTIME_EXPORT DEVICE bool string_like(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: StringLike.cpp:250
void operator()(std::string const &str, int32_t const old_id) override
int32_t lookupTransientStringUnlocked(const String &lookup_string) const
std::vector< std::string > getStrings(const std::vector< int32_t > &string_ids) const
size_t getTransientBulkImpl(const std::vector< std::string > &strings, int32_t *string_ids, const bool take_read_lock) const
void order_translation_locks(const shared::StringDictKey &source_dict_key, const shared::StringDictKey &dest_dict_key, std::shared_lock< std::shared_mutex > &source_read_lock, std::shared_lock< std::shared_mutex > &dest_read_lock)
void operator()(std::string_view const sv, int32_t const old_id) override
static int32_t transientIndexToId(unsigned const index)
void updateGeneration(const int64_t generation) noexcept
size_t transientEntryCount() const
Returns the number of transient string entries for this proxy,.
OUTPUT transform(INPUT const &input, FUNC const &func)
Definition: misc.h:329
Functions to support the LIKE and ILIKE operator in SQL. Only single-byte character set is supported ...
IdMap buildUnionTranslationMapToOtherProxy(StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_types) const
StringDictionaryProxy(StringDictionaryProxy const &)=delete
void setRangeStart(const int32_t range_start)
int32_t getOrAddTransient(const std::string &)
#define RUNTIME_EXPORT
std::vector< T > getLike(const std::string &pattern, const bool icase, const bool is_simple, const char escape) const
#define CHECK_LT(x, y)
Definition: Logger.h:303
void operator()(std::string_view const, int32_t const string_id) override
RUNTIME_EXPORT DEVICE bool string_like_simple(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, char escape_char)
Definition: StringLike.cpp:43
bool do_compare(const std::string &str, const std::string &pattern, const std::string &comp_operator)
#define CHECK_LE(x, y)
Definition: Logger.h:304
StringDictionaryProxy * sdp_
int32_t getOrAddTransientUnlocked(String const &)
bool operator!=(StringDictionaryProxy const &) const
std::vector< int32_t > getRegexpLike(const std::string &pattern, const char escape) const
int32_t getOrAdd(const std::string &str) noexcept
bool operator==(StringDictionaryProxy const &) const
std::vector< T > const & getVectorMap() const
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
std::vector< int32_t > getOrAddTransientBulk(const std::vector< std::string > &strings)
IdMap buildIntersectionTranslationMapToOtherProxy(const StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
#define CHECK(condition)
Definition: Logger.h:291
DEVICE RUNTIME_EXPORT int32_t StringDictionaryProxy_getStringId(int8_t *proxy_ptr, char *c_str_ptr)
#define DEBUG_TIMER(name)
Definition: Logger.h:412
DEVICE RUNTIME_EXPORT size_t StringDictionaryProxy_getStringLength(int8_t *proxy_ptr, int32_t string_id)
const shared::StringDictKey & getDictKey() const noexcept
Definition: Datum.h:71
RUNTIME_EXPORT DEVICE bool regexp_like(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: Regexp.cpp:39
int32_t getIdOfString(const std::string &str) const
static unsigned transientIdToIndex(int32_t const id)
int64_t getGeneration() const noexcept
#define VLOG(n)
Definition: Logger.h:388
int32_t truncate_to_generation(const int32_t id, const size_t generation)
DEVICE RUNTIME_EXPORT const char * StringDictionaryProxy_getStringBytes(int8_t *proxy_ptr, int32_t string_id)
StringDictionaryProxy::IdMap & id_map_
RUNTIME_EXPORT DEVICE bool string_ilike(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: StringLike.cpp:261