37 switch (ext_arg_column_type) {
94 switch (ext_arg_column_list_type) {
148 switch (ext_arg_array_type) {
174 const bool is_arg_literal,
176 int32_t& penalty_score) {
177 const auto arg_type = arg_type_info.
get_type();
183 const auto sig_type = sig_type_info.get_type();
206 const bool is_integer_to_fp_cast = (arg_type ==
kTINYINT || arg_type ==
kSMALLINT ||
211 CHECK_GE(arg_type_relative_scale, 1);
212 CHECK_LE(arg_type_relative_scale, 8);
213 auto sig_type_relative_scale = sig_type_info.get_numeric_scalar_scale();
214 CHECK_GE(sig_type_relative_scale, 1);
215 CHECK_LE(sig_type_relative_scale, 8);
217 if (is_integer_to_fp_cast) {
219 sig_type_relative_scale = (3 - (sig_type_relative_scale >> 2)) << 3;
224 CHECK_GE(sig_type_relative_scale, arg_type_relative_scale);
227 const auto sig_type_scale_gain_ratio =
228 sig_type_relative_scale / arg_type_relative_scale;
229 CHECK_GE(sig_type_scale_gain_ratio, 1);
235 const int32_t type_family_cast_penalty_score = is_integer_to_fp_cast ? 1001000 : 1000;
237 int32_t scale_cast_penalty_score;
251 if (is_arg_literal) {
252 scale_cast_penalty_score =
253 (8000 / arg_type_relative_scale) - (1000 * sig_type_scale_gain_ratio);
255 scale_cast_penalty_score = (1000 * sig_type_scale_gain_ratio);
258 const auto cast_penalty_score =
259 type_family_cast_penalty_score + scale_cast_penalty_score;
261 penalty_score += cast_penalty_score;
266 const bool is_arg_literal,
268 const std::vector<ExtArgumentType>& sig_types,
269 int& penalty_score) {
292 int max_pos = sig_types.size() - 1;
293 if (sig_pos > max_pos) {
296 auto sig_type = sig_types[sig_pos];
315 penalty_score += 1000;
323 penalty_score += 1000;
332 penalty_score += 1000;
335 penalty_score += 1000;
345 penalty_score += 1000;
350 const auto sig_type_ti =
353 sig_type_ti.get_type() ==
kTINYINT) {
355 penalty_score += 1000;
358 penalty_score += 1000;
370 penalty_score += 1000;
373 penalty_score += 1000;
384 penalty_score += 1000;
387 penalty_score += 1000;
397 penalty_score += 1000;
404 const auto sig_type_ti =
407 sig_type_ti.get_type() ==
kARRAY) {
409 sig_type_ti.get_elem_type().get_type()) {
410 penalty_score += 1000;
416 sig_type_ti.get_type() ==
kTINYINT) {
418 penalty_score += 1000;
421 penalty_score += 1000;
431 const auto sig_type_ti =
434 sig_type_ti.get_type() ==
kARRAY) {
436 sig_type_ti.get_elem_type().get_type()) {
437 penalty_score += 1000;
443 sig_type_ti.get_type() ==
kTINYINT) {
445 penalty_score += 10000;
448 penalty_score += 10000;
461 penalty_score += 1000;
473 penalty_score += 1000;
479 penalty_score += 1000;
488 penalty_score += 1000;
494 penalty_score += 1000;
501 penalty_score += 1000;
517 throw std::runtime_error(std::string(__FILE__) +
"#" +
std::to_string(__LINE__) +
532 if (!(std::isalpha(str[0]) || str[0] ==
'_')) {
536 for (
size_t i = 1; i < str.size(); i++) {
537 if (!(std::isalnum(str[i]) || str[i] ==
'_')) {
547 template <
typename T>
551 const std::vector<T>& ext_funcs,
552 const std::string processor) {
585 "Cannot bind function with invalid UDF/UDTF function name: " + name);
588 std::vector<SQLTypeInfo> type_infos_input;
589 std::vector<bool> args_are_constants;
590 for (
auto atype : func_args) {
591 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
592 if (dynamic_cast<const Analyzer::ColumnVar*>(atype.get())) {
595 if (ti.get_subtype() ==
kNULLT) {
596 throw std::runtime_error(std::string(__FILE__) +
"#" +
598 ": column support for type info " +
599 type_info.
to_string() +
" is not implemented");
602 type_infos_input.push_back(ti);
607 type_infos_input.push_back(atype->get_type_info());
608 if (dynamic_cast<const Analyzer::Constant*>(atype.get())) {
609 args_are_constants.push_back(
true);
611 args_are_constants.push_back(
false);
614 CHECK_EQ(type_infos_input.size(), args_are_constants.size());
616 if (type_infos_input.size() == 0 && ext_funcs.size() > 0) {
617 CHECK_EQ(ext_funcs.size(),
static_cast<size_t>(1));
618 CHECK_EQ(ext_funcs[0].getInputArgs().size(), static_cast<size_t>(0));
619 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
620 CHECK(ext_funcs[0].hasNonUserSpecifiedOutputSize());
622 std::vector<SQLTypeInfo> empty_type_info_variant(0);
623 return {ext_funcs[0], empty_type_info_variant};
626 int minimal_score = std::numeric_limits<int>::max();
629 int optimal_variant = -1;
630 std::vector<std::vector<SQLTypeInfo>> type_infos_variants;
670 for (
const auto& ext_func : ext_funcs) {
673 const auto& ext_func_args = ext_func.getInputArgs();
675 int penalty_score = 0;
677 int original_input_idx = 0;
678 type_infos_variants.emplace_back();
680 for (
size_t i = 0; i < type_infos_input.size(); i++) {
683 if ((
size_t)pos >= ext_func_args.size()) {
689 args_are_constants[original_input_idx],
701 size_t args_left = ext_func_args.size() - pos - 1;
702 while ((type_infos_input.size() - j > args_left) and
709 type_infos_variants.back().push_back(ti_col_list);
712 original_input_idx = j;
716 args_are_constants[original_input_idx],
722 type_infos_variants.back().push_back(ti);
723 original_input_idx += 1;
732 if ((
size_t)pos == ext_func_args.size()) {
733 CHECK_EQ(args_are_constants.size(), original_input_idx);
736 if (penalty_score < minimal_score) {
738 minimal_score = penalty_score;
739 optimal_variant = type_infos_variants.size() - 1;
749 if (!ext_funcs.size()) {
750 message =
"Function " + name +
"(" + sarg_types +
") not supported.";
753 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
754 message =
"Could not bind " + name +
"(" + sarg_types +
") to any " + processor +
755 " UDTF implementation.";
756 }
else if constexpr (std::is_same_v<T, ExtensionFunction>) {
757 message =
"Could not bind " + name +
"(" + sarg_types +
") to any " + processor +
758 " UDF implementation.";
760 LOG(
FATAL) <<
"bind_function: unknown extension function type "
763 message +=
"\n Existing extension function implementations:";
764 for (
const auto& ext_func : ext_funcs) {
766 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
767 if (ext_func.useDefaultSizer()) {
771 message +=
"\n " + ext_func.toStringSQL();
778 if constexpr (std::is_same_v<T, table_functions::TableFunction>) {
779 if (ext_funcs[optimal].hasUserSpecifiedOutputSizeMultiplier() &&
780 ext_funcs[optimal].useDefaultSizer()) {
781 std::string name = ext_funcs[optimal].getName();
784 for (
size_t i = 0; i < ext_funcs.size(); i++) {
785 if (ext_funcs[i].getName() ==
name) {
787 std::vector<SQLTypeInfo> type_info = type_infos_variants[optimal_variant];
788 size_t sizer = ext_funcs[optimal].getOutputRowSizeParameter();
789 type_info.insert(type_info.begin() + sizer - 1,
SQLTypeInfo(
kINT,
true));
790 return {ext_funcs[optimal], type_info};
797 return {ext_funcs[optimal], type_infos_variants[optimal_variant]};
800 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
803 const std::vector<table_functions::TableFunction>& table_funcs,
805 std::string processor = (is_gpu ?
"GPU" :
"CPU");
806 return bind_function<table_functions::TableFunction>(
807 name, input_args, table_funcs, processor);
815 std::string processor =
"GPU";
817 if (!ext_funcs.size()) {
824 bind_function<ExtensionFunction>(
name, func_args, ext_funcs, processor));
828 processor =
"GPU|CPU";
831 bind_function<ExtensionFunction>(
name, func_args, ext_funcs, processor));
842 std::vector<ExtensionFunction> ext_funcs =
844 std::string processor = (is_gpu ?
"GPU" :
"CPU");
846 bind_function<ExtensionFunction>(
name, func_args, ext_funcs, processor));
854 for (
size_t i = 0; i < function_oper->
getArity(); ++i) {
855 func_args.push_back(function_oper->
getOwnArg(i));
860 const std::tuple<table_functions::TableFunction, std::vector<SQLTypeInfo>>
865 std::vector<table_functions::TableFunction> table_funcs =
ExtArgumentType get_array_arg_elem_type(const ExtArgumentType ext_arg_array_type)
static std::vector< TableFunction > get_table_funcs()
static std::vector< ExtensionFunction > get_ext_funcs(const std::string &name)
bool is_ext_arg_type_column(const ExtArgumentType ext_arg_type)
static int match_arguments(const SQLTypeInfo &arg_type, const bool is_arg_literal, int sig_pos, const std::vector< ExtArgumentType > &sig_types, int &penalty_score)
ExtArgumentType get_column_list_arg_elem_type(const ExtArgumentType ext_arg_column_list_type)
#define DEFAULT_ROW_MULTIPLIER_SUFFIX
HOST DEVICE SQLTypes get_type() const
std::string to_string() const
bool is_ext_arg_type_column_list(const ExtArgumentType ext_arg_type)
std::shared_ptr< Analyzer::Expr > getOwnArg(const size_t i) const
bool has_same_itemtype(const SQLTypeInfo &other) const
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
bool supportsFlatBuffer() const
bool is_numeric_scalar_auto_castable(const SQLTypeInfo &new_type_info) const
returns true if the sql_type can be cast to the type specified by new_type_info with no loss of preci...
ExtArgumentType get_column_arg_elem_type(const ExtArgumentType ext_arg_column_type)
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
static int match_numeric_argument(const SQLTypeInfo &arg_type_info, const bool is_arg_literal, const ExtArgumentType &sig_ext_arg_type, int32_t &penalty_score)
auto generate_column_type(const SQLTypeInfo &elem_ti)
HOST DEVICE EncodingType get_compression() const
auto generate_column_list_type(const SQLTypeInfo &elem_ti)
void set_dimension(int d)
std::string get_type_name() const
int32_t get_numeric_scalar_scale() const
returns integer between 1 and 8 indicating what is roughly equivalent to the logical byte size of a s...
const std::tuple< table_functions::TableFunction, std::vector< SQLTypeInfo > > bind_table_function(std::string name, Analyzer::ExpressionPtrVector input_args, const std::vector< table_functions::TableFunction > &table_funcs, const bool is_gpu)
static std::string toString(const std::vector< ExtensionFunction > &ext_funcs, std::string tab="")
std::vector< ExpressionPtr > ExpressionPtrVector
bool is_valid_identifier(std::string str)
std::string getName() const
SQLTypeInfo get_elem_type() const
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)