19 #include <boost/algorithm/string.hpp>
22 #include <string_view>
23 #include <unordered_map>
30 namespace table_functions {
35 switch (ext_arg_type) {
58 switch (ext_arg_type) {
97 int32_t scalar_args = 0;
113 if (ann.find(
"require") != ann.end()) {
121 const size_t idx)
const {
124 static const std::map<std::string, std::string> empty = {};
131 const size_t input_arg_idx)
const {
137 const std::string& key,
138 const std::string& default_)
const {
140 const auto& it = ann.find(key);
141 if (it != ann.end()) {
148 const size_t output_arg_idx)
const {
154 const std::string& key,
155 const std::string& default_)
const {
157 const auto& it = ann.find(key);
158 if (it != ann.end()) {
169 const std::string& key,
170 const std::string& default_)
const {
172 const auto& it = ann.find(key);
173 if (it != ann.end()) {
180 const size_t sql_idx)
const {
181 std::vector<std::string> fields;
184 static const std::vector<std::string> empty = {};
187 std::string substr = line.substr(1, line.size() - 2);
188 boost::split(fields, substr, boost::is_any_of(
", "), boost::token_compress_on);
193 if (use_input_args) {
194 std::vector<std::string> arg_types;
196 for (
size_t sql_idx = 0; sql_idx <
sql_args_.size(); sql_idx++) {
197 const std::vector<std::string> cursor_fields =
getCursorFields(sql_idx);
198 if (cursor_fields.empty()) {
200 arg_types.emplace_back(
203 std::vector<std::string> vec;
204 for (
size_t i = 0; i < cursor_fields.size(); i++) {
217 std::vector<std::string> names;
218 if (use_input_args) {
219 for (
size_t idx = 0; idx <
sql_args_.size(); idx++) {
221 if (cursor_fields.empty()) {
223 names.emplace_back(name);
229 for (
size_t idx = 0; idx <
output_args_.size(); idx++) {
231 names.emplace_back(name);
239 std::vector<std::string> default_values;
240 default_values.reserve(
sql_args_.size());
241 for (
size_t idx = 0; idx <
sql_args_.size(); idx++) {
243 default_values.emplace_back(name);
251 #define PREFIX_LENGTH 5
253 auto annot = annotation.find(
"input_id");
254 if (annot == annotation.end()) {
279 return std::make_pair(lo, 0);
287 const std::string& input_id = annot->second;
289 if (input_id ==
"args<-1>") {
294 size_t comma = input_id.find(
",");
295 int32_t gt = input_id.size() - 1;
296 int32_t lo = std::stoi(input_id.substr(
PREFIX_LENGTH, comma - 1));
298 if (comma == std::string::npos) {
299 return std::make_pair(lo, 0);
301 int32_t hi = std::stoi(input_id.substr(comma + 1, gt - comma - 1));
302 return std::make_pair(lo, hi);
320 int32_t ext_arg_index = 0, sql_arg_index = 0;
327 while ((
size_t)ext_arg_index < sizer) {
328 if ((
size_t)ext_arg_index == sizer - 1) {
329 return sql_arg_index;
333 const auto& sql_arg =
sql_args_[sql_arg_index];
335 if (same_kind(ext_arg, sql_arg)) {
352 template <
size_t... I>
354 std::index_sequence<I...>) {
355 return ((list[I] < list[I + 1]) && ...);
362 constexpr std::string_view whitelisted_table_functions[]{
364 "decision_tree_reg_fit",
366 "generate_random_strings",
368 "get_decision_trees",
375 "random_forest_reg_fit",
376 "random_forest_reg_var_importance",
377 "supported_ml_frameworks",
378 "tf_compute_dwell_times",
379 "tf_cross_section_1d",
380 "tf_cross_section_2d",
381 "tf_feature_self_similarity",
382 "tf_feature_similarity",
383 "tf_geo_multi_rasterize",
385 "tf_geo_rasterize_slope",
386 "tf_graph_shortest_path",
387 "tf_graph_shortest_paths_distances",
388 "tf_load_point_cloud",
390 "tf_mandelbrot_cuda",
391 "tf_mandelbrot_cuda_float",
392 "tf_mandelbrot_float",
393 "tf_point_cloud_metadata",
394 "tf_raster_contour_lines",
395 "tf_raster_contour_polygons",
396 "tf_raster_graph_shortest_slope_weighted_path"
398 constexpr
auto whitelisted_table_functions_len =
399 sizeof(whitelisted_table_functions) /
sizeof(*whitelisted_table_functions);
403 constexpr std::string_view ml_table_functions[]{
"dbscan",
404 "decision_tree_reg_fit",
406 "get_decision_trees",
413 "random_forest_reg_fit",
414 "random_forest_reg_var_importance",
415 "supported_ml_frameworks"};
416 constexpr
auto ml_table_functions_len =
417 sizeof(ml_table_functions) /
sizeof(*ml_table_functions);
421 whitelisted_table_functions,
422 std::make_index_sequence<whitelisted_table_functions_len - 1>{}));
424 ml_table_functions, std::make_index_sequence<ml_table_functions_len - 1>{}));
426 if (!std::binary_search(whitelisted_table_functions,
427 whitelisted_table_functions + whitelisted_table_functions_len,
432 !std::binary_search(ml_table_functions,
433 ml_table_functions + ml_table_functions_len,
441 const std::string&
name,
443 const std::vector<ExtArgumentType>& input_args,
444 const std::vector<ExtArgumentType>& output_args,
445 const std::vector<ExtArgumentType>& sql_args,
446 const std::vector<std::map<std::string, std::string>>& annotations,
448 static const std::map<std::string, std::string> empty = {};
450 auto func_annotations =
451 (annotations.size() == sql_args.size() + output_args.size() + 1 ? annotations.back()
453 auto mgr_annotation = func_annotations.find(
"uses_manager");
454 bool uses_manager = mgr_annotation != func_annotations.end() &&
455 boost::algorithm::to_lower_copy(mgr_annotation->second) ==
"true";
465 const auto tf_name = tf.getName(
true ,
true );
470 auto sig = tf.getSignature(
true,
false);
472 if (it->second.getName() ==
name) {
473 if (it->second.isRuntime()) {
475 <<
"Overriding existing run-time table function (reset not called?): "
479 throw std::runtime_error(
"Will not override existing load-time table function: " +
483 if (sig == it->second.getSignature(
true,
485 ((tf.isCPU() && it->second.isCPU()) || (tf.isGPU() && it->second.isGPU()))) {
487 <<
"The existing (1) and added (2) table functions have the same signature `"
489 <<
" 1: " << it->second.toString() <<
"\n 2: " << tf.toString() <<
"\n";
497 auto input_args2 = input_args;
498 input_args2.erase(input_args2.begin() + sizer.
val - 1);
500 auto sql_args2 = sql_args;
501 auto sql_sizer_pos = tf.getSqlOutputRowSizeParameter();
502 sql_args2.erase(sql_args2.begin() + sql_sizer_pos);
504 auto annotations2 = annotations;
505 annotations2.erase(annotations2.begin() + sql_sizer_pos);
515 auto sig = tf2.getSignature(
true,
false);
517 if (sig == it->second.getSignature(
true,
519 ((tf2.isCPU() && it->second.isCPU()) || (tf2.isGPU() && it->second.isGPU()))) {
521 <<
"The existing (1) and added (2) table functions have the same signature `"
523 <<
" 1: " << it->second.toString() <<
"\n 2: " << tf2.toString() <<
"\n";
542 if (it->second.isRuntime()) {
553 const auto idx = str.find(
"__");
554 if (idx == std::string::npos) {
557 CHECK_GT(idx, std::string::size_type(0));
558 return str.substr(0, idx);
575 const bool include_output)
const {
582 std::vector<std::string>
args;
583 for (
size_t sql_idx = 0; sql_idx <
sql_args_.size(); sql_idx++) {
584 const std::vector<std::string> cursor_fields =
getCursorFields(sql_idx);
585 if (cursor_fields.empty()) {
590 std::vector<std::string> vec;
591 for (
size_t i = 0; i < cursor_fields.size(); i++) {
593 const auto&
name = cursor_fields[i];
600 if (include_output) {
613 std::vector<TableFunction> table_funcs;
615 if (is_gpu ? tf.isGPU() : tf.isCPU()) {
616 table_funcs.emplace_back(tf);
623 const std::string&
name) {
624 std::vector<TableFunction> table_funcs;
625 auto table_func_name =
name;
629 if (fname == table_func_name) {
630 table_funcs.push_back(pair.second);
637 std::vector<TableFunction> table_funcs;
639 if (pair.second.isRuntime() == is_runtime) {
640 table_funcs.push_back(pair.second);
647 std::vector<TableFunction> table_funcs;
649 table_funcs.push_back(pair.second);
SQLTypeInfo getOutputSQLType(const size_t idx) const
const std::string getOutputAnnotation(const size_t output_arg_idx, const std::string &key, const std::string &default_) const
std::string drop_suffix(const std::string &str)
static std::vector< TableFunction > get_table_funcs()
bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type)
static void add(const std::string &name, const TableFunctionOutputRowSizer sizer, const std::vector< ExtArgumentType > &input_args, const std::vector< ExtArgumentType > &output_args, const std::vector< ExtArgumentType > &sql_args, const std::vector< std::map< std::string, std::string >> &annotations, bool is_runtime=false)
bool hasUserSpecifiedOutputSizeMultiplier() const
SQLTypeInfo ext_arg_pointer_type_to_type_info(const ExtArgumentType ext_arg_type)
size_t getSqlOutputRowSizeParameter() const
SQLTypeInfo ext_arg_type_to_type_info_output(const ExtArgumentType ext_arg_type)
const std::map< std::string, std::string > getFunctionAnnotations() const
const std::vector< std::map< std::string, std::string > > annotations_
const std::vector< ExtArgumentType > output_args_
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
std::pair< int32_t, int32_t > getInputID(const size_t idx) const
const std::string getFunctionAnnotation(const std::string &key, const std::string &default_) const
size_t getOutputRowSizeParameter() const
bool is_table_function_whitelisted(std::string_view const function_name)
std::string getSignature(const bool include_name, const bool include_output) const
bool containsPreFlightFn() const
bool is_ext_arg_type_nonscalar(const ExtArgumentType ext_arg_type)
const std::vector< ExtArgumentType > sql_args_
SQLTypeInfo getInputSQLType(const size_t idx) const
const std::string getArgNames(const bool use_input_args) const
std::string drop_suffix_impl(const std::string &str)
int32_t countScalarArgs() const
std::string getPreFlightFnName() const
const std::map< std::string, std::string > getOutputAnnotations(const size_t output_arg_idx) const
const std::string getInputAnnotation(const size_t input_arg_idx, const std::string &key, const std::string &default_) const
bool g_enable_dev_table_functions
std::string getName(const bool drop_suffix=false, const bool lower=false) const
const std::string getArgTypes(const bool use_input_args) const
const std::string getInputArgsDefaultValues() const
bool hasPreFlightOutputSizer() const
constexpr bool string_view_array_sorted_and_distinct(std::string_view const *list, std::index_sequence< I...>)
const std::vector< std::string > getCursorFields(const size_t sql_idx) const
bool g_enable_ml_functions
const std::map< std::string, std::string > getInputAnnotations(const size_t input_arg_idx) const
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
static std::unordered_map< std::string, TableFunction > functions_
const std::vector< ExtArgumentType > input_args_
SQLTypeInfo get_elem_type() const
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
bool g_enable_table_functions
const std::vector< std::map< std::string, std::string > > & getAnnotations() const