OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionFunctionsWhitelist.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <boost/algorithm/string/join.hpp>
20 #include <iostream>
21 
25 #include "Shared/StringTransform.h"
26 
27 // Get the list of all type specializations for the given function name.
28 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get(
29  const std::string& name) {
30  const auto it = functions_.find(to_upper(name));
31  if (it == functions_.end()) {
32  return nullptr;
33  }
34  return &it->second;
35 }
36 
37 std::vector<ExtensionFunction>* ExtensionFunctionsWhitelist::get_udf(
38  const std::string& name) {
39  const auto it = udf_functions_.find(to_upper(name));
40  if (it == udf_functions_.end()) {
41  return nullptr;
42  }
43  return &it->second;
44 }
45 
46 // Get the list of all udfs
47 std::unordered_set<std::string> ExtensionFunctionsWhitelist::get_udfs_name(
48  const bool is_runtime) {
49  std::unordered_set<std::string> names;
50  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
51  for (auto funcs : collections) {
52  for (auto& pair : *funcs) {
53  ExtensionFunction udf = pair.second.at(0);
54  if (udf.isRuntime() == is_runtime) {
55  names.insert(udf.getName(/* keep_suffix */ false));
56  }
57  }
58  }
59  return names;
60 }
61 
62 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
63  const std::string& name) {
64  std::vector<ExtensionFunction> ext_funcs = {};
65  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
66  const auto uname = to_upper(name);
67  for (auto funcs : collections) {
68  const auto it = funcs->find(uname);
69  if (it == funcs->end()) {
70  continue;
71  }
72  auto ext_func_sigs = it->second;
73  std::copy(ext_func_sigs.begin(), ext_func_sigs.end(), std::back_inserter(ext_funcs));
74  }
75  return ext_funcs;
76 }
77 
78 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
79  const std::string& name,
80  const bool is_gpu) {
81  std::vector<ExtensionFunction> ext_funcs = {};
82  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
83  const auto uname = to_upper(name);
84  for (auto funcs : collections) {
85  const auto it = funcs->find(uname);
86  if (it == funcs->end()) {
87  continue;
88  }
89  auto ext_func_sigs = it->second;
90  std::copy_if(ext_func_sigs.begin(),
91  ext_func_sigs.end(),
92  std::back_inserter(ext_funcs),
93  [is_gpu](auto sig) { return (is_gpu ? sig.isGPU() : sig.isCPU()); });
94  }
95  return ext_funcs;
96 }
97 
98 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
99  const std::string& name,
100  size_t arity) {
101  std::vector<ExtensionFunction> ext_funcs = {};
102  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
103  const auto uname = to_upper(name);
104  for (auto funcs : collections) {
105  const auto it = funcs->find(uname);
106  if (it == funcs->end()) {
107  continue;
108  }
109  auto ext_func_sigs = it->second;
110  std::copy_if(ext_func_sigs.begin(),
111  ext_func_sigs.end(),
112  std::back_inserter(ext_funcs),
113  [arity](auto sig) { return arity == sig.getInputArgs().size(); });
114  }
115  return ext_funcs;
116 }
117 
118 std::vector<ExtensionFunction> ExtensionFunctionsWhitelist::get_ext_funcs(
119  const std::string& name,
120  size_t arity,
121  const SQLTypeInfo& rtype) {
122  std::vector<ExtensionFunction> ext_funcs = {};
123  const auto collections = {&functions_, &udf_functions_, &rt_udf_functions_};
124  const auto uname = to_upper(name);
125  for (auto funcs : collections) {
126  const auto it = funcs->find(uname);
127  if (it == funcs->end()) {
128  continue;
129  }
130  auto ext_func_sigs = it->second;
131  std::copy_if(ext_func_sigs.begin(),
132  ext_func_sigs.end(),
133  std::back_inserter(ext_funcs),
134  [arity, rtype](auto sig) {
135  // Ideally, arity should be equal to the number of
136  // sig arguments but there seems to be many cases
137  // where some sig arguments will be represented
138  // with multiple arguments, for instance, array
139  // argument is translated to data pointer and array
140  // size arguments.
141  if (arity > sig.getInputArgs().size()) {
142  return false;
143  }
144  auto rt = rtype.get_type();
145  auto st = ext_arg_type_to_type_info(sig.getRet()).get_type();
146  return (st == rt || (st == kTINYINT && rt == kBOOLEAN));
147  });
148  }
149  return ext_funcs;
150 }
151 
152 namespace {
153 
154 // Returns the LLVM name for `type`.
156  bool byval = true,
157  bool declare = false) {
158  switch (type) {
160  return "i8"; // clang converts bool to i8
162  return "i8";
164  return "i16";
166  return "i32";
168  return "i64";
170  return "float";
172  return "double";
174  return "void";
176  return "i8*";
178  return "i16*";
180  return "i32*";
182  return "i64*";
184  return "float*";
186  return "double*";
188  return "i1*";
190  return (declare ? "{i8*, i64, i8}*" : "Array<i8>");
192  return (declare ? "{i16*, i64, i8}*" : "Array<i16>");
194  return (declare ? "{i32*, i64, i8}*" : "Array<i32>");
196  return (declare ? "{i64*, i64, i8}*" : "Array<i64>");
198  return (declare ? "{float*, i64, i8}*" : "Array<float>");
200  return (declare ? "{double*, i64, i8}*" : "Array<double>");
202  return (declare ? "{i1*, i64, i8}*" : "Array<i1>");
204  return (declare ? "{i32*, i64, i8}*" : "Array<TextEncodingDict>");
206  return (declare ? "{i8*, i32, i32, i32, i32}*" : "GeoPoint");
208  return (declare ? "{i8*, i32, i32, i32, i32}*" : "GeoMultiPoint");
210  return (declare ? "{i8*, i32, i32, i32, i32}*" : "GeoLineString");
212  return (declare ? "{i8*, i32, i8*, i32, i32, i32, i32}*" : "GeoMultiLineSting");
214  return (declare ? "{i8*, i32, i8*, i32, i32, i32, i32}*" : "GeoPolygon");
216  return (declare ? "{i8*, i32, i8*, i32, i8*, i32, i32, i32, i32}*"
217  : "GeoMultiPolygon");
219  return "cursor";
221  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<i8>");
223  return (declare ? (byval ? "{i16*, i64}" : "i8*") : "Column<i16>");
225  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<i32>");
227  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<i64>");
229  return (declare ? (byval ? "{float*, i64}" : "i8*") : "Column<float>");
231  return (declare ? (byval ? "{double*, i64}" : "i8*") : "Column<double>");
233  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "Column<bool>");
235  return (declare ? (byval ? "{i32*, i64}" : "i8*") : "Column<TextEncodingDict>");
237  return (declare ? (byval ? "{i64*, i64}" : "i8*") : "Column<Timestamp>");
239  return (declare ? (byval ? "{i8*, i64}" : "i8*") : "TextEncodingNone");
241  return (declare ? "{ i32 }" : "TextEncodingDict");
243  return (declare ? "{ i64 }" : "Timestamp");
245  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i8>");
247  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i16>");
249  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i32>");
251  return (declare ? "{i8**, i64, i64}*" : "ColumnList<i64>");
253  return (declare ? "{i8**, i64, i64}*" : "ColumnList<float>");
255  return (declare ? "{i8**, i64, i64}*" : "ColumnList<double>");
257  return (declare ? "{i8**, i64, i64}*" : "ColumnList<bool>");
259  return (declare ? "{i8**, i64, i64}*" : "ColumnList<TextEncodingDict>");
261  return (declare ? "{i8*, i64}*" : "Column<Array<i8>>");
263  return (declare ? "{i8*, i64}*" : "Column<Array<i16>>");
265  return (declare ? "{i8*, i64}*" : "Column<Array<i32>>");
267  return (declare ? "{i8*, i64}*" : "Column<Array<i64>>");
269  return (declare ? "{i8*, i64}*" : "Column<Array<float>>");
271  return (declare ? "{i8*, i64}*" : "Column<Array<double>>");
273  return (declare ? "{i8*, i64}*" : "Column<Array<bool>>");
275  return (declare ? "{i8*, i64}" : "Column<Array<TextEncodingDict>>");
277  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i8>");
279  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i16>");
281  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i32>");
283  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<i64>");
285  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<float>");
287  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<double>");
289  return (declare ? "{i8**, i64, i64}*" : "ColumnListArray<bool>");
291  return (declare ? "{i8**, i64, i64}" : "ColumnList<Array<TextEncodingDict>>");
293  return (declare ? "{ i64 }" : "DayTimeInterval");
295  return (declare ? "{ i64 }" : "YearMonthTimeInterval");
297  return (declare ? "{i8*, i64}*" : "Column<GeoPoint>");
299  return (declare ? "{i8*, i64}*" : "Column<GeoLineString>");
301  return (declare ? "{i8*, i64}*" : "Column<GeoPolygon>");
303  return (declare ? "{i8*, i64}*" : "Column<GeoMultiPoint>");
305  return (declare ? "{i8*, i64}*" : "Column<GeoMultiLineString>");
307  return (declare ? "{i8*, i64}*" : "Column<GeoMultiPolygon>");
309  return (declare ? "{i8*, i64, i64}*" : "ColumnList<GeoPoint>");
311  return (declare ? "{i8*, i64, i64}*" : "ColumnList<GeoLineString>");
313  return (declare ? "{i8*, i64, i64}*" : "ColumnList<GeoPolygon>");
315  return (declare ? "{i8*, i64, i64}*" : "ColumnList<GeoMultiPoint>");
317  return (declare ? "{i8*, i64, i64}*" : "ColumnList<GeoMultiLineString>");
319  return (declare ? "{i8*, i64, i64}*" : "ColumnList<GeoMultiPolygon>");
321  return (declare ? "{i8*, i64}*" : "Column<TextEncodingNone>");
323  return (declare ? "{i8*, i64}*" : "Column<Array<TextEncodingNone>>");
325  return (declare ? "{i8*, i64, i64}*" : "ColumnList<TextEncodingNone>");
327  return (declare ? "{i8*, i64, i64}*" : "ColumnList<Array<TextEncodingNone>>");
328  default:
329  CHECK(false);
330  }
331  CHECK(false);
332  return "";
333 }
334 
335 std::string drop_suffix(const std::string& str) {
336  const auto idx = str.find("__");
337  if (idx == std::string::npos) {
338  return str;
339  }
340  CHECK_GT(idx, std::string::size_type(0));
341  return str.substr(0, idx);
342 }
343 
344 } // namespace
345 
347  SQLTypes type = kNULLT;
348  int d = 0;
349  int s = 0;
350  bool n = false;
352  int p = 0;
353  SQLTypes subtype = kNULLT;
354 
355 #define EXTARGTYPECASE(EXTARGTYPE, ELEMTYPE, ENCODING, ARRAYENCODING) \
356  case ExtArgumentType::EXTARGTYPE: \
357  type = ELEMTYPE; \
358  c = kENCODING_##ENCODING; \
359  break; \
360  case ExtArgumentType::Array##EXTARGTYPE: \
361  type = kARRAY; \
362  c = kENCODING_##ENCODING; \
363  subtype = ELEMTYPE; \
364  n = true; \
365  break; \
366  case ExtArgumentType::Column##EXTARGTYPE: \
367  type = kCOLUMN; \
368  c = kENCODING_##ENCODING; \
369  subtype = ELEMTYPE; \
370  break; \
371  case ExtArgumentType::ColumnList##EXTARGTYPE: \
372  type = kCOLUMN_LIST; \
373  c = kENCODING_##ENCODING; \
374  subtype = ELEMTYPE; \
375  break; \
376  case ExtArgumentType::ColumnArray##EXTARGTYPE: \
377  type = kCOLUMN; \
378  subtype = ELEMTYPE; \
379  c = kENCODING_##ARRAYENCODING; \
380  break; \
381  case ExtArgumentType::ColumnListArray##EXTARGTYPE: \
382  type = kCOLUMN_LIST; \
383  subtype = ELEMTYPE; \
384  c = kENCODING_##ARRAYENCODING; \
385  break;
386 
387 #define EXTARGGEOTYPECASE(GEOTYPE, KTYPE) \
388  case ExtArgumentType::Geo##GEOTYPE: \
389  type = KTYPE; \
390  subtype = kGEOMETRY; \
391  break; \
392  case ExtArgumentType::ColumnGeo##GEOTYPE: \
393  type = kCOLUMN; \
394  subtype = KTYPE; \
395  break; \
396  case ExtArgumentType::ColumnListGeo##GEOTYPE: \
397  type = kCOLUMN_LIST; \
398  subtype = KTYPE; \
399  break;
400 
401  switch (ext_arg_type) {
402  EXTARGTYPECASE(Bool, kBOOLEAN, NONE, ARRAY);
403  EXTARGTYPECASE(Int8, kTINYINT, NONE, ARRAY);
405  EXTARGTYPECASE(Int32, kINT, NONE, ARRAY);
406  EXTARGTYPECASE(Int64, kBIGINT, NONE, ARRAY);
407  EXTARGTYPECASE(Float, kFLOAT, NONE, ARRAY);
408  EXTARGTYPECASE(Double, kDOUBLE, NONE, ARRAY);
410  EXTARGTYPECASE(TextEncodingDict, kTEXT, DICT, ARRAY_DICT);
411  // TODO: EXTARGTYPECASE(Timestamp, kTIMESTAMP, NONE, ARRAY);
413  type = kTIMESTAMP;
414  c = kENCODING_NONE;
415  d = 9;
416  break;
418  type = kCOLUMN;
419  subtype = kTIMESTAMP;
420  c = kENCODING_NONE;
421  d = 9;
422  break;
424  type = kINTERVAL_DAY_TIME;
425  break;
427  type = kINTERVAL_YEAR_MONTH;
428  break;
429  EXTARGGEOTYPECASE(Point, kPOINT);
430  EXTARGGEOTYPECASE(LineString, kLINESTRING);
431  EXTARGGEOTYPECASE(Polygon, kPOLYGON);
432  EXTARGGEOTYPECASE(MultiPoint, kMULTIPOINT);
433  EXTARGGEOTYPECASE(MultiLineString, kMULTILINESTRING);
434  EXTARGGEOTYPECASE(MultiPolygon, kMULTIPOLYGON);
435  default:
436  LOG(FATAL) << "ExtArgumentType `" << serialize_type(ext_arg_type)
437  << "` cannot be converted to SQLTypes.";
438  UNREACHABLE();
439  }
440  return SQLTypeInfo(type, d, s, n, c, p, subtype);
441 }
442 
444  const std::vector<ExtensionFunction>& ext_funcs,
445  std::string tab) {
446  std::string r = "";
447  for (auto sig : ext_funcs) {
448  r += tab + sig.toString() + "\n";
449  }
450  return r;
451 }
452 
454  const std::vector<SQLTypeInfo>& arg_types) {
455  std::string r = "";
456  for (auto sig = arg_types.begin(); sig != arg_types.end();) {
457  r += sig->get_type_name();
458  sig++;
459  if (sig != arg_types.end()) {
460  r += ", ";
461  }
462  }
463  return r;
464 }
465 
467  const std::vector<ExtArgumentType>& sig_types) {
468  std::string r = "";
469  for (auto t = sig_types.begin(); t != sig_types.end();) {
470  r += serialize_type(*t, /* byval */ false, /* declare */ false);
471  t++;
472  if (t != sig_types.end()) {
473  r += ", ";
474  }
475  }
476  return r;
477 }
478 
480  const std::vector<ExtArgumentType>& sig_types) {
481  std::string r = "";
482  for (auto t = sig_types.begin(); t != sig_types.end();) {
484  t++;
485  if (t != sig_types.end()) {
486  r += ", ";
487  }
488  }
489  return r;
490 }
491 
493  return serialize_type(sig_type, /* byval */ false, /* declare */ false);
494 }
495 
497  switch (sig_type) {
499  return "TINYINT";
501  return "SMALLINT";
503  return "INTEGER";
505  return "BIGINT";
507  return "FLOAT";
509  return "DOUBLE";
511  return "BOOLEAN";
513  return "TINYINT[]";
515  return "SMALLINT[]";
517  return "INT[]";
519  return "BIGINT[]";
521  return "FLOAT[]";
523  return "DOUBLE[]";
525  return "BOOLEAN[]";
527  return "ARRAY<TINYINT>";
529  return "ARRAY<SMALLINT>";
531  return "ARRAY<INT>";
533  return "ARRAY<BIGINT>";
535  return "ARRAY<FLOAT>";
537  return "ARRAY<DOUBLE>";
539  return "ARRAY<BOOLEAN>";
541  return "ARRAY<TEXT ENCODING DICT>";
543  return "ARRAY<TEXT ENCODING NONE>";
545  return "COLUMN<TINYINT>";
547  return "COLUMN<SMALLINT>";
549  return "COLUMN<INT>";
551  return "COLUMN<BIGINT>";
553  return "COLUMN<FLOAT>";
555  return "COLUMN<DOUBLE>";
557  return "COLUMN<BOOLEAN>";
559  return "COLUMN<TEXT ENCODING DICT>";
561  return "COLUMN<TEXT ENCODING NONE>";
563  return "COLUMN<TIMESTAMP(9)>";
565  return "CURSOR";
567  return "POINT";
569  return "MULTIPOINT";
571  return "LINESTRING";
573  return "MULTILINESTRING";
575  return "POLYGON";
577  return "MULTIPOLYGON";
579  return "VOID";
581  return "TEXT ENCODING NONE";
583  return "TEXT ENCODING DICT";
585  return "TIMESTAMP(9)";
587  return "COLUMNLIST<TINYINT>";
589  return "COLUMNLIST<SMALLINT>";
591  return "COLUMNLIST<INT>";
593  return "COLUMNLIST<BIGINT>";
595  return "COLUMNLIST<FLOAT>";
597  return "COLUMNLIST<DOUBLE>";
599  return "COLUMNLIST<BOOLEAN>";
601  return "COLUMNLIST<TEXT ENCODING DICT>";
603  return "COLUMNLIST<TEXT ENCODING NONE>";
605  return "COLUMN<ARRAY<TINYINT>>";
607  return "COLUMN<ARRAY<SMALLINT>>";
609  return "COLUMN<ARRAY<INT>>";
611  return "COLUMN<ARRAY<BIGINT>>";
613  return "COLUMN<ARRAY<FLOAT>>";
615  return "COLUMN<ARRAY<DOUBLE>>";
617  return "COLUMN<ARRAY<BOOLEAN>>";
619  return "COLUMN<ARRAY<TEXT ENCODING DICT>>";
621  return "COLUMN<ARRAY<TEXT ENCODING NONE>>";
623  return "COLUMNLIST<ARRAY<TINYINT>>";
625  return "COLUMNLIST<ARRAY<SMALLINT>>";
627  return "COLUMNLIST<ARRAY<INT>>";
629  return "COLUMNLIST<ARRAY<BIGINT>>";
631  return "COLUMNLIST<ARRAY<FLOAT>>";
633  return "COLUMNLIST<ARRAY<DOUBLE>>";
635  return "COLUMNLIST<ARRAY<BOOLEAN>>";
637  return "COLUMNLIST<ARRAY<TEXT ENCODING DICT>>";
639  return "COLUMNLIST<ARRAY<TEXT ENCODING NONE>>";
641  return "DAY TIME INTERVAL";
643  return "YEAR MONTH INTERVAL";
645  return "COLUMN<GEOPOINT>";
647  return "COLUMN<GEOLINESTRING>";
649  return "COLUMN<GEOPOLYGON>";
651  return "COLUMN<GEOMULTIPOINT>";
653  return "COLUMN<GEOMULTILINESTRING>";
655  return "COLUMN<GEOMULTIPOLYGON>";
657  return "COLUMNLIST<GEOPOINT>";
659  return "COLUMNLIST<GEOLINESTRING>";
661  return "COLUMNLIST<GEOPOLYGON>";
663  return "COLUMNLIST<GEOMULTIPOINT>";
665  return "COLUMNLIST<GEOMULTILINESTRING>";
667  return "COLUMNLIST<GEOMULTIPOLYGON>";
668  default:
669  UNREACHABLE() << " sig_type=" << static_cast<int>(sig_type);
670  }
671  return "";
672 }
673 
675  // if-else exists to keep compatibility with older versions of RBC
676  if (annotations_.empty()) {
677  return false;
678  } else {
679  auto func_annotations = annotations_.back();
680  auto mgr_annotation = func_annotations.find("uses_manager");
681  if (mgr_annotation != func_annotations.end()) {
682  return boost::algorithm::to_lower_copy(mgr_annotation->second) == "true";
683  }
684  return false;
685  }
686 }
687 
688 const std::string ExtensionFunction::getName(bool keep_suffix) const {
689  return (keep_suffix ? name_ : drop_suffix(name_));
690 }
691 
692 std::string ExtensionFunction::toString() const {
693  return getName() + "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
695 }
696 
697 std::string ExtensionFunction::toStringSQL() const {
698  return getName(/* keep_suffix = */ false) + "(" +
701 }
702 
703 std::string ExtensionFunction::toSignature() const {
704  return "(" + ExtensionFunctionsWhitelist::toString(args_) + ") -> " +
706 }
707 
708 // Converts the extension function signatures to their LLVM representation.
710  const std::unordered_set<std::string>& udf_decls,
711  const bool is_gpu) {
712  std::vector<std::string> declarations;
713  for (const auto& kv : functions_) {
714  const std::vector<ExtensionFunction>& ext_funcs = kv.second;
715  CHECK(!ext_funcs.empty());
716  for (const auto& ext_func : ext_funcs) {
717  // If there is a udf function declaration matching an extension function signature
718  // do not emit a duplicate declaration.
719  if (!udf_decls.empty() && udf_decls.find(ext_func.getName()) != udf_decls.end()) {
720  continue;
721  }
722 
723  std::string decl_prefix;
724  std::vector<std::string> arg_strs;
725 
726  if (is_ext_arg_type_array(ext_func.getRet())) {
727  decl_prefix = "declare void @" + ext_func.getName();
728  arg_strs.emplace_back(
729  serialize_type(ext_func.getRet(), /* byval */ true, /* declare */ true));
730  } else {
731  decl_prefix =
732  "declare " +
733  serialize_type(ext_func.getRet(), /* byval */ true, /* declare */ true) +
734  " @" + ext_func.getName();
735  }
736 
737  // if the extension function uses a Row Function Manager, append "i8*" as the first
738  // arg
739  if (ext_func.usesManager()) {
740  arg_strs.emplace_back("i8*");
741  }
742 
743  for (const auto arg : ext_func.getInputArgs()) {
744  arg_strs.emplace_back(serialize_type(arg, /* byval */ false, /* declare */ true));
745  }
746  declarations.emplace_back(decl_prefix + "(" +
747  boost::algorithm::join(arg_strs, ", ") + ");");
748  }
749  }
750 
752  if (kv.second.isRuntime() || kv.second.useDefaultSizer()) {
753  // Runtime UDTFs are defined in LLVM/NVVM IR module
754  // UDTFs using default sizer share LLVM IR
755  continue;
756  }
757  if (!((is_gpu && kv.second.isGPU()) || (!is_gpu && kv.second.isCPU()))) {
758  continue;
759  }
760  std::string decl_prefix{
761  "declare " +
762  serialize_type(ExtArgumentType::Int32, /* byval */ true, /* declare */ true) +
763  " @" + kv.first};
764  std::vector<std::string> arg_strs;
765  for (const auto arg : kv.second.getArgs(/* ensure_column = */ true)) {
766  arg_strs.push_back(
767  serialize_type(arg, /* byval= */ kv.second.isRuntime(), /* declare= */ true));
768  }
769  declarations.push_back(decl_prefix + "(" + boost::algorithm::join(arg_strs, ", ") +
770  ");");
771  }
772  return declarations;
773 }
774 
775 namespace {
776 
778  if (type_name == "bool" || type_name == "i1") {
779  return ExtArgumentType::Bool;
780  }
781  if (type_name == "i8") {
782  return ExtArgumentType::Int8;
783  }
784  if (type_name == "i16") {
785  return ExtArgumentType::Int16;
786  }
787  if (type_name == "i32") {
788  return ExtArgumentType::Int32;
789  }
790  if (type_name == "i64") {
791  return ExtArgumentType::Int64;
792  }
793  if (type_name == "float") {
794  return ExtArgumentType::Float;
795  }
796  if (type_name == "double") {
798  }
799  if (type_name == "void") {
800  return ExtArgumentType::Void;
801  }
802  if (type_name == "i8*") {
803  return ExtArgumentType::PInt8;
804  }
805  if (type_name == "i16*") {
807  }
808  if (type_name == "i32*") {
810  }
811  if (type_name == "i64*") {
813  }
814  if (type_name == "float*") {
816  }
817  if (type_name == "double*") {
819  }
820  if (type_name == "i1*" || type_name == "bool*") {
821  return ExtArgumentType::PBool;
822  }
823  if (type_name == "Array<i8>") {
825  }
826  if (type_name == "Array<i16>") {
828  }
829  if (type_name == "Array<i32>") {
831  }
832  if (type_name == "Array<i64>") {
834  }
835  if (type_name == "Array<float>") {
837  }
838  if (type_name == "Array<double>") {
840  }
841  if (type_name == "Array<bool>" || type_name == "Array<i1>") {
843  }
844  if (type_name == "Array<TextEncodingDict>") {
846  }
847  if (type_name == "Array<TextEncodingNone>") {
849  }
850  if (type_name == "GeoPoint") {
852  }
853  if (type_name == "GeoMultiPoint") {
855  }
856  if (type_name == "GeoLineString") {
858  }
859  if (type_name == "GeoMultiLineString") {
861  }
862  if (type_name == "GeoPolygon") {
864  }
865  if (type_name == "GeoMultiPolygon") {
867  }
868  if (type_name == "cursor") {
870  }
871  if (type_name == "Column<i8>") {
873  }
874  if (type_name == "Column<i16>") {
876  }
877  if (type_name == "Column<i32>") {
879  }
880  if (type_name == "Column<i64>") {
882  }
883  if (type_name == "Column<float>") {
885  }
886  if (type_name == "Column<double>") {
888  }
889  if (type_name == "Column<bool>") {
891  }
892  if (type_name == "Column<TextEncodingDict>") {
894  }
895  if (type_name == "Column<Timestamp>") {
897  }
898  if (type_name == "TextEncodingNone") {
900  }
901  if (type_name == "TextEncodingDict") {
903  }
904  if (type_name == "timestamp") {
906  }
907  if (type_name == "ColumnList<i8>") {
909  }
910  if (type_name == "ColumnList<i16>") {
912  }
913  if (type_name == "ColumnList<i32>") {
915  }
916  if (type_name == "ColumnList<i64>") {
918  }
919  if (type_name == "ColumnList<float>") {
921  }
922  if (type_name == "ColumnList<double>") {
924  }
925  if (type_name == "ColumnList<bool>") {
927  }
928  if (type_name == "ColumnList<TextEncodingDict>") {
930  }
931  if (type_name == "Column<Array<i8>>") {
933  }
934  if (type_name == "Column<Array<i16>>") {
936  }
937  if (type_name == "Column<Array<i32>>") {
939  }
940  if (type_name == "Column<Array<i64>>") {
942  }
943  if (type_name == "Column<Array<float>>") {
945  }
946  if (type_name == "Column<Array<double>>") {
948  }
949  if (type_name == "Column<Array<bool>>") {
951  }
952  if (type_name == "Column<Array<TextEncodingDict>>") {
954  }
955  if (type_name == "ColumnList<Array<i8>>") {
957  }
958  if (type_name == "ColumnList<Array<i16>>") {
960  }
961  if (type_name == "ColumnList<Array<i32>>") {
963  }
964  if (type_name == "ColumnList<Array<i64>>") {
966  }
967  if (type_name == "ColumnList<Array<float>>") {
969  }
970  if (type_name == "ColumnList<Array<double>>") {
972  }
973  if (type_name == "ColumnList<Array<bool>>") {
975  }
976  if (type_name == "ColumnList<Array<TextEncodingDict>>") {
978  }
979  if (type_name == "DayTimeInterval") {
981  }
982  if (type_name == "YearMonthTimeInterval") {
984  }
985  if (type_name == "Column<GeoPoint>") {
987  }
988  if (type_name == "Column<GeoLineString>") {
990  }
991  if (type_name == "Column<GeoPolygon>") {
993  }
994  if (type_name == "Column<GeoMultiPoint>") {
996  }
997  if (type_name == "Column<GeoMultiLineString>") {
999  }
1000  if (type_name == "Column<GeoMultiPolygon>") {
1002  }
1003  if (type_name == "ColumnList<GeoPoint>") {
1005  }
1006  if (type_name == "ColumnList<GeoLineString>") {
1008  }
1009  if (type_name == "ColumnList<GeoPolygon>") {
1011  }
1012  if (type_name == "ColumnList<GeoMultiPoint>") {
1014  }
1015  if (type_name == "ColumnList<GeoMultiLineString>") {
1017  }
1018  if (type_name == "ColumnList<GeoMultiPolygon>") {
1020  }
1021  CHECK(false);
1022  return ExtArgumentType::Int16;
1023 }
1024 
1025 } // namespace
1026 
1027 using SignatureMap = std::unordered_map<std::string, std::vector<ExtensionFunction>>;
1028 
1030  const std::string& json_func_sigs,
1031  const bool is_runtime) {
1032  rapidjson::Document func_sigs;
1033  func_sigs.Parse(json_func_sigs.c_str());
1034  CHECK(func_sigs.IsArray());
1035  for (auto func_sigs_it = func_sigs.Begin(); func_sigs_it != func_sigs.End();
1036  ++func_sigs_it) {
1037  CHECK(func_sigs_it->IsObject());
1038  const auto name = json_str(field(*func_sigs_it, "name"));
1039  const auto ret = deserialize_type(json_str(field(*func_sigs_it, "ret")));
1040  std::vector<ExtArgumentType> args;
1041  const auto& args_serialized = field(*func_sigs_it, "args");
1042  CHECK(args_serialized.IsArray());
1043  for (auto args_serialized_it = args_serialized.Begin();
1044  args_serialized_it != args_serialized.End();
1045  ++args_serialized_it) {
1046  args.push_back(deserialize_type(json_str(*args_serialized_it)));
1047  }
1048 
1049  std::vector<std::map<std::string, std::string>> annotations;
1050  const auto& anns = field(*func_sigs_it, "annotations");
1051  CHECK(anns.IsArray());
1052  static const std::map<std::string, std::string> map_empty = {};
1053  for (auto obj = anns.Begin(); obj != anns.End(); ++obj) {
1054  CHECK(obj->IsObject());
1055  if (obj->ObjectEmpty()) {
1056  annotations.push_back(map_empty);
1057  } else {
1058  std::map<std::string, std::string> m;
1059  for (auto kv = obj->MemberBegin(); kv != obj->MemberEnd(); ++kv) {
1060  m[kv->name.GetString()] = kv->value.GetString();
1061  }
1062  annotations.push_back(m);
1063  }
1064  }
1065  signatures[to_upper(drop_suffix(name))].emplace_back(
1066  name, args, ret, annotations, is_runtime);
1067  }
1068 }
1069 
1070 // Calcite loads the available extensions from `ExtensionFunctions.ast`, adds
1071 // them to its operator table and shares the list with the execution layer in
1072 // JSON format. Build an in-memory representation of that list here so that it
1073 // can be used by getLLVMDeclarations(), when the LLVM IR codegen asks for it.
1074 void ExtensionFunctionsWhitelist::add(const std::string& json_func_sigs) {
1075  // Valid json_func_sigs example:
1076  // [
1077  // {
1078  // "name":"sum",
1079  // "ret":"i32",
1080  // "args":[
1081  // "i32",
1082  // "i32"
1083  // ]
1084  // }
1085  // ]
1086 
1087  addCommon(functions_, json_func_sigs, /* is_runtime */ false);
1088 }
1089 
1090 void ExtensionFunctionsWhitelist::addUdfs(const std::string& json_func_sigs) {
1091  if (!json_func_sigs.empty()) {
1092  addCommon(udf_functions_, json_func_sigs, /* is_runtime */ false);
1093  }
1094 }
1095 
1097  rt_udf_functions_.clear();
1098 }
1099 
1100 void ExtensionFunctionsWhitelist::addRTUdfs(const std::string& json_func_sigs) {
1101  if (!json_func_sigs.empty()) {
1102  addCommon(rt_udf_functions_, json_func_sigs, /* is_runtime */ true);
1103  }
1104 }
1105 
1106 std::unordered_map<std::string, std::vector<ExtensionFunction>>
1108 
1109 std::unordered_map<std::string, std::vector<ExtensionFunction>>
1111 
1112 std::unordered_map<std::string, std::vector<ExtensionFunction>>
1114 
1115 std::string toString(const ExtArgumentType& sig_type) {
1116  return ExtensionFunctionsWhitelist::toString(sig_type);
1117 }
static void addUdfs(const std::string &json_func_sigs)
static void addCommon(std::unordered_map< std::string, std::vector< ExtensionFunction >> &sigs, const std::string &json_func_sigs, const bool is_runtime)
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
static std::unordered_map< std::string, std::vector< ExtensionFunction > > udf_functions_
SQLTypes
Definition: sqltypes.h:65
static std::vector< ExtensionFunction > * get(const std::string &name)
#define LOG(tag)
Definition: Logger.h:285
std::string toSignature() const
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:46
static std::unordered_map< std::string, std::vector< ExtensionFunction > > rt_udf_functions_
std::string join(T const &container, std::string const &delim)
const std::vector< ExtArgumentType > args_
static void add(const std::string &json_func_sigs)
#define UNREACHABLE()
Definition: Logger.h:338
std::unordered_map< std::string, std::vector< ExtensionFunction >> SignatureMap
const std::string getName(bool keep_suffix=true) const
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:391
std::string toStringSQL() const
#define CHECK_GT(x, y)
Definition: Logger.h:305
static std::unordered_map< std::string, std::vector< ExtensionFunction > > functions_
const rapidjson::Value & field(const rapidjson::Value &obj, const char field[]) noexcept
Definition: JsonAccessors.h:33
EncodingType
Definition: sqltypes.h:240
Supported runtime functions management and retrieval.
static std::vector< ExtensionFunction > * get_udf(const std::string &name)
#define EXTARGTYPECASE(EXTARGTYPE, ELEMTYPE, ENCODING, ARRAYENCODING)
DEVICE auto copy(ARGS &&...args)
Definition: gpu_enabled.h:51
ExtArgumentType deserialize_type(const std::string &type_name)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
static std::unordered_set< std::string > get_udfs_name(const bool is_runtime)
Checked json field retrieval.
std::string toString(const Executor::ExtModuleKinds &kind)
Definition: Execute.h:1703
Argument type based extension function binding.
std::string to_upper(const std::string &str)
Definition: sqltypes.h:79
#define EXTARGGEOTYPECASE(GEOTYPE, KTYPE)
std::string serialize_type(const ExtArgumentType type, bool byval=true, bool declare=false)
const std::vector< std::map< std::string, std::string > > annotations_
const ExtArgumentType ret_
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
#define CHECK(condition)
Definition: Logger.h:291
static std::unordered_map< std::string, TableFunction > functions_
Definition: sqltypes.h:72
string name
Definition: setup.in.py:72
constexpr double n
Definition: Utm.h:38
constexpr auto type_name() noexcept
static std::vector< std::string > getLLVMDeclarations(const std::unordered_set< std::string > &udf_decls, const bool is_gpu=false)
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
static std::string toStringSQL(const std::vector< ExtArgumentType > &sig_types)
static void addRTUdfs(const std::string &json_func_sigs)