OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExtensionsIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 #include "ExtensionFunctions.hpp"
22 
23 #include <tuple>
24 
25 extern std::unique_ptr<llvm::Module> udf_gpu_module;
26 extern std::unique_ptr<llvm::Module> udf_cpu_module;
27 
28 namespace {
29 
30 llvm::StructType* get_buffer_struct_type(CgenState* cgen_state,
31  const std::string& ext_func_name,
32  size_t param_num,
33  llvm::Type* elem_type) {
34  CHECK(elem_type);
35  CHECK(elem_type->isPointerTy());
36  llvm::StructType* generated_struct_type =
37  llvm::StructType::get(cgen_state->context_,
38  {elem_type,
39  llvm::Type::getInt64Ty(cgen_state->context_),
40  llvm::Type::getInt8Ty(cgen_state->context_)},
41  false);
42  llvm::Function* udf_func = cgen_state->module_->getFunction(ext_func_name);
43  if (udf_func) {
44  // Compare expected array struct type with type from the function
45  // definition from the UDF module, but use the type from the
46  // module
47  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
48  CHECK_LE(param_num, udf_func_type->getNumParams());
49  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
50  CHECK(param_pointer_type->isPointerTy());
51  llvm::Type* param_type = param_pointer_type->getPointerElementType();
52  CHECK(param_type->isStructTy());
53  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
54  CHECK_GE(struct_type->getStructNumElements(),
55  generated_struct_type->getStructNumElements())
56  << serialize_llvm_object(struct_type);
57 
58  const auto expected_elems = generated_struct_type->elements();
59  const auto current_elems = struct_type->elements();
60  for (size_t i = 0; i < expected_elems.size(); i++) {
61  CHECK_EQ(expected_elems[i], current_elems[i])
62  << "[" << ::toString(expected_elems[i]) << ", " << ::toString(current_elems[i])
63  << "]";
64  }
65 
66  if (struct_type->isLiteral()) {
67  return struct_type;
68  }
69 
70  llvm::StringRef struct_name = struct_type->getStructName();
71 #if LLVM_VERSION_MAJOR >= 12
72  return struct_type->getTypeByName(cgen_state->context_, struct_name);
73 #else
74  return cgen_state->module_->getTypeByName(struct_name);
75 #endif
76  }
77  return generated_struct_type;
78 }
79 
81  llvm::LLVMContext& ctx) {
82  switch (ext_arg_type) {
83  case ExtArgumentType::Bool: // pass thru to Int8
85  return get_int_type(8, ctx);
87  return get_int_type(16, ctx);
89  return get_int_type(32, ctx);
91  return get_int_type(64, ctx);
93  return llvm::Type::getFloatTy(ctx);
95  return llvm::Type::getDoubleTy(ctx);
97  return get_int_type(32, ctx);
122  return llvm::Type::getVoidTy(ctx);
123  default:
124  CHECK(false);
125  }
126  CHECK(false);
127  return nullptr;
128 }
129 
131  CHECK(ll_type);
132  const auto bits = ll_type->getPrimitiveSizeInBits();
133 
134  if (ll_type->isFloatingPointTy()) {
135  switch (bits) {
136  case 32:
137  return SQLTypeInfo(kFLOAT, false);
138  case 64:
139  return SQLTypeInfo(kDOUBLE, false);
140  default:
141  LOG(FATAL) << "Unsupported llvm floating point type: " << bits
142  << ", only 32 and 64 bit floating point is supported.";
143  }
144  } else {
145  switch (bits) {
146  case 1:
147  return SQLTypeInfo(kBOOLEAN, false);
148  case 8:
149  return SQLTypeInfo(kTINYINT, false);
150  case 16:
151  return SQLTypeInfo(kSMALLINT, false);
152  case 32:
153  return SQLTypeInfo(kINT, false);
154  case 64:
155  return SQLTypeInfo(kBIGINT, false);
156  default:
157  LOG(FATAL) << "Unrecognized llvm type for SQL type: "
158  << bits; // TODO let's get the real name here
159  }
160  }
161  UNREACHABLE();
162  return SQLTypeInfo();
163 }
164 
166  llvm::LLVMContext& ctx) {
167  CHECK(ti.is_buffer());
168  if (ti.is_text_encoding_none()) {
169  return llvm::Type::getInt8PtrTy(ctx);
170  }
171 
172  const auto& elem_ti = ti.get_elem_type();
173  if (elem_ti.is_fp()) {
174  switch (elem_ti.get_size()) {
175  case 4:
176  return llvm::Type::getFloatPtrTy(ctx);
177  case 8:
178  return llvm::Type::getDoublePtrTy(ctx);
179  }
180  }
181 
182  if (elem_ti.is_text_encoding_dict()) {
183  return llvm::Type::getInt32PtrTy(ctx);
184  }
185 
186  if (elem_ti.is_boolean()) {
187  return llvm::Type::getInt8PtrTy(ctx);
188  }
189 
190  CHECK(elem_ti.is_integer());
191  switch (elem_ti.get_size()) {
192  case 1:
193  return llvm::Type::getInt8PtrTy(ctx);
194  case 2:
195  return llvm::Type::getInt16PtrTy(ctx);
196  case 4:
197  return llvm::Type::getInt32PtrTy(ctx);
198  case 8:
199  return llvm::Type::getInt64PtrTy(ctx);
200  }
201 
202  UNREACHABLE();
203  return nullptr;
204 }
205 
207  const auto& func_ti = function_oper->get_type_info();
208  for (size_t i = 0; i < function_oper->getArity(); ++i) {
209  const auto arg = function_oper->getArg(i);
210  const auto& arg_ti = arg->get_type_info();
211  if ((func_ti.is_array() && arg_ti.is_array()) ||
212  (func_ti.is_text_encoding_none() && arg_ti.is_text_encoding_none()) ||
213  (func_ti.is_text_encoding_dict() && arg_ti.is_text_encoding_dict()) ||
214  (func_ti.is_text_encoding_dict_array() && arg_ti.is_text_encoding_dict())) {
215  // If the function returns an array and any of the arguments are arrays, allow NULL
216  // scalars.
217  // TODO: Make this a property of the FunctionOper following `RETURN NULL ON NULL`
218  // semantics.
219  return false;
220  } else if (!arg_ti.get_notnull() && !arg_ti.is_buffer()) {
221  // Nullable geometry args will trigger a null check
222  return true;
223  } else {
224  continue;
225  }
226  }
227  return false;
228 }
229 
230 } // namespace
231 
233  int8_t* buffer) {
234  Executor* exec_ptr = reinterpret_cast<Executor*>(exec);
235  if (buffer != nullptr) {
236  exec_ptr->getRowSetMemoryOwner()->addVarlenBuffer(buffer);
237  }
238 }
239 
241  const Analyzer::FunctionOper* function_oper,
242  const CompilationOptions& co) {
244  ExtensionFunction ext_func_sig = [=]() {
246  try {
247  return bind_function(function_oper, /* is_gpu= */ true);
248  } catch (ExtensionFunctionBindingError& e) {
249  LOG(WARNING) << "codegenFunctionOper[GPU]: " << e.what() << " Redirecting "
250  << function_oper->getName() << " to run on CPU.";
251  throw QueryMustRunOnCpu();
252  }
253  } else {
254  try {
255  return bind_function(function_oper, /* is_gpu= */ false);
256  } catch (ExtensionFunctionBindingError& e) {
257  LOG(WARNING) << "codegenFunctionOper[CPU]: " << e.what();
258  throw;
259  }
260  }
261  }();
262 
263  const auto& ret_ti = function_oper->get_type_info();
264  CHECK(ret_ti.is_integer() || ret_ti.is_fp() || ret_ti.is_boolean() ||
265  ret_ti.is_buffer() || ret_ti.is_text_encoding_dict());
266  if (ret_ti.is_buffer() && co.device_type == ExecutorDeviceType::GPU) {
267  // TODO: This is not necessary for runtime UDFs because RBC does
268  // not generated GPU LLVM IR when the UDF is using Buffer objects.
269  // However, we cannot remove it until C++ UDFs can be defined for
270  // different devices independently.
271  throw QueryMustRunOnCpu();
272  }
273 
274  auto ret_ty = ext_arg_type_to_llvm_type(ext_func_sig.getRet(), cgen_state_->context_);
275  const auto current_bb = cgen_state_->ir_builder_.GetInsertBlock();
276  for (auto it : cgen_state_->ext_call_cache_) {
277  if (*it.foper == *function_oper) {
278  auto inst = llvm::dyn_cast<llvm::Instruction>(it.lv);
279  if (inst && inst->getParent() == current_bb) {
280  return it.lv;
281  }
282  }
283  }
284  std::vector<llvm::Value*> orig_arg_lvs;
285  std::vector<size_t> orig_arg_lvs_index;
286  std::unordered_map<llvm::Value*, llvm::Value*> const_arr_size;
287 
288  for (size_t i = 0; i < function_oper->getArity(); ++i) {
289  orig_arg_lvs_index.push_back(orig_arg_lvs.size());
290  const auto arg = function_oper->getArg(i);
291  const auto arg_cast = dynamic_cast<const Analyzer::UOper*>(arg);
292  const auto arg0 =
293  (arg_cast && arg_cast->get_optype() == kCAST) ? arg_cast->get_operand() : arg;
294  const auto array_expr_arg = dynamic_cast<const Analyzer::ArrayExpr*>(arg0);
295  auto is_local_alloc = array_expr_arg && array_expr_arg->isLocalAlloc();
296  const auto& arg_ti = arg->get_type_info();
297  const auto arg_lvs = codegen(arg, true, co);
298  auto geo_uoper_arg = dynamic_cast<const Analyzer::GeoUOper*>(arg);
299  auto geo_binoper_arg = dynamic_cast<const Analyzer::GeoBinOper*>(arg);
300  auto geo_expr_arg = dynamic_cast<const Analyzer::GeoExpr*>(arg);
301  // TODO(adb / d): Assuming no const array cols for geo (for now)
302  if ((geo_uoper_arg || geo_binoper_arg) && arg_ti.is_geometry()) {
303  // Extract arr sizes and put them in the map, forward arr pointers
304  CHECK_EQ(2 * static_cast<size_t>(arg_ti.get_physical_coord_cols()), arg_lvs.size());
305  for (size_t i = 0; i < arg_lvs.size(); i++) {
306  auto arr = arg_lvs[i++];
307  auto size = arg_lvs[i];
308  orig_arg_lvs.push_back(arr);
309  const_arr_size[arr] = size;
310  }
311  } else if (geo_expr_arg && geo_expr_arg->get_type_info().is_geometry()) {
312  CHECK(geo_expr_arg->get_type_info().get_type() == kPOINT);
313  CHECK_EQ(arg_lvs.size(), size_t(2));
314  for (size_t j = 0; j < arg_lvs.size(); j++) {
315  orig_arg_lvs.push_back(arg_lvs[j]);
316  }
317  } else if (arg_ti.is_geometry()) {
318  CHECK_EQ(static_cast<size_t>(arg_ti.get_physical_coord_cols()), arg_lvs.size());
319  for (size_t j = 0; j < arg_lvs.size(); j++) {
320  orig_arg_lvs.push_back(arg_lvs[j]);
321  }
322  } else if (arg_ti.is_text_encoding_none()) {
323  if (arg_lvs.size() == 3) {
324  // arg_lvs contains:
325  // arg_lvs[0] StringView struct { i8*, i64 }
326  // arg_lvs[1] i8* pointer
327  // arg_lvs[2] i32 string length (truncated from i64)
328  std::copy(
329  std::begin(arg_lvs), std::end(arg_lvs), std::back_inserter(orig_arg_lvs));
330  } else if (arg_lvs.size() == 1) {
331  // TextEncodingNone*
332  // orig_arg_lvs should contain:
333  // orig_arg_lvs[0]: TextEncodingNone struct { i8*, i64, i8* }
334  // orig_arg_lvs[1]: i8*
335  // orig_arg_lvs[1]: i32 string length (truncated from i64)
336  CHECK(arg_lvs[0]->getType()->isPointerTy());
337  auto none_enc_string = cgen_state_->ir_builder_.CreateLoad(
338  arg_lvs[0]->getType()->getPointerElementType(), arg_lvs[0]);
339  orig_arg_lvs.push_back(none_enc_string);
340  orig_arg_lvs.push_back(
341  cgen_state_->ir_builder_.CreateExtractValue(none_enc_string, 0));
342  orig_arg_lvs.push_back(cgen_state_->ir_builder_.CreateTrunc(
343  cgen_state_->ir_builder_.CreateExtractValue(none_enc_string, 1),
344  llvm::Type::getInt32Ty(cgen_state_->context_)));
345  }
346  } else if (arg_ti.is_text_encoding_dict()) {
347  CHECK_EQ(size_t(1), arg_lvs.size());
348  orig_arg_lvs.push_back(arg_lvs[0]);
349  } else {
350  if (arg_lvs.size() > 1) {
351  CHECK(arg_ti.is_array());
352  CHECK_EQ(size_t(2), arg_lvs.size());
353  const_arr_size[arg_lvs.front()] = arg_lvs.back();
354  } else {
355  CHECK_EQ(size_t(1), arg_lvs.size());
356  /* arg_lvs contains:
357  &col_buf1
358  */
359  if (is_local_alloc && arg_ti.get_size() > 0) {
360  const_arr_size[arg_lvs.front()] = cgen_state_->llInt(arg_ti.get_size());
361  }
362  }
363  orig_arg_lvs.push_back(arg_lvs.front());
364  }
365  }
366  // The extension function implementations don't handle NULL, they work under
367  // the assumption that the inputs are validated before calling them. Generate
368  // code to do the check at the call site: if any argument is NULL, return NULL
369  // without calling the function at all.
370  const auto [bbs, null_buffer_ptr] = beginArgsNullcheck(function_oper, orig_arg_lvs);
371  CHECK_GE(orig_arg_lvs.size(), function_oper->getArity());
372  // Arguments must be converted to the types the extension function can handle.
374  function_oper, &ext_func_sig, orig_arg_lvs, orig_arg_lvs_index, const_arr_size, co);
375 
376  if (ext_func_sig.usesManager()) {
378  throw QueryMustRunOnCpu();
379  }
380  llvm::Value* row_func_mgr = get_arg_by_name(cgen_state_->row_func_, "row_func_mgr");
381  args.insert(args.begin(), row_func_mgr);
382  }
383 
384  llvm::Value* buffer_ret{nullptr};
385  if (ret_ti.is_buffer()) {
386  // codegen buffer return as first arg
387  CHECK(ret_ti.is_array() || ret_ti.is_text_encoding_none());
388  ret_ty = llvm::Type::getVoidTy(cgen_state_->context_);
389  const auto struct_ty = get_buffer_struct_type(
390  cgen_state_,
391  function_oper->getName(),
392  0,
394  buffer_ret = cgen_state_->ir_builder_.CreateAlloca(struct_ty);
395  args.insert(args.begin(), buffer_ret);
396  }
397 
398  const auto ext_call = cgen_state_->emitExternalCall(
399  ext_func_sig.getName(), ret_ty, args, {}, ret_ti.is_buffer());
400  auto ext_call_nullcheck = endArgsNullcheck(
401  bbs, ret_ti.is_buffer() ? buffer_ret : ext_call, null_buffer_ptr, function_oper);
402 
403  // Cast the return of the extension function to match the FunctionOper
404  if (!(ret_ti.is_buffer() || ret_ti.is_text_encoding_dict())) {
405  const auto extension_ret_ti = get_sql_type_from_llvm_type(ret_ty);
406  if (bbs.args_null_bb &&
407  extension_ret_ti.get_type() != function_oper->get_type_info().get_type() &&
408  // Skip i1-->i8 casts for ST_ functions.
409  // function_oper ret type is i1, extension ret type is 'upgraded' to i8
410  // during type deserialization to 'handle' NULL returns, hence i1-->i8.
411  // ST_ functions can't return NULLs, we just need to check arg nullness
412  // and if any args are NULL then ST_ function is not called
413  function_oper->getName().substr(0, 3) != std::string("ST_")) {
414  ext_call_nullcheck = codegenCast(ext_call_nullcheck,
415  extension_ret_ti,
416  function_oper->get_type_info(),
417  false,
418  co);
419  }
420  }
421 
422  cgen_state_->ext_call_cache_.push_back({function_oper, ext_call_nullcheck});
423  return ext_call_nullcheck;
424 }
425 
426 // Start the control flow needed for a call site check of NULL arguments.
427 std::tuple<CodeGenerator::ArgNullcheckBBs, llvm::Value*>
429  const std::vector<llvm::Value*>& orig_arg_lvs) {
431  llvm::BasicBlock* args_null_bb{nullptr};
432  llvm::BasicBlock* args_notnull_bb{nullptr};
433  llvm::BasicBlock* orig_bb = cgen_state_->ir_builder_.GetInsertBlock();
434  llvm::Value* null_array_alloca{nullptr};
435  // Only generate the check if required (at least one argument must be nullable).
436  if (ext_func_call_requires_nullcheck(function_oper)) {
437  const auto func_ti = function_oper->get_type_info();
438  if (func_ti.is_buffer()) {
439  const auto arr_struct_ty = get_buffer_struct_type(
440  cgen_state_,
441  function_oper->getName(),
442  0,
444  null_array_alloca = cgen_state_->ir_builder_.CreateAlloca(arr_struct_ty);
445  }
446  const auto args_notnull_lv = cgen_state_->ir_builder_.CreateNot(
447  codegenFunctionOperNullArg(function_oper, orig_arg_lvs));
448  args_notnull_bb = llvm::BasicBlock::Create(
449  cgen_state_->context_, "args_notnull", cgen_state_->current_func_);
450  args_null_bb = llvm::BasicBlock::Create(
452  cgen_state_->ir_builder_.CreateCondBr(args_notnull_lv, args_notnull_bb, args_null_bb);
453  cgen_state_->ir_builder_.SetInsertPoint(args_notnull_bb);
454  }
455  return std::make_tuple(
456  CodeGenerator::ArgNullcheckBBs{args_null_bb, args_notnull_bb, orig_bb},
457  null_array_alloca);
458 }
459 
460 // Wrap up the control flow needed for NULL argument handling.
462  const ArgNullcheckBBs& bbs,
463  llvm::Value* fn_ret_lv,
464  llvm::Value* null_array_ptr,
465  const Analyzer::FunctionOper* function_oper) {
467  if (bbs.args_null_bb) {
468  CHECK(bbs.args_notnull_bb);
469  cgen_state_->ir_builder_.CreateBr(bbs.args_null_bb);
470  cgen_state_->ir_builder_.SetInsertPoint(bbs.args_null_bb);
471 
472  llvm::PHINode* ext_call_phi{nullptr};
473  llvm::Value* null_lv{nullptr};
474  const auto func_ti = function_oper->get_type_info();
475  if (!func_ti.is_buffer()) {
476  // The pre-cast SQL equivalent of the type returned by the extension function.
477  const auto extension_ret_ti = get_sql_type_from_llvm_type(fn_ret_lv->getType());
478 
479  ext_call_phi = cgen_state_->ir_builder_.CreatePHI(
480  extension_ret_ti.is_fp()
481  ? get_fp_type(extension_ret_ti.get_size() * 8, cgen_state_->context_)
482  : get_int_type(extension_ret_ti.get_size() * 8, cgen_state_->context_),
483  2);
484 
485  null_lv =
486  extension_ret_ti.is_fp()
487  ? static_cast<llvm::Value*>(cgen_state_->inlineFpNull(extension_ret_ti))
488  : static_cast<llvm::Value*>(cgen_state_->inlineIntNull(extension_ret_ti));
489  } else {
490  const auto arr_struct_ty = get_buffer_struct_type(
491  cgen_state_,
492  function_oper->getName(),
493  0,
495  ext_call_phi =
496  cgen_state_->ir_builder_.CreatePHI(llvm::PointerType::get(arr_struct_ty, 0), 2);
497 
498  CHECK(null_array_ptr);
499  const auto arr_null_bool =
500  cgen_state_->ir_builder_.CreateStructGEP(arr_struct_ty, null_array_ptr, 2);
501  cgen_state_->ir_builder_.CreateStore(
502  llvm::ConstantInt::get(get_int_type(8, cgen_state_->context_), 1),
503  arr_null_bool);
504 
505  const auto arr_null_size =
506  cgen_state_->ir_builder_.CreateStructGEP(arr_struct_ty, null_array_ptr, 1);
507  cgen_state_->ir_builder_.CreateStore(
508  llvm::ConstantInt::get(get_int_type(64, cgen_state_->context_), 0),
509  arr_null_size);
510  }
511  ext_call_phi->addIncoming(fn_ret_lv, bbs.args_notnull_bb);
512  ext_call_phi->addIncoming(func_ti.is_buffer() ? null_array_ptr : null_lv,
513  bbs.orig_bb);
514 
515  return ext_call_phi;
516  }
517  return fn_ret_lv;
518 }
519 
520 namespace {
521 
523  const auto& ret_ti = function_oper->get_type_info();
524  if (!ret_ti.is_integer() && !ret_ti.is_fp()) {
525  return true;
526  }
527  for (size_t i = 0; i < function_oper->getArity(); ++i) {
528  const auto arg = function_oper->getArg(i);
529  const auto& arg_ti = arg->get_type_info();
530  if (!arg_ti.is_integer() && !arg_ti.is_fp()) {
531  return true;
532  }
533  }
534  return false;
535 }
536 
537 } // namespace
538 
541  const CompilationOptions& co) {
543  if (call_requires_custom_type_handling(function_oper)) {
544  // Some functions need the return type to be the same as the input type.
545  if (function_oper->getName() == "FLOOR" || function_oper->getName() == "CEIL") {
546  CHECK_EQ(size_t(1), function_oper->getArity());
547  const auto arg = function_oper->getArg(0);
548  const auto& arg_ti = arg->get_type_info();
549  CHECK(arg_ti.is_decimal());
550  const auto arg_lvs = codegen(arg, true, co);
551  CHECK_EQ(size_t(1), arg_lvs.size());
552  const auto arg_lv = arg_lvs.front();
553  CHECK(arg_lv->getType()->isIntegerTy(64));
555  std::tie(bbs, std::ignore) = beginArgsNullcheck(function_oper, {arg_lvs});
556  const std::string func_name =
557  (function_oper->getName() == "FLOOR") ? "decimal_floor" : "decimal_ceil";
558  const auto covar_result_lv = cgen_state_->emitCall(
559  func_name, {arg_lv, cgen_state_->llInt(exp_to_scale(arg_ti.get_scale()))});
560  const auto ret_ti = function_oper->get_type_info();
561  CHECK(ret_ti.is_decimal());
562  CHECK_EQ(0, ret_ti.get_scale());
563  const auto result_lv = cgen_state_->ir_builder_.CreateSDiv(
564  covar_result_lv, cgen_state_->llInt(exp_to_scale(arg_ti.get_scale())));
565  return endArgsNullcheck(bbs, result_lv, nullptr, function_oper);
566  } else if (function_oper->getName() == "ROUND" &&
567  function_oper->getArg(0)->get_type_info().is_decimal()) {
568  CHECK_EQ(size_t(2), function_oper->getArity());
569 
570  const auto arg0 = function_oper->getArg(0);
571  const auto& arg0_ti = arg0->get_type_info();
572  const auto arg0_lvs = codegen(arg0, true, co);
573  CHECK_EQ(size_t(1), arg0_lvs.size());
574  const auto arg0_lv = arg0_lvs.front();
575  CHECK(arg0_lv->getType()->isIntegerTy(64));
576 
577  const auto arg1 = function_oper->getArg(1);
578  const auto& arg1_ti = arg1->get_type_info();
579  CHECK(arg1_ti.is_integer());
580  const auto arg1_lvs = codegen(arg1, true, co);
581  auto arg1_lv = arg1_lvs.front();
582  if (arg1_ti.get_type() != kINT) {
583  arg1_lv = codegenCast(arg1_lv, arg1_ti, SQLTypeInfo(kINT, true), false, co);
584  }
585 
587  std::tie(bbs0, std::ignore) =
588  beginArgsNullcheck(function_oper, {arg0_lv, arg1_lvs.front()});
589 
590  const std::string func_name = "Round__4";
591  const auto ret_ti = function_oper->get_type_info();
592  CHECK(ret_ti.is_decimal());
593  const auto result_lv = cgen_state_->emitExternalCall(
594  func_name,
596  {arg0_lv, arg1_lv, cgen_state_->llInt(arg0_ti.get_scale())});
597 
598  return endArgsNullcheck(bbs0, result_lv, nullptr, function_oper);
599  }
600  throw std::runtime_error("Type combination not supported for function " +
601  function_oper->getName());
602  }
603  return codegenFunctionOper(function_oper, co);
604 }
605 
606 // Generates code which returns true iff at least one of the arguments is NULL.
608  const Analyzer::FunctionOper* function_oper,
609  const std::vector<llvm::Value*>& orig_arg_lvs) {
611  llvm::Value* one_arg_null =
612  llvm::ConstantInt::get(llvm::IntegerType::getInt1Ty(cgen_state_->context_), false);
613  size_t physical_coord_cols = 0;
614  for (size_t i = 0, j = 0; i < function_oper->getArity();
615  ++i, j += std::max(size_t(1), physical_coord_cols)) {
616  const auto arg = function_oper->getArg(i);
617  const auto& arg_ti = arg->get_type_info();
618  physical_coord_cols = arg_ti.get_physical_coord_cols();
619  if (arg_ti.get_notnull()) {
620  continue;
621  }
622  auto geo_expr_arg = dynamic_cast<const Analyzer::GeoExpr*>(arg);
623  if (geo_expr_arg && arg_ti.is_geometry()) {
624  CHECK(arg_ti.get_type() == kPOINT);
625  auto is_null_lv = cgen_state_->ir_builder_.CreateICmp(
626  llvm::CmpInst::ICMP_EQ,
627  orig_arg_lvs[j],
628  llvm::ConstantPointerNull::get( // TODO: centralize logic; in geo expr?
629  arg_ti.get_compression() == kENCODING_GEOINT
630  ? llvm::Type::getInt32PtrTy(cgen_state_->context_)
631  : llvm::Type::getDoublePtrTy(cgen_state_->context_)));
632  one_arg_null = cgen_state_->ir_builder_.CreateOr(one_arg_null, is_null_lv);
633  physical_coord_cols = 2; // number of lvs to advance
634  continue;
635  }
636 #ifdef ENABLE_GEOS
637  // If geo arg is coming from geos, skip the null check, assume it's a valid geo
638  if (arg_ti.is_geometry()) {
639  auto* coords_load = llvm::dyn_cast<llvm::LoadInst>(orig_arg_lvs[i]);
640  if (coords_load) {
641  continue;
642  }
643  }
644 #endif
645  if (arg_ti.is_geometry()) {
646  auto* coords_alloca = llvm::dyn_cast<llvm::AllocaInst>(orig_arg_lvs[j]);
647  auto* coords_phi = llvm::dyn_cast<llvm::PHINode>(orig_arg_lvs[j]);
648  if (coords_alloca || coords_phi) {
649  // TODO: null check dynamically generated geometries
650  continue;
651  }
652  }
653  if (arg_ti.is_text_encoding_dict()) {
654  one_arg_null = cgen_state_->ir_builder_.CreateOr(
655  one_arg_null, codegenIsNullNumber(orig_arg_lvs[j], arg_ti));
656  continue;
657  }
658  if (arg_ti.is_buffer() || arg_ti.is_geometry()) {
659  // POINT [un]compressed coord check requires custom checker and chunk iterator
660  // Non-POINT NULL geographies will have a normally encoded null coord array
661  auto fname =
662  (arg_ti.get_type() == kPOINT) ? "point_coord_array_is_null" : "array_is_null";
663  auto is_null_lv = cgen_state_->emitExternalCall(
664  fname, get_int_type(1, cgen_state_->context_), {orig_arg_lvs[j], posArg(arg)});
665  one_arg_null = cgen_state_->ir_builder_.CreateOr(one_arg_null, is_null_lv);
666  continue;
667  }
668  CHECK(arg_ti.is_number() or arg_ti.is_boolean());
669  one_arg_null = cgen_state_->ir_builder_.CreateOr(
670  one_arg_null, codegenIsNullNumber(orig_arg_lvs[j], arg_ti));
671  }
672  return one_arg_null;
673 }
674 
675 llvm::Value* CodeGenerator::codegenCompression(const SQLTypeInfo& type_info) {
677  int32_t compression = (type_info.get_compression() == kENCODING_GEOINT &&
678  type_info.get_comp_param() == 32)
679  ? 1
680  : 0;
681 
682  return cgen_state_->llInt(compression);
683 }
684 
685 std::pair<llvm::Value*, llvm::Value*> CodeGenerator::codegenArrayBuff(
686  llvm::Value* chunk,
687  llvm::Value* row_pos,
688  SQLTypes array_type,
689  bool cast_and_extend) {
691  const auto elem_ti =
692  SQLTypeInfo(
693  SQLTypes::kARRAY, 0, 0, false, EncodingType::kENCODING_NONE, 0, array_type)
694  .get_elem_type();
695 
696  auto buff = cgen_state_->emitExternalCall(
697  "array_buff", llvm::Type::getInt32PtrTy(cgen_state_->context_), {chunk, row_pos});
698 
699  auto len = cgen_state_->emitExternalCall(
700  "array_size",
701  get_int_type(32, cgen_state_->context_),
702  {chunk, row_pos, cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
703 
704  if (cast_and_extend) {
705  buff = castArrayPointer(buff, elem_ti);
706  len =
707  cgen_state_->ir_builder_.CreateZExt(len, get_int_type(64, cgen_state_->context_));
708  }
709 
710  return std::make_pair(buff, len);
711 }
712 
713 void CodeGenerator::codegenBufferArgs(const std::string& ext_func_name,
714  size_t param_num,
715  llvm::Value* buffer_buf,
716  llvm::Value* buffer_size,
717  llvm::Value* buffer_null,
718  std::vector<llvm::Value*>& output_args) {
720  CHECK(buffer_buf);
721  CHECK(buffer_size);
722 
723  auto buffer_abstraction = get_buffer_struct_type(
724  cgen_state_, ext_func_name, param_num, buffer_buf->getType());
725  auto alloc_mem = cgen_state_->ir_builder_.CreateAlloca(buffer_abstraction);
726 
727  auto buffer_buf_ptr =
728  cgen_state_->ir_builder_.CreateStructGEP(buffer_abstraction, alloc_mem, 0);
729  cgen_state_->ir_builder_.CreateStore(buffer_buf, buffer_buf_ptr);
730 
731  auto buffer_size_ptr =
732  cgen_state_->ir_builder_.CreateStructGEP(buffer_abstraction, alloc_mem, 1);
733  cgen_state_->ir_builder_.CreateStore(buffer_size, buffer_size_ptr);
734 
735  auto bool_extended_type = llvm::Type::getInt8Ty(cgen_state_->context_);
736  auto buffer_null_extended =
737  cgen_state_->ir_builder_.CreateZExt(buffer_null, bool_extended_type);
738  auto buffer_is_null_ptr =
739  cgen_state_->ir_builder_.CreateStructGEP(buffer_abstraction, alloc_mem, 2);
740  cgen_state_->ir_builder_.CreateStore(buffer_null_extended, buffer_is_null_ptr);
741  output_args.push_back(alloc_mem);
742 }
743 
744 llvm::StructType* CodeGenerator::createPointStructType(const std::string& udf_func_name,
745  size_t param_num) {
746  llvm::Module* module_for_lookup = cgen_state_->module_;
747  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
748 
749  llvm::StructType* generated_struct_type =
750  llvm::StructType::get(cgen_state_->context_,
751  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
752  llvm::Type::getInt32Ty(cgen_state_->context_),
753  llvm::Type::getInt32Ty(cgen_state_->context_),
754  llvm::Type::getInt32Ty(cgen_state_->context_),
755  llvm::Type::getInt32Ty(cgen_state_->context_)},
756  false);
757 
758  if (udf_func) {
759  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
760  CHECK(param_num < udf_func_type->getNumParams());
761  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
762  CHECK(param_pointer_type->isPointerTy());
763  llvm::Type* param_type = param_pointer_type->getPointerElementType();
764  CHECK(param_type->isStructTy());
765  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
766  CHECK_EQ(struct_type->getStructNumElements(), 5u)
767  << serialize_llvm_object(struct_type);
768  const auto expected_elems = generated_struct_type->elements();
769  const auto current_elems = struct_type->elements();
770  for (size_t i = 0; i < expected_elems.size(); i++) {
771  CHECK_EQ(expected_elems[i], current_elems[i]);
772  }
773  if (struct_type->isLiteral()) {
774  return struct_type;
775  }
776 
777  llvm::StringRef struct_name = struct_type->getStructName();
778 #if LLVM_VERSION_MAJOR >= 12
779  llvm::StructType* point_type =
780  struct_type->getTypeByName(cgen_state_->context_, struct_name);
781 #else
782  llvm::StructType* point_type = module_for_lookup->getTypeByName(struct_name);
783 #endif
784  CHECK(point_type);
785 
786  return point_type;
787  }
788  return generated_struct_type;
789 }
790 
791 void CodeGenerator::codegenGeoPointArgs(const std::string& udf_func_name,
792  size_t param_num,
793  llvm::Value* point_buf,
794  llvm::Value* point_size,
795  llvm::Value* compression,
796  llvm::Value* input_srid,
797  llvm::Value* output_srid,
798  std::vector<llvm::Value*>& output_args) {
800  CHECK(point_buf);
801  CHECK(point_size);
802  CHECK(compression);
803  CHECK(input_srid);
804  CHECK(output_srid);
805 
806  auto point_abstraction = createPointStructType(udf_func_name, param_num);
807  auto alloc_mem = cgen_state_->ir_builder_.CreateAlloca(point_abstraction, nullptr);
808 
809  auto point_buf_ptr =
810  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 0);
811  cgen_state_->ir_builder_.CreateStore(point_buf, point_buf_ptr);
812 
813  auto point_size_ptr =
814  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 1);
815  cgen_state_->ir_builder_.CreateStore(point_size, point_size_ptr);
816 
817  auto point_compression_ptr =
818  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 2);
819  cgen_state_->ir_builder_.CreateStore(compression, point_compression_ptr);
820 
821  auto input_srid_ptr =
822  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 3);
823  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
824 
825  auto output_srid_ptr =
826  cgen_state_->ir_builder_.CreateStructGEP(point_abstraction, alloc_mem, 4);
827  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
828 
829  output_args.push_back(alloc_mem);
830 }
831 
833  const std::string& udf_func_name,
834  size_t param_num) {
835  llvm::Module* module_for_lookup = cgen_state_->module_;
836  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
837 
838  llvm::StructType* generated_struct_type =
839  llvm::StructType::get(cgen_state_->context_,
840  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
841  llvm::Type::getInt32Ty(cgen_state_->context_),
842  llvm::Type::getInt32Ty(cgen_state_->context_),
843  llvm::Type::getInt32Ty(cgen_state_->context_),
844  llvm::Type::getInt32Ty(cgen_state_->context_)},
845  false);
846 
847  if (udf_func) {
848  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
849  CHECK(param_num < udf_func_type->getNumParams());
850  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
851  CHECK(param_pointer_type->isPointerTy());
852  llvm::Type* param_type = param_pointer_type->getPointerElementType();
853  CHECK(param_type->isStructTy());
854  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
855  CHECK(struct_type->isStructTy());
856  CHECK_EQ(struct_type->getStructNumElements(), 5u);
857 
858  const auto expected_elems = generated_struct_type->elements();
859  const auto current_elems = struct_type->elements();
860  for (size_t i = 0; i < expected_elems.size(); i++) {
861  CHECK_EQ(expected_elems[i], current_elems[i]);
862  }
863  if (struct_type->isLiteral()) {
864  return struct_type;
865  }
866 
867  llvm::StringRef struct_name = struct_type->getStructName();
868 #if LLVM_VERSION_MAJOR >= 12
869  llvm::StructType* multi_point_type =
870  struct_type->getTypeByName(cgen_state_->context_, struct_name);
871 #else
872  llvm::StructType* multi_point_type = module_for_lookup->getTypeByName(struct_name);
873 #endif
874  CHECK(multi_point_type);
875 
876  return multi_point_type;
877  }
878  return generated_struct_type;
879 }
880 
881 void CodeGenerator::codegenGeoMultiPointArgs(const std::string& udf_func_name,
882  size_t param_num,
883  llvm::Value* multi_point_buf,
884  llvm::Value* multi_point_size,
885  llvm::Value* compression,
886  llvm::Value* input_srid,
887  llvm::Value* output_srid,
888  std::vector<llvm::Value*>& output_args) {
890  CHECK(multi_point_buf);
891  CHECK(multi_point_size);
892  CHECK(compression);
893  CHECK(input_srid);
894  CHECK(output_srid);
895 
896  auto multi_point_abstraction = createMultiPointStructType(udf_func_name, param_num);
897  auto alloc_mem =
898  cgen_state_->ir_builder_.CreateAlloca(multi_point_abstraction, nullptr);
899 
900  auto multi_point_buf_ptr =
901  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 0);
902  cgen_state_->ir_builder_.CreateStore(multi_point_buf, multi_point_buf_ptr);
903 
904  auto multi_point_size_ptr =
905  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 1);
906  cgen_state_->ir_builder_.CreateStore(multi_point_size, multi_point_size_ptr);
907 
908  auto compression_ptr =
909  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 2);
910  cgen_state_->ir_builder_.CreateStore(compression, compression_ptr);
911 
912  auto input_srid_ptr =
913  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 3);
914  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
915 
916  auto output_srid_ptr =
917  cgen_state_->ir_builder_.CreateStructGEP(multi_point_abstraction, alloc_mem, 4);
918  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
919 
920  output_args.push_back(alloc_mem);
921 }
922 
924  const std::string& udf_func_name,
925  size_t param_num) {
926  llvm::Module* module_for_lookup = cgen_state_->module_;
927  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
928 
929  llvm::StructType* generated_struct_type =
930  llvm::StructType::get(cgen_state_->context_,
931  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
932  llvm::Type::getInt32Ty(cgen_state_->context_),
933  llvm::Type::getInt32Ty(cgen_state_->context_),
934  llvm::Type::getInt32Ty(cgen_state_->context_),
935  llvm::Type::getInt32Ty(cgen_state_->context_)},
936  false);
937 
938  if (udf_func) {
939  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
940  CHECK(param_num < udf_func_type->getNumParams());
941  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
942  CHECK(param_pointer_type->isPointerTy());
943  llvm::Type* param_type = param_pointer_type->getPointerElementType();
944  CHECK(param_type->isStructTy());
945  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
946  CHECK(struct_type->isStructTy());
947  CHECK_EQ(struct_type->getStructNumElements(), 5u);
948 
949  const auto expected_elems = generated_struct_type->elements();
950  const auto current_elems = struct_type->elements();
951  for (size_t i = 0; i < expected_elems.size(); i++) {
952  CHECK_EQ(expected_elems[i], current_elems[i]);
953  }
954  if (struct_type->isLiteral()) {
955  return struct_type;
956  }
957 
958  llvm::StringRef struct_name = struct_type->getStructName();
959 #if LLVM_VERSION_MAJOR >= 12
960  llvm::StructType* line_string_type =
961  struct_type->getTypeByName(cgen_state_->context_, struct_name);
962 #else
963  llvm::StructType* line_string_type = module_for_lookup->getTypeByName(struct_name);
964 #endif
965  CHECK(line_string_type);
966 
967  return line_string_type;
968  }
969  return generated_struct_type;
970 }
971 
972 void CodeGenerator::codegenGeoLineStringArgs(const std::string& udf_func_name,
973  size_t param_num,
974  llvm::Value* line_string_buf,
975  llvm::Value* line_string_size,
976  llvm::Value* compression,
977  llvm::Value* input_srid,
978  llvm::Value* output_srid,
979  std::vector<llvm::Value*>& output_args) {
981  CHECK(line_string_buf);
982  CHECK(line_string_size);
983  CHECK(compression);
984  CHECK(input_srid);
985  CHECK(output_srid);
986 
987  auto line_string_abstraction = createLineStringStructType(udf_func_name, param_num);
988  auto alloc_mem =
989  cgen_state_->ir_builder_.CreateAlloca(line_string_abstraction, nullptr);
990 
991  auto line_string_buf_ptr =
992  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 0);
993  cgen_state_->ir_builder_.CreateStore(line_string_buf, line_string_buf_ptr);
994 
995  auto line_string_size_ptr =
996  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 1);
997  cgen_state_->ir_builder_.CreateStore(line_string_size, line_string_size_ptr);
998 
999  auto line_string_compression_ptr =
1000  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 2);
1001  cgen_state_->ir_builder_.CreateStore(compression, line_string_compression_ptr);
1002 
1003  auto input_srid_ptr =
1004  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 3);
1005  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
1006 
1007  auto output_srid_ptr =
1008  cgen_state_->ir_builder_.CreateStructGEP(line_string_abstraction, alloc_mem, 4);
1009  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
1010 
1011  output_args.push_back(alloc_mem);
1012 }
1013 
1015  const std::string& udf_func_name,
1016  size_t param_num) {
1017  llvm::Module* module_for_lookup = cgen_state_->module_;
1018  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
1019 
1020  llvm::StructType* generated_struct_type =
1021  llvm::StructType::get(cgen_state_->context_,
1022  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
1023  llvm::Type::getInt32Ty(cgen_state_->context_),
1024  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1025  llvm::Type::getInt32Ty(cgen_state_->context_),
1026  llvm::Type::getInt32Ty(cgen_state_->context_),
1027  llvm::Type::getInt32Ty(cgen_state_->context_),
1028  llvm::Type::getInt32Ty(cgen_state_->context_)},
1029  false);
1030 
1031  if (udf_func) {
1032  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
1033  CHECK(param_num < udf_func_type->getNumParams());
1034  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
1035  CHECK(param_pointer_type->isPointerTy());
1036  llvm::Type* param_type = param_pointer_type->getPointerElementType();
1037  CHECK(param_type->isStructTy());
1038  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
1039  CHECK(struct_type->isStructTy());
1040  CHECK_EQ(struct_type->getStructNumElements(), 7u);
1041 
1042  const auto expected_elems = generated_struct_type->elements();
1043  const auto current_elems = struct_type->elements();
1044  for (size_t i = 0; i < expected_elems.size(); i++) {
1045  CHECK_EQ(expected_elems[i], current_elems[i]);
1046  }
1047  if (struct_type->isLiteral()) {
1048  return struct_type;
1049  }
1050 
1051  llvm::StringRef struct_name = struct_type->getStructName();
1052 #if LLVM_VERSION_MAJOR >= 12
1053  llvm::StructType* multi_linestring_type =
1054  struct_type->getTypeByName(cgen_state_->context_, struct_name);
1055 #else
1056  llvm::StructType* multi_linestring_type =
1057  module_for_lookup->getTypeByName(struct_name);
1058 #endif
1059  CHECK(multi_linestring_type);
1060 
1061  return multi_linestring_type;
1062  }
1063  return generated_struct_type;
1064 }
1065 
1067  const std::string& udf_func_name,
1068  size_t param_num,
1069  llvm::Value* multi_linestring_coords,
1070  llvm::Value* multi_linestring_coords_size,
1071  llvm::Value* linestring_sizes,
1072  llvm::Value* linestring_sizes_size,
1073  llvm::Value* compression,
1074  llvm::Value* input_srid,
1075  llvm::Value* output_srid,
1076  std::vector<llvm::Value*>& output_args) {
1078  CHECK(multi_linestring_coords);
1079  CHECK(multi_linestring_coords_size);
1080  CHECK(linestring_sizes);
1081  CHECK(linestring_sizes_size);
1082  CHECK(compression);
1083  CHECK(input_srid);
1084  CHECK(output_srid);
1085 
1086  auto multi_linestring_abstraction =
1087  createMultiLineStringStructType(udf_func_name, param_num);
1088  auto alloc_mem =
1089  cgen_state_->ir_builder_.CreateAlloca(multi_linestring_abstraction, nullptr);
1090 
1091  auto multi_linestring_coords_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1092  multi_linestring_abstraction, alloc_mem, 0);
1093  cgen_state_->ir_builder_.CreateStore(multi_linestring_coords,
1094  multi_linestring_coords_ptr);
1095 
1096  auto multi_linestring_coords_size_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1097  multi_linestring_abstraction, alloc_mem, 1);
1098  cgen_state_->ir_builder_.CreateStore(multi_linestring_coords_size,
1099  multi_linestring_coords_size_ptr);
1100 
1101  auto linestring_sizes_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1102  multi_linestring_abstraction, alloc_mem, 2);
1103  const auto linestring_sizes_ptr_ty =
1104  llvm::dyn_cast<llvm::PointerType>(linestring_sizes_ptr->getType());
1105  CHECK(linestring_sizes_ptr_ty);
1106  cgen_state_->ir_builder_.CreateStore(
1107  cgen_state_->ir_builder_.CreateBitCast(
1108  linestring_sizes, linestring_sizes_ptr_ty->getPointerElementType()),
1109  linestring_sizes_ptr);
1110 
1111  auto linestring_sizes_size_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1112  multi_linestring_abstraction, alloc_mem, 3);
1113  cgen_state_->ir_builder_.CreateStore(linestring_sizes_size, linestring_sizes_size_ptr);
1114 
1115  auto multi_linestring_compression_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1116  multi_linestring_abstraction, alloc_mem, 4);
1117  cgen_state_->ir_builder_.CreateStore(compression, multi_linestring_compression_ptr);
1118 
1119  auto input_srid_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1120  multi_linestring_abstraction, alloc_mem, 5);
1121  cgen_state_->ir_builder_.CreateStore(input_srid, input_srid_ptr);
1122 
1123  auto output_srid_ptr = cgen_state_->ir_builder_.CreateStructGEP(
1124  multi_linestring_abstraction, alloc_mem, 6);
1125  cgen_state_->ir_builder_.CreateStore(output_srid, output_srid_ptr);
1126 
1127  output_args.push_back(alloc_mem);
1128 }
1129 
1130 llvm::StructType* CodeGenerator::createPolygonStructType(const std::string& udf_func_name,
1131  size_t param_num) {
1132  llvm::Module* module_for_lookup = cgen_state_->module_;
1133  llvm::Function* udf_func = module_for_lookup->getFunction(udf_func_name);
1134 
1135  llvm::StructType* generated_struct_type =
1136  llvm::StructType::get(cgen_state_->context_,
1137  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
1138  llvm::Type::getInt32Ty(cgen_state_->context_),
1139  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1140  llvm::Type::getInt32Ty(cgen_state_->context_),
1141  llvm::Type::getInt32Ty(cgen_state_->context_),
1142  llvm::Type::getInt32Ty(cgen_state_->context_),
1143  llvm::Type::getInt32Ty(cgen_state_->context_)},
1144  false);
1145 
1146  if (udf_func) {
1147  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
1148  CHECK(param_num < udf_func_type->getNumParams());
1149  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
1150  CHECK(param_pointer_type->isPointerTy());
1151  llvm::Type* param_type = param_pointer_type->getPointerElementType();
1152  CHECK(param_type->isStructTy());
1153  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
1154 
1155  CHECK(struct_type->isStructTy());
1156  CHECK_EQ(struct_type->getStructNumElements(), 7u);
1157 
1158  const auto expected_elems = generated_struct_type->elements();
1159  const auto current_elems = struct_type->elements();
1160  for (size_t i = 0; i < expected_elems.size(); i++) {
1161  CHECK_EQ(expected_elems[i], current_elems[i]);
1162  }
1163  if (struct_type->isLiteral()) {
1164  return struct_type;
1165  }
1166 
1167  llvm::StringRef struct_name = struct_type->getStructName();
1168 
1169 #if LLVM_VERSION_MAJOR >= 12
1170  llvm::StructType* polygon_type =
1171  struct_type->getTypeByName(cgen_state_->context_, struct_name);
1172 #else
1173  llvm::StructType* polygon_type = module_for_lookup->getTypeByName(struct_name);
1174 #endif
1175  CHECK(polygon_type);
1176 
1177  return polygon_type;
1178  }
1179  return generated_struct_type;
1180 }
1181 
1182 void CodeGenerator::codegenGeoPolygonArgs(const std::string& udf_func_name,
1183  size_t param_num,
1184  llvm::Value* polygon_buf,
1185  llvm::Value* polygon_size,
1186  llvm::Value* ring_sizes_buf,
1187  llvm::Value* num_rings,
1188  llvm::Value* compression,
1189  llvm::Value* input_srid,
1190  llvm::Value* output_srid,
1191  std::vector<llvm::Value*>& output_args) {
1193  CHECK(polygon_buf);
1194  CHECK(polygon_size);
1195  CHECK(ring_sizes_buf);
1196  CHECK(num_rings);
1197  CHECK(compression);
1198  CHECK(input_srid);
1199  CHECK(output_srid);
1200 
1201  auto& builder = cgen_state_->ir_builder_;
1202 
1203  auto polygon_abstraction = createPolygonStructType(udf_func_name, param_num);
1204  auto alloc_mem = builder.CreateAlloca(polygon_abstraction, nullptr);
1205 
1206  const auto polygon_buf_ptr = builder.CreateStructGEP(polygon_abstraction, alloc_mem, 0);
1207  builder.CreateStore(polygon_buf, polygon_buf_ptr);
1208 
1209  const auto polygon_size_ptr =
1210  builder.CreateStructGEP(polygon_abstraction, alloc_mem, 1);
1211  builder.CreateStore(polygon_size, polygon_size_ptr);
1212 
1213  const auto ring_sizes_buf_ptr =
1214  builder.CreateStructGEP(polygon_abstraction, alloc_mem, 2);
1215  const auto ring_sizes_ptr_ty =
1216  llvm::dyn_cast<llvm::PointerType>(ring_sizes_buf_ptr->getType());
1217  CHECK(ring_sizes_ptr_ty);
1218  builder.CreateStore(
1219  builder.CreateBitCast(ring_sizes_buf, ring_sizes_ptr_ty->getPointerElementType()),
1220  ring_sizes_buf_ptr);
1221 
1222  const auto ring_size_ptr = builder.CreateStructGEP(polygon_abstraction, alloc_mem, 3);
1223  builder.CreateStore(num_rings, ring_size_ptr);
1224 
1225  const auto polygon_compression_ptr =
1226  builder.CreateStructGEP(polygon_abstraction, alloc_mem, 4);
1227  builder.CreateStore(compression, polygon_compression_ptr);
1228 
1229  const auto input_srid_ptr = builder.CreateStructGEP(polygon_abstraction, alloc_mem, 5);
1230  builder.CreateStore(input_srid, input_srid_ptr);
1231 
1232  const auto output_srid_ptr = builder.CreateStructGEP(polygon_abstraction, alloc_mem, 6);
1233  builder.CreateStore(output_srid, output_srid_ptr);
1234 
1235  output_args.push_back(alloc_mem);
1236 }
1237 
1239  const std::string& udf_func_name,
1240  size_t param_num) {
1241  llvm::Function* udf_func = cgen_state_->module_->getFunction(udf_func_name);
1242 
1243  llvm::StructType* generated_struct_type =
1244  llvm::StructType::get(cgen_state_->context_,
1245  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
1246  llvm::Type::getInt32Ty(cgen_state_->context_),
1247  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1248  llvm::Type::getInt32Ty(cgen_state_->context_),
1249  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1250  llvm::Type::getInt32Ty(cgen_state_->context_),
1251  llvm::Type::getInt32Ty(cgen_state_->context_),
1252  llvm::Type::getInt32Ty(cgen_state_->context_),
1253  llvm::Type::getInt32Ty(cgen_state_->context_)},
1254  false);
1255 
1256  if (udf_func) {
1257  llvm::FunctionType* udf_func_type = udf_func->getFunctionType();
1258  CHECK(param_num < udf_func_type->getNumParams());
1259  llvm::Type* param_pointer_type = udf_func_type->getParamType(param_num);
1260  CHECK(param_pointer_type->isPointerTy());
1261  llvm::Type* param_type = param_pointer_type->getPointerElementType();
1262  CHECK(param_type->isStructTy());
1263  llvm::StructType* struct_type = llvm::cast<llvm::StructType>(param_type);
1264  CHECK(struct_type->isStructTy());
1265  CHECK_EQ(struct_type->getStructNumElements(), 9u);
1266  const auto expected_elems = generated_struct_type->elements();
1267  const auto current_elems = struct_type->elements();
1268  for (size_t i = 0; i < expected_elems.size(); i++) {
1269  CHECK_EQ(expected_elems[i], current_elems[i]);
1270  }
1271  if (struct_type->isLiteral()) {
1272  return struct_type;
1273  }
1274  llvm::StringRef struct_name = struct_type->getStructName();
1275 
1276 #if LLVM_VERSION_MAJOR >= 12
1277  llvm::StructType* polygon_type =
1278  struct_type->getTypeByName(cgen_state_->context_, struct_name);
1279 #else
1280  llvm::StructType* polygon_type = cgen_state_->module_->getTypeByName(struct_name);
1281 #endif
1282  CHECK(polygon_type);
1283 
1284  return polygon_type;
1285  }
1286  return generated_struct_type;
1287 }
1288 
1289 void CodeGenerator::codegenGeoMultiPolygonArgs(const std::string& udf_func_name,
1290  size_t param_num,
1291  llvm::Value* polygon_coords,
1292  llvm::Value* polygon_coords_size,
1293  llvm::Value* ring_sizes_buf,
1294  llvm::Value* ring_sizes,
1295  llvm::Value* polygon_bounds,
1296  llvm::Value* polygon_bounds_sizes,
1297  llvm::Value* compression,
1298  llvm::Value* input_srid,
1299  llvm::Value* output_srid,
1300  std::vector<llvm::Value*>& output_args) {
1302  CHECK(polygon_coords);
1303  CHECK(polygon_coords_size);
1304  CHECK(ring_sizes_buf);
1305  CHECK(ring_sizes);
1306  CHECK(polygon_bounds);
1307  CHECK(polygon_bounds_sizes);
1308  CHECK(compression);
1309  CHECK(input_srid);
1310  CHECK(output_srid);
1311 
1312  auto& builder = cgen_state_->ir_builder_;
1313 
1314  auto multi_polygon_abstraction = createMultiPolygonStructType(udf_func_name, param_num);
1315  auto alloc_mem = builder.CreateAlloca(multi_polygon_abstraction, nullptr);
1316 
1317  const auto polygon_coords_ptr =
1318  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 0);
1319  builder.CreateStore(polygon_coords, polygon_coords_ptr);
1320 
1321  const auto polygon_coords_size_ptr =
1322  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 1);
1323  builder.CreateStore(polygon_coords_size, polygon_coords_size_ptr);
1324 
1325  const auto ring_sizes_buf_ptr =
1326  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 2);
1327  const auto ring_sizes_ptr_ty =
1328  llvm::dyn_cast<llvm::PointerType>(ring_sizes_buf_ptr->getType());
1329  CHECK(ring_sizes_ptr_ty);
1330  builder.CreateStore(
1331  builder.CreateBitCast(ring_sizes_buf, ring_sizes_ptr_ty->getPointerElementType()),
1332  ring_sizes_buf_ptr);
1333 
1334  const auto ring_sizes_ptr =
1335  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 3);
1336  builder.CreateStore(ring_sizes, ring_sizes_ptr);
1337 
1338  const auto polygon_bounds_buf_ptr =
1339  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 4);
1340  const auto bounds_ptr_ty =
1341  llvm::dyn_cast<llvm::PointerType>(polygon_bounds_buf_ptr->getType());
1342  CHECK(bounds_ptr_ty);
1343  builder.CreateStore(
1344  builder.CreateBitCast(polygon_bounds, bounds_ptr_ty->getPointerElementType()),
1345  polygon_bounds_buf_ptr);
1346 
1347  const auto polygon_bounds_sizes_ptr =
1348  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 5);
1349  builder.CreateStore(polygon_bounds_sizes, polygon_bounds_sizes_ptr);
1350 
1351  const auto polygon_compression_ptr =
1352  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 6);
1353  builder.CreateStore(compression, polygon_compression_ptr);
1354 
1355  const auto input_srid_ptr =
1356  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 7);
1357  builder.CreateStore(input_srid, input_srid_ptr);
1358 
1359  const auto output_srid_ptr =
1360  builder.CreateStructGEP(multi_polygon_abstraction, alloc_mem, 8);
1361  builder.CreateStore(output_srid, output_srid_ptr);
1362 
1363  output_args.push_back(alloc_mem);
1364 }
1365 
1366 // Generate CAST operations for arguments in `orig_arg_lvs` to the types required by
1367 // `ext_func_sig`.
1369  const Analyzer::FunctionOper* function_oper,
1370  const ExtensionFunction* ext_func_sig,
1371  const std::vector<llvm::Value*>& orig_arg_lvs,
1372  const std::vector<size_t>& orig_arg_lvs_index,
1373  const std::unordered_map<llvm::Value*, llvm::Value*>& const_arr_size,
1374  const CompilationOptions& co) {
1376  CHECK(ext_func_sig);
1377  const auto& ext_func_args = ext_func_sig->getInputArgs();
1378  CHECK_LE(function_oper->getArity(), ext_func_args.size());
1379  const auto func_ti = function_oper->get_type_info();
1380  std::vector<llvm::Value*> args;
1381  /*
1382  i: argument in RA for the function operand
1383  j: extra offset in ext_func_args
1384  k: origin_arg_lvs counter, equal to orig_arg_lvs_index[i]
1385  ij: ext_func_args counter, equal to i + j
1386  dj: offset when UDF implementation first argument corresponds to return value
1387  */
1388  for (size_t i = 0, j = 0, dj = (func_ti.is_buffer() ? 1 : 0);
1389  i < function_oper->getArity();
1390  ++i) {
1391  size_t k = orig_arg_lvs_index[i];
1392  size_t ij = i + j;
1393  const auto arg = function_oper->getArg(i);
1394  const auto ext_func_arg = ext_func_args[ij];
1395  const auto& arg_ti = arg->get_type_info();
1396  llvm::Value* arg_lv{nullptr};
1397  if (arg_ti.is_text_encoding_none()) {
1398  CHECK(ext_func_arg == ExtArgumentType::TextEncodingNone)
1399  << ::toString(ext_func_arg);
1400  const auto ptr_lv = orig_arg_lvs[k + 1];
1401  const auto len_lv = orig_arg_lvs[k + 2];
1402  auto& builder = cgen_state_->ir_builder_;
1403  auto string_buf_arg = builder.CreatePointerCast(
1404  ptr_lv, llvm::Type::getInt8PtrTy(cgen_state_->context_));
1405  auto string_size_arg =
1406  builder.CreateZExt(len_lv, get_int_type(64, cgen_state_->context_));
1407  auto padding = ll_int<int8_t>(0, cgen_state_->context_);
1408  codegenBufferArgs(ext_func_sig->getName(),
1409  ij + dj,
1410  string_buf_arg,
1411  string_size_arg,
1412  padding,
1413  args);
1414  } else if (arg_ti.is_text_encoding_dict()) {
1415  CHECK(ext_func_arg == ExtArgumentType::TextEncodingDict)
1416  << ::toString(ext_func_arg);
1417  arg_lv = orig_arg_lvs[k];
1418  args.push_back(arg_lv);
1419  } else if (arg_ti.is_array()) {
1420  bool const_arr = (const_arr_size.count(orig_arg_lvs[k]) > 0);
1421  const auto elem_ti = arg_ti.get_elem_type();
1422  // TODO: switch to fast fixlen variants
1423  const auto ptr_lv = (const_arr)
1424  ? orig_arg_lvs[k]
1426  "array_buff",
1427  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1428  {orig_arg_lvs[k], posArg(arg)});
1429  const auto len_lv =
1430  (const_arr) ? const_arr_size.at(orig_arg_lvs[k])
1432  "array_size",
1434  {orig_arg_lvs[k],
1435  posArg(arg),
1436  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
1437 
1438  if (is_ext_arg_type_pointer(ext_func_arg)) {
1439  args.push_back(castArrayPointer(ptr_lv, elem_ti));
1440  args.push_back(cgen_state_->ir_builder_.CreateZExt(
1441  len_lv, get_int_type(64, cgen_state_->context_)));
1442  j++;
1443  } else if (is_ext_arg_type_array(ext_func_arg)) {
1444  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
1445  auto& builder = cgen_state_->ir_builder_;
1446  auto array_size_arg =
1447  builder.CreateZExt(len_lv, get_int_type(64, cgen_state_->context_));
1448  llvm::Value* array_null_arg = nullptr;
1449  if (auto gep = llvm::dyn_cast<llvm::GetElementPtrInst>(ptr_lv)) {
1450  CHECK(gep->getSourceElementType()->isArrayTy());
1451  // gep has the form
1452  // %17 = getelementptr [9 x i32], [9 x i32]* %7, i32 0
1453  // and was created by passing a const array to the UDF function:
1454  // select array_append({11, 22, 33}, 4);
1455  array_null_arg = ll_bool(false, cgen_state_->context_);
1456  } else {
1457  array_null_arg =
1458  cgen_state_->emitExternalCall("array_is_null",
1460  {orig_arg_lvs[k], posArg(arg)});
1461  }
1462  codegenBufferArgs(ext_func_sig->getName(),
1463  ij + dj,
1464  array_buf_arg,
1465  array_size_arg,
1466  array_null_arg,
1467  args);
1468  } else {
1469  UNREACHABLE();
1470  }
1471 
1472  } else if (arg_ti.is_geometry()) {
1473  auto geo_expr_arg = dynamic_cast<const Analyzer::GeoExpr*>(arg);
1474  if (geo_expr_arg) {
1475  auto ptr_lv = cgen_state_->ir_builder_.CreateBitCast(
1476  orig_arg_lvs[k], llvm::Type::getInt8PtrTy(cgen_state_->context_));
1477  args.push_back(ptr_lv);
1478  // TODO: remove when we normalize extension functions geo sizes to int32
1479  auto size_lv = cgen_state_->ir_builder_.CreateSExt(
1480  orig_arg_lvs[k + 1], llvm::Type::getInt64Ty(cgen_state_->context_));
1481  args.push_back(size_lv);
1482  j++;
1483  continue;
1484  }
1485  // Coords
1486  bool const_arr = (const_arr_size.count(orig_arg_lvs[k]) > 0);
1487  // NOTE(adb): We're generating code to handle the TINYINT array only -- the actual
1488  // geo encoding (or lack thereof) does not matter here
1489  const auto elem_ti = SQLTypeInfo(SQLTypes::kARRAY,
1490  0,
1491  0,
1492  false,
1494  0,
1496  .get_elem_type();
1497  llvm::Value* ptr_lv;
1498  llvm::Value* len_lv;
1499  int32_t fixlen = -1;
1500  if (arg_ti.get_type() == kPOINT) {
1501  const auto col_var = dynamic_cast<const Analyzer::ColumnVar*>(arg);
1502  if (col_var) {
1503  const auto coords_cd = executor()->getPhysicalColumnDescriptor(col_var, 1);
1504  if (coords_cd && coords_cd->columnType.get_type() == kARRAY) {
1505  fixlen = coords_cd->columnType.get_size();
1506  }
1507  }
1508  }
1509  if (fixlen > 0) {
1510  ptr_lv =
1511  cgen_state_->emitExternalCall("fast_fixlen_array_buff",
1512  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1513  {orig_arg_lvs[k], posArg(arg)});
1514  len_lv = cgen_state_->llInt(int32_t(fixlen));
1515  } else {
1516  // TODO: remove const_arr and related code if it's not needed
1517  ptr_lv = (const_arr) ? orig_arg_lvs[k]
1519  "array_buff",
1520  llvm::Type::getInt8PtrTy(cgen_state_->context_),
1521  {orig_arg_lvs[k], posArg(arg)});
1522  len_lv = (const_arr)
1523  ? const_arr_size.at(orig_arg_lvs[k])
1525  "array_size",
1527  {orig_arg_lvs[k],
1528  posArg(arg),
1529  cgen_state_->llInt(log2_bytes(elem_ti.get_logical_size()))});
1530  }
1531 
1532  if (is_ext_arg_type_geo(ext_func_arg)) {
1533  if (arg_ti.get_type() == kPOINT || arg_ti.get_type() == kLINESTRING ||
1534  arg_ti.get_type() == kMULTIPOINT) {
1535  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
1536  auto compression_val = codegenCompression(arg_ti);
1537  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
1538  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
1539 
1540  if (arg_ti.get_type() == kPOINT) {
1541  CHECK_EQ(k, ij);
1542  codegenGeoPointArgs(ext_func_sig->getName(),
1543  ij + dj,
1544  array_buf_arg,
1545  len_lv,
1546  compression_val,
1547  input_srid_val,
1548  output_srid_val,
1549  args);
1550  } else if (arg_ti.get_type() == kMULTIPOINT) {
1551  CHECK_EQ(k, ij);
1552  codegenGeoMultiPointArgs(ext_func_sig->getName(),
1553  ij + dj,
1554  array_buf_arg,
1555  len_lv,
1556  compression_val,
1557  input_srid_val,
1558  output_srid_val,
1559  args);
1560  } else {
1561  CHECK_EQ(k, ij);
1562  codegenGeoLineStringArgs(ext_func_sig->getName(),
1563  ij + dj,
1564  array_buf_arg,
1565  len_lv,
1566  compression_val,
1567  input_srid_val,
1568  output_srid_val,
1569  args);
1570  }
1571  }
1572  } else {
1573  CHECK(ext_func_arg == ExtArgumentType::PInt8);
1574  args.push_back(castArrayPointer(ptr_lv, elem_ti));
1575  args.push_back(cgen_state_->ir_builder_.CreateZExt(
1576  len_lv, get_int_type(64, cgen_state_->context_)));
1577  j++;
1578  }
1579 
1580  switch (arg_ti.get_type()) {
1581  case kPOINT:
1582  case kMULTIPOINT:
1583  case kLINESTRING:
1584  break;
1585  case kMULTILINESTRING: {
1586  if (ext_func_arg == ExtArgumentType::GeoMultiLineString) {
1587  auto multi_linestring_coords = castArrayPointer(ptr_lv, elem_ti);
1588  auto compression_val = codegenCompression(arg_ti);
1589  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
1590  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
1591 
1592  auto [linestring_sizes, linestring_sizes_size] =
1593  codegenArrayBuff(orig_arg_lvs[k + 1],
1594  posArg(arg),
1596  /*cast_and_extend=*/false);
1597  CHECK_EQ(k, ij);
1598  codegenGeoMultiLineStringArgs(ext_func_sig->getName(),
1599  ij + dj,
1600  multi_linestring_coords,
1601  len_lv,
1602  linestring_sizes,
1603  linestring_sizes_size,
1604  compression_val,
1605  input_srid_val,
1606  output_srid_val,
1607  args);
1608  } else {
1609  CHECK(ext_func_arg == ExtArgumentType::PInt8);
1610  // Linestring Sizes
1611  auto const_arr = const_arr_size.count(orig_arg_lvs[k + 1]) > 0;
1612  auto [linestring_sizes, linestring_sizes_size] =
1613  (const_arr) ? std::make_pair(orig_arg_lvs[k + 1],
1614  const_arr_size.at(orig_arg_lvs[k + 1]))
1615  : codegenArrayBuff(orig_arg_lvs[k + 1],
1616  posArg(arg),
1618  /*cast_and_extend=*/true);
1619  args.push_back(linestring_sizes);
1620  args.push_back(linestring_sizes_size);
1621  j += 2;
1622  }
1623  break;
1624  }
1625  case kPOLYGON: {
1626  if (ext_func_arg == ExtArgumentType::GeoPolygon) {
1627  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
1628  auto compression_val = codegenCompression(arg_ti);
1629  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
1630  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
1631 
1632  auto [ring_size_buff, ring_size] =
1633  codegenArrayBuff(orig_arg_lvs[k + 1],
1634  posArg(arg),
1636  /*cast_and_extend=*/false);
1637  CHECK_EQ(k, ij);
1638  codegenGeoPolygonArgs(ext_func_sig->getName(),
1639  ij + dj,
1640  array_buf_arg,
1641  len_lv,
1642  ring_size_buff,
1643  ring_size,
1644  compression_val,
1645  input_srid_val,
1646  output_srid_val,
1647  args);
1648  } else {
1649  CHECK(ext_func_arg == ExtArgumentType::PInt8);
1650  // Ring Sizes
1651  auto const_arr = const_arr_size.count(orig_arg_lvs[k + 1]) > 0;
1652  auto [ring_size_buff, ring_size] =
1653  (const_arr) ? std::make_pair(orig_arg_lvs[k + 1],
1654  const_arr_size.at(orig_arg_lvs[k + 1]))
1655  : codegenArrayBuff(orig_arg_lvs[k + 1],
1656  posArg(arg),
1658  /*cast_and_extend=*/true);
1659  args.push_back(ring_size_buff);
1660  args.push_back(ring_size);
1661  j += 2;
1662  }
1663  break;
1664  }
1665  case kMULTIPOLYGON: {
1666  if (ext_func_arg == ExtArgumentType::GeoMultiPolygon) {
1667  auto array_buf_arg = castArrayPointer(ptr_lv, elem_ti);
1668  auto compression_val = codegenCompression(arg_ti);
1669  auto input_srid_val = cgen_state_->llInt(arg_ti.get_input_srid());
1670  auto output_srid_val = cgen_state_->llInt(arg_ti.get_output_srid());
1671 
1672  auto [ring_size_buff, ring_size] =
1673  codegenArrayBuff(orig_arg_lvs[k + 1],
1674  posArg(arg),
1676  /*cast_and_extend=*/false);
1677 
1678  auto [poly_bounds_buff, poly_bounds_size] =
1679  codegenArrayBuff(orig_arg_lvs[k + 2],
1680  posArg(arg),
1682  /*cast_and_extend=*/false);
1683  CHECK_EQ(k, ij);
1684  codegenGeoMultiPolygonArgs(ext_func_sig->getName(),
1685  ij + dj,
1686  array_buf_arg,
1687  len_lv,
1688  ring_size_buff,
1689  ring_size,
1690  poly_bounds_buff,
1691  poly_bounds_size,
1692  compression_val,
1693  input_srid_val,
1694  output_srid_val,
1695  args);
1696  } else {
1697  CHECK(ext_func_arg == ExtArgumentType::PInt8);
1698  // Ring Sizes
1699  {
1700  auto const_arr = const_arr_size.count(orig_arg_lvs[k + 1]) > 0;
1701  auto [ring_size_buff, ring_size] =
1702  (const_arr) ? std::make_pair(orig_arg_lvs[k + 1],
1703  const_arr_size.at(orig_arg_lvs[k + 1]))
1704  : codegenArrayBuff(orig_arg_lvs[k + 1],
1705  posArg(arg),
1707  /*cast_and_extend=*/true);
1708 
1709  args.push_back(ring_size_buff);
1710  args.push_back(ring_size);
1711  }
1712  // Poly Rings
1713  {
1714  auto const_arr = const_arr_size.count(orig_arg_lvs[k + 2]) > 0;
1715  auto [poly_bounds_buff, poly_bounds_size] =
1716  (const_arr)
1717  ? std::make_pair(orig_arg_lvs[k + 2],
1718  const_arr_size.at(orig_arg_lvs[k + 2]))
1719  : codegenArrayBuff(
1720  orig_arg_lvs[k + 2], posArg(arg), SQLTypes::kINT, true);
1721 
1722  args.push_back(poly_bounds_buff);
1723  args.push_back(poly_bounds_size);
1724  }
1725  j += 4;
1726  }
1727  break;
1728  }
1729  default:
1730  CHECK(false);
1731  }
1732  } else {
1733  CHECK(is_ext_arg_type_scalar(ext_func_arg));
1734  const auto arg_target_ti = ext_arg_type_to_type_info(ext_func_arg);
1735  if (arg_ti.get_type() != arg_target_ti.get_type()) {
1736  arg_lv = codegenCast(orig_arg_lvs[k], arg_ti, arg_target_ti, false, co);
1737  } else {
1738  arg_lv = orig_arg_lvs[k];
1739  }
1740  CHECK_EQ(arg_lv->getType(),
1741  ext_arg_type_to_llvm_type(ext_func_arg, cgen_state_->context_));
1742  args.push_back(arg_lv);
1743  }
1744  }
1745  return args;
1746 }
1747 
1748 llvm::Value* CodeGenerator::castArrayPointer(llvm::Value* ptr,
1749  const SQLTypeInfo& elem_ti) {
1751  if (elem_ti.get_type() == kFLOAT) {
1752  return cgen_state_->ir_builder_.CreatePointerCast(
1753  ptr, llvm::Type::getFloatPtrTy(cgen_state_->context_));
1754  }
1755  if (elem_ti.get_type() == kDOUBLE) {
1756  return cgen_state_->ir_builder_.CreatePointerCast(
1757  ptr, llvm::Type::getDoublePtrTy(cgen_state_->context_));
1758  }
1759  CHECK(elem_ti.is_integer() || elem_ti.is_boolean() ||
1760  (elem_ti.is_string() && elem_ti.get_compression() == kENCODING_DICT));
1761  switch (elem_ti.get_size()) {
1762  case 1:
1763  return cgen_state_->ir_builder_.CreatePointerCast(
1764  ptr, llvm::Type::getInt8PtrTy(cgen_state_->context_));
1765  case 2:
1766  return cgen_state_->ir_builder_.CreatePointerCast(
1767  ptr, llvm::Type::getInt16PtrTy(cgen_state_->context_));
1768  case 4:
1769  return cgen_state_->ir_builder_.CreatePointerCast(
1770  ptr, llvm::Type::getInt32PtrTy(cgen_state_->context_));
1771  case 8:
1772  return cgen_state_->ir_builder_.CreatePointerCast(
1773  ptr, llvm::Type::getInt64PtrTy(cgen_state_->context_));
1774  default:
1775  CHECK(false);
1776  }
1777  return nullptr;
1778 }
1779 
1780 // Reflects struct StringView defined in Shared/Datum.h
1782  auto* const string_view_type =
1783  llvm::StructType::get(cgen_state_->context_,
1784  {llvm::Type::getInt8PtrTy(cgen_state_->context_),
1785  llvm::Type::getInt64Ty(cgen_state_->context_)});
1786  string_view_type->setName("StringView");
1787  return string_view_type;
1788 }
llvm::StructType * createLineStringStructType(const std::string &udf_func_name, size_t param_num)
void codegenGeoMultiPolygonArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *polygon_coords, llvm::Value *polygon_coords_size, llvm::Value *ring_sizes_buf, llvm::Value *ring_sizes, llvm::Value *polygon_bounds, llvm::Value *polygon_bounds_sizes, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
#define CHECK_EQ(x, y)
Definition: Logger.h:301
HOST DEVICE int get_size() const
Definition: sqltypes.h:403
bool is_ext_arg_type_scalar(const ExtArgumentType ext_arg_type)
llvm::BasicBlock * args_notnull_bb
size_t getArity() const
Definition: Analyzer.h:2746
SQLTypes
Definition: sqltypes.h:65
std::unique_ptr< llvm::Module > udf_gpu_module
CgenState * cgen_state_
const ExtArgumentType getRet() const
void codegenGeoPolygonArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *polygon_buf, llvm::Value *polygon_size, llvm::Value *ring_sizes_buf, llvm::Value *num_rings, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
llvm::StructType * createMultiPointStructType(const std::string &udf_func_name, size_t param_num)
#define LOG(tag)
Definition: Logger.h:285
std::vector< llvm::Value * > codegenFunctionOperCastArgs(const Analyzer::FunctionOper *, const ExtensionFunction *, const std::vector< llvm::Value * > &, const std::vector< size_t > &, const std::unordered_map< llvm::Value *, llvm::Value * > &, const CompilationOptions &)
llvm::StructType * createMultiLineStringStructType(const std::string &udf_func_name, size_t param_num)
llvm::Value * codegenFunctionOperNullArg(const Analyzer::FunctionOper *, const std::vector< llvm::Value * > &)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:384
llvm::Value * posArg(const Analyzer::Expr *) const
Definition: ColumnIR.cpp:590
llvm::Value * castArrayPointer(llvm::Value *ptr, const SQLTypeInfo &elem_ti)
#define UNREACHABLE()
Definition: Logger.h:338
#define CHECK_GE(x, y)
Definition: Logger.h:306
Definition: sqldefs.h:51
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
llvm::StructType * createPointStructType(const std::string &udf_func_name, size_t param_num)
bool call_requires_custom_type_handling(const Analyzer::FunctionOper *function_oper)
const std::string getName(bool keep_suffix=true) const
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:391
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
bool ext_func_call_requires_nullcheck(const Analyzer::FunctionOper *function_oper)
SQLTypeInfo get_sql_type_from_llvm_type(const llvm::Type *ll_type)
llvm::StructType * get_buffer_struct_type(CgenState *cgen_state, const std::string &ext_func_name, size_t param_num, llvm::Type *elem_type)
std::vector< FunctionOperValue > ext_call_cache_
Definition: CgenState.h:390
void codegenBufferArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *buffer_buf, llvm::Value *buffer_size, llvm::Value *buffer_is_null, std::vector< llvm::Value * > &output_args)
llvm::Function * row_func_
Definition: CgenState.h:374
RUNTIME_EXPORT void register_buffer_with_executor_rsm(int64_t exec, int8_t *buffer)
std::pair< llvm::Value *, llvm::Value * > codegenArrayBuff(llvm::Value *chunk, llvm::Value *row_pos, SQLTypes array_type, bool cast_and_extend)
llvm::Module * module_
Definition: CgenState.h:373
Supported runtime functions management and retrieval.
llvm::LLVMContext & context_
Definition: CgenState.h:382
llvm::Function * current_func_
Definition: CgenState.h:376
std::tuple< ArgNullcheckBBs, llvm::Value * > beginArgsNullcheck(const Analyzer::FunctionOper *function_oper, const std::vector< llvm::Value * > &orig_arg_lvs)
llvm::Value * emitExternalCall(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, const std::vector< llvm::Attribute::AttrKind > &fnattrs={}, const bool has_struct_return=false)
Definition: CgenState.cpp:395
llvm::Value * get_arg_by_name(llvm::Function *func, const std::string &name)
Definition: Execute.h:168
bool is_integer() const
Definition: sqltypes.h:567
DEVICE auto copy(ARGS &&...args)
Definition: gpu_enabled.h:51
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:65
bool is_ext_arg_type_geo(const ExtArgumentType ext_arg_type)
void codegenGeoMultiPointArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *multi_point_buf, llvm::Value *multi_point_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
bool is_ext_arg_type_array(const ExtArgumentType ext_arg_type)
llvm::Value * codegenFunctionOper(const Analyzer::FunctionOper *, const CompilationOptions &)
llvm::Type * get_llvm_type_from_sql_array_type(const SQLTypeInfo ti, llvm::LLVMContext &ctx)
bool is_boolean() const
Definition: sqltypes.h:582
llvm::BasicBlock * args_null_bb
#define AUTOMATIC_IR_METADATA(CGENSTATE)
llvm::Type * ext_arg_type_to_llvm_type(const ExtArgumentType ext_arg_type, llvm::LLVMContext &ctx)
void codegenGeoMultiLineStringArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *multi_linestring_coords, llvm::Value *multi_linestring_size, llvm::Value *linestring_sizes, llvm::Value *linestring_sizes_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
std::string toString(const Executor::ExtModuleKinds &kind)
Definition: Execute.h:1703
std::tuple< T, std::vector< SQLTypeInfo > > bind_function(std::string name, Analyzer::ExpressionPtrVector func_args, const std::vector< T > &ext_funcs, const std::string processor)
Argument type based extension function binding.
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:217
bool is_buffer() const
Definition: sqltypes.h:623
ExecutorDeviceType device_type
void codegenGeoPointArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *point_buf, llvm::Value *point_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
#define RUNTIME_EXPORT
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:30
#define CHECK_LE(x, y)
Definition: Logger.h:304
HOST DEVICE EncodingType get_compression() const
Definition: sqltypes.h:399
std::string serialize_llvm_object(const T *llvm_obj)
bool isLocalAlloc() const
Definition: Analyzer.h:3024
llvm::StructType * createPolygonStructType(const std::string &udf_func_name, size_t param_num)
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:2748
const Expr * get_operand() const
Definition: Analyzer.h:384
llvm::Value * endArgsNullcheck(const ArgNullcheckBBs &, llvm::Value *, llvm::Value *, const Analyzer::FunctionOper *)
const std::vector< ExtArgumentType > & getInputArgs() const
llvm::StructType * createStringViewStructType()
std::unique_ptr< llvm::Module > udf_cpu_module
HOST DEVICE int get_comp_param() const
Definition: sqltypes.h:402
llvm::Value * codegenFunctionOperWithCustomTypeHandling(const Analyzer::FunctionOperWithCustomTypeHandling *, const CompilationOptions &)
llvm::ConstantInt * llInt(const T v) const
Definition: CgenState.h:249
#define CHECK(condition)
Definition: Logger.h:291
llvm::Value * codegenIsNullNumber(llvm::Value *, const SQLTypeInfo &)
Definition: LogicalIR.cpp:416
uint64_t exp_to_scale(const unsigned exp)
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
llvm::Value * codegenCompression(const SQLTypeInfo &type_info)
llvm::Value * codegenCast(const Analyzer::UOper *, const CompilationOptions &)
Definition: CastIR.cpp:21
uint32_t log2_bytes(const uint32_t bytes)
Definition: Execute.h:198
Definition: sqltypes.h:72
bool is_string() const
Definition: sqltypes.h:561
std::string getName() const
Definition: Analyzer.h:2744
void codegenGeoLineStringArgs(const std::string &udf_func_name, size_t param_num, llvm::Value *line_string_buf, llvm::Value *line_string_size, llvm::Value *compression, llvm::Value *input_srid, llvm::Value *output_srid, std::vector< llvm::Value * > &output_args)
bool is_text_encoding_none() const
Definition: sqltypes.h:614
bool is_ext_arg_type_pointer(const ExtArgumentType ext_arg_type)
SQLTypeInfo get_elem_type() const
Definition: sqltypes.h:977
bool is_decimal() const
Definition: sqltypes.h:570
int get_physical_coord_cols() const
Definition: sqltypes.h:451
SQLTypeInfo ext_arg_type_to_type_info(const ExtArgumentType ext_arg_type)
llvm::ConstantFP * inlineFpNull(const SQLTypeInfo &)
Definition: CgenState.cpp:104
Executor * executor() const
llvm::StructType * createMultiPolygonStructType(const std::string &udf_func_name, size_t param_num)