OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
QueryTemplateGenerator.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "QueryTemplateGenerator.h"
18 #include "IRCodegenUtils.h"
19 #include "Logger/Logger.h"
20 
21 #include <llvm/IR/Constants.h>
22 #include <llvm/IR/IRBuilder.h>
23 #include <llvm/IR/Instructions.h>
24 #include <llvm/IR/Verifier.h>
25 
26 // This file was pretty much auto-generated by running:
27 // llc -march=cpp RuntimeFunctions.ll
28 // and formatting the results to be more readable.
29 
30 namespace {
31 
32 template <typename... ATTRS>
33 llvm::AttributeList make_attribute_list(llvm::Module const* const mod,
34  unsigned const index,
35  ATTRS const... attrs) {
36  static_assert((std::is_same_v<llvm::Attribute::AttrKind, ATTRS> && ...));
37  // llvm::AttrBuilder basically wraps a llvm::SmallVector<llvm::Attribute, 8>.
38  static_assert(sizeof...(ATTRS) <= 8, "Use a llvm::SmallVector with a larger size.");
39 #if 14 <= LLVM_VERSION_MAJOR
40  llvm::AttrBuilder attr_builder(mod->getContext());
41 #else
42  llvm::AttrBuilder attr_builder;
43 #endif
44  (attr_builder.addAttribute(attrs), ...);
45  return llvm::AttributeList::get(mod->getContext(), index, attr_builder);
46 }
47 
48 // NTYPES = max number of types
49 template <size_t NTYPES>
50 class Params {
51  llvm::Module const* const mod_;
52  llvm::SmallVector<llvm::Type*, NTYPES> types_;
53  llvm::SmallVector<char const*, NTYPES> names_;
54  // +1 is for extra call to addAttributes() to add UWTable function attribute
55  llvm::SmallVector<llvm::AttributeList, NTYPES + 1> attrs_;
56 
57  public:
58  Params(llvm::Module const* const mod) : mod_(mod) {}
59 
60  template <typename... ATTRS>
61  void addAttributes(unsigned const index, ATTRS const... attrs) {
62  static_assert((std::is_same_v<llvm::Attribute::AttrKind, ATTRS> && ...));
63  attrs_.push_back(make_attribute_list(mod_, index, attrs...));
64  }
65 
66  llvm::AttributeList attributeList() const {
67  return llvm::AttributeList::get(mod_->getContext(), attrs_);
68  }
69 
70  template <typename... ATTRS>
71  void pushBack(llvm::Type* const type, char const* const name, ATTRS const... attrs) {
72  static_assert((std::is_same_v<llvm::Attribute::AttrKind, ATTRS> && ...));
73  types_.push_back(type);
74  names_.push_back(name);
75  if constexpr (0u < sizeof...(ATTRS)) {
76  static_assert(1u == llvm::AttributeList::AttrIndex::FirstArgIndex);
77  addAttributes(types_.size(), attrs...);
78  }
79  }
80 
81  void setNames(llvm::Function::arg_iterator itr) const {
82  for (char const* const name : names_) {
83  itr++->setName(name);
84  }
85  }
86 
87  auto& types() { return types_; }
88 };
89 
90 // NTYPES = max number of types. Used by llvm::SmallVector to avoid dynamic memory allocs.
91 template <bool IS_GROUP_BY, size_t NTYPES = 13u>
92 Params<NTYPES> make_params(llvm::Module const* const mod, bool const hoist_literals) {
93  constexpr llvm::Attribute::AttrKind NoCapture = llvm::Attribute::NoCapture;
94  auto* const i8_type = llvm::IntegerType::get(mod->getContext(), 8);
95  auto* const i32_type = llvm::IntegerType::get(mod->getContext(), 32);
96  auto* const i64_type = llvm::IntegerType::get(mod->getContext(), 64);
97  auto* const pi8_type = llvm::PointerType::get(i8_type, 0);
98  auto* const ppi8_type = llvm::PointerType::get(pi8_type, 0);
99  auto* const pi32_type = llvm::PointerType::get(i32_type, 0);
100  auto* const pi64_type = llvm::PointerType::get(i64_type, 0);
101  auto* const ppi64_type = llvm::PointerType::get(pi64_type, 0);
102 
103  // Must match parameter order in QueryExecutionContext::launchCpuCode()
104  // hoist_literals is true iff literals is included in the parameter list.
105  // NTYPES should equal the max number of parameters to avoid dynamic memory allocation.
106  Params<NTYPES> params(mod);
107  params.pushBack(pi32_type, "error_code");
108  params.pushBack(pi32_type, "total_matched");
109  params.pushBack(ppi64_type, IS_GROUP_BY ? "group_by_buffers" : "out");
110  params.pushBack(i32_type, "frag_idx");
111  if constexpr (IS_GROUP_BY) {
112  constexpr llvm::Attribute::AttrKind ReadOnly = llvm::Attribute::ReadOnly;
113  constexpr llvm::Attribute::AttrKind UWTable = llvm::Attribute::UWTable;
114  params.pushBack(pi32_type, "row_index_resume", NoCapture, ReadOnly);
115  params.pushBack(ppi8_type, "byte_stream", NoCapture, ReadOnly);
116  if (hoist_literals) {
117  params.pushBack(pi8_type, "literals", NoCapture, ReadOnly);
118  }
119  params.pushBack(pi64_type, "row_count_ptr", NoCapture, ReadOnly);
120  params.pushBack(pi64_type, "frag_row_off_ptr", NoCapture, ReadOnly);
121  params.pushBack(pi32_type, "max_matched_ptr", NoCapture, ReadOnly);
122  params.pushBack(pi64_type, "agg_init_val", NoCapture, ReadOnly);
123  params.pushBack(pi64_type, "join_hash_tables", NoCapture, ReadOnly);
124  params.pushBack(pi8_type, "row_func_mgr", NoCapture, ReadOnly);
125  params.addAttributes(llvm::AttributeList::AttrIndex::FunctionIndex, UWTable);
126  } else {
127  // For an unknown reason, commit 70ab189189cc0599d973f3f021169a6846298cf5
128  // removed the ReadOnly and UWTable attributes for (non-group_by) query_template()
129  // but kept them for query_group_by_template().
130  params.pushBack(pi32_type, "row_index_resume", NoCapture); // start_rowid
131  params.pushBack(ppi8_type, "byte_stream", NoCapture); // col_buffers
132  if (hoist_literals) {
133  params.pushBack(pi8_type, "literals", NoCapture);
134  }
135  params.pushBack(pi64_type, "row_count_ptr", NoCapture); // num_rows
136  params.pushBack(pi64_type, "frag_row_off_ptr", NoCapture);
137  params.pushBack(pi32_type, "max_matched_ptr", NoCapture);
138  params.pushBack(pi64_type, "agg_init_val", NoCapture);
139  params.pushBack(pi64_type, "join_hash_tables", NoCapture);
140  params.pushBack(pi8_type, "row_func_mgr", NoCapture);
141  }
142  return params;
143 }
144 
145 inline llvm::Type* get_pointer_element_type(llvm::Value* value) {
146  CHECK(value);
147  auto type = value->getType();
148  CHECK(type && type->isPointerTy());
149  auto pointer_type = llvm::dyn_cast<llvm::PointerType>(type);
150  CHECK(pointer_type);
151  return pointer_type->getPointerElementType();
152 }
153 
154 llvm::Function* default_func_builder(llvm::Module* mod, const std::string& name) {
155  using namespace llvm;
156 
157  std::vector<Type*> func_args;
158  FunctionType* func_type = FunctionType::get(
159  /*Result=*/IntegerType::get(mod->getContext(), 32),
160  /*Params=*/func_args,
161  /*isVarArg=*/false);
162 
163  auto func_ptr = mod->getFunction(name);
164  if (!func_ptr) {
165  func_ptr = Function::Create(
166  /*Type=*/func_type,
167  /*Linkage=*/GlobalValue::ExternalLinkage,
168  /*Name=*/name,
169  mod); // (external, no body)
170  func_ptr->setCallingConv(CallingConv::C);
171  }
172  func_ptr->setAttributes(
173  make_attribute_list(mod, llvm::AttributeList::AttrIndex::FunctionIndex));
174  return func_ptr;
175 }
176 
177 llvm::Function* pos_start(llvm::Module* mod) {
178  return default_func_builder(mod, "pos_start");
179 }
180 
181 llvm::Function* group_buff_idx(llvm::Module* mod) {
182  return default_func_builder(mod, "group_buff_idx");
183 }
184 
185 llvm::Function* pos_step(llvm::Module* mod) {
186  using namespace llvm;
187 
188  std::vector<Type*> func_args;
189  FunctionType* func_type = FunctionType::get(
190  /*Result=*/IntegerType::get(mod->getContext(), 32),
191  /*Params=*/func_args,
192  /*isVarArg=*/false);
193 
194  auto func_ptr = mod->getFunction("pos_step");
195  if (!func_ptr) {
196  func_ptr = Function::Create(
197  /*Type=*/func_type,
198  /*Linkage=*/GlobalValue::ExternalLinkage,
199  /*Name=*/"pos_step",
200  mod); // (external, no body)
201  func_ptr->setCallingConv(CallingConv::C);
202  }
203  func_ptr->setAttributes(
204  make_attribute_list(mod, llvm::AttributeList::AttrIndex::FunctionIndex));
205  return func_ptr;
206 }
207 
208 llvm::Function* row_process(llvm::Module* mod,
209  const size_t aggr_col_count,
210  const bool hoist_literals) {
211  using namespace llvm;
212 
213  std::vector<Type*> func_args;
214  auto i8_type = IntegerType::get(mod->getContext(), 8);
215  auto i32_type = IntegerType::get(mod->getContext(), 32);
216  auto i64_type = IntegerType::get(mod->getContext(), 64);
217  auto pi32_type = PointerType::get(i32_type, 0);
218  auto pi64_type = PointerType::get(i64_type, 0);
219 
220  if (aggr_col_count) {
221  for (size_t i = 0; i < aggr_col_count; ++i) {
222  func_args.push_back(pi64_type);
223  }
224  } else { // group by query
225  func_args.push_back(pi64_type); // groups buffer
226  func_args.push_back(pi64_type); // varlen output buffer
227  func_args.push_back(pi32_type); // 1 iff current row matched, else 0
228  func_args.push_back(pi32_type); // total rows matched from the caller
229  func_args.push_back(pi32_type); // total rows matched before atomic increment
230  func_args.push_back(pi32_type); // max number of slots in the output buffer
231  }
232 
233  func_args.push_back(pi64_type); // aggregate init values
234 
235  func_args.push_back(i64_type); // pos
236  func_args.push_back(pi64_type); // frag_row_off_ptr
237  func_args.push_back(pi64_type); // row_count_ptr
238  if (hoist_literals) {
239  func_args.push_back(PointerType::get(i8_type, 0)); // literals
240  }
241  FunctionType* func_type = FunctionType::get(
242  /*Result=*/i32_type,
243  /*Params=*/func_args,
244  /*isVarArg=*/false);
245 
246  std::string func_name{"row_process"};
247  auto func_ptr = mod->getFunction(func_name);
248 
249  if (!func_ptr) {
250  func_ptr = Function::Create(
251  /*Type=*/func_type,
252  /*Linkage=*/GlobalValue::ExternalLinkage,
253  /*Name=*/func_name,
254  mod); // (external, no body)
255  func_ptr->setCallingConv(CallingConv::C);
256  func_ptr->setAttributes(
257  make_attribute_list(mod, llvm::AttributeList::AttrIndex::FunctionIndex));
258  }
259 
260  return func_ptr;
261 }
262 
263 } // namespace
264 
265 // Return pair (query_func, row_func_call)
266 std::tuple<llvm::Function*, llvm::CallInst*> query_template(
267  llvm::Module* mod,
268  const size_t aggr_col_count,
269  const bool hoist_literals,
270  const bool is_estimate_query,
271  const GpuSharedMemoryContext& gpu_smem_context) {
272  using namespace llvm;
273 
274  auto* const i32_type = llvm::IntegerType::get(mod->getContext(), 32);
275  auto* const i64_type = llvm::IntegerType::get(mod->getContext(), 64);
276 
277  llvm::Function* const func_pos_start = pos_start(mod);
278  CHECK(func_pos_start);
279  llvm::Function* const func_pos_step = pos_step(mod);
280  CHECK(func_pos_step);
281  llvm::Function* const func_group_buff_idx = group_buff_idx(mod);
282  CHECK(func_group_buff_idx);
283  llvm::Function* const func_row_process =
284  row_process(mod, is_estimate_query ? 1 : aggr_col_count, hoist_literals);
285  CHECK(func_row_process);
286 
287  constexpr bool IS_GROUP_BY = false;
288  Params query_func_params = make_params<IS_GROUP_BY>(mod, hoist_literals);
289 
290  FunctionType* query_func_type = FunctionType::get(
291  /*Result=*/Type::getVoidTy(mod->getContext()),
292  /*Params=*/query_func_params.types(),
293  /*isVarArg=*/false);
294 
295  std::string query_template_name{"query_template"};
296  auto query_func_ptr = mod->getFunction(query_template_name);
297  CHECK(!query_func_ptr);
298 
299  query_func_ptr = Function::Create(
300  /*Type=*/query_func_type,
301  /*Linkage=*/GlobalValue::ExternalLinkage,
302  /*Name=*/query_template_name,
303  mod);
304  query_func_ptr->setCallingConv(CallingConv::C);
305  query_func_ptr->setAttributes(query_func_params.attributeList());
306  query_func_params.setNames(query_func_ptr->arg_begin());
307 
308  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
309  auto bb_preheader =
310  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
311  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".for.body", query_func_ptr, 0);
312  auto bb_crit_edge =
313  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
314  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
315 
316  // Block (.entry)
317  llvm::Value* const agg_init_val = get_arg_by_name(query_func_ptr, "agg_init_val");
318  std::vector<Value*> result_ptr_vec;
319  llvm::CallInst* smem_output_buffer{nullptr};
320  if (!is_estimate_query) {
321  for (size_t i = 0; i < aggr_col_count; ++i) {
322  auto result_ptr = new AllocaInst(i64_type, 0, "result", bb_entry);
323  result_ptr->setAlignment(LLVM_ALIGN(8));
324  result_ptr_vec.push_back(result_ptr);
325  }
326  if (gpu_smem_context.isSharedMemoryUsed()) {
327  auto init_smem_func = mod->getFunction("init_shared_mem");
328  CHECK(init_smem_func);
329  // only one slot per aggregate column is needed, and so we can initialize shared
330  // memory buffer for intermediate results to be exactly like the agg_init_val array
331  smem_output_buffer = CallInst::Create(
332  init_smem_func,
333  std::vector<llvm::Value*>{
334  agg_init_val,
335  llvm::ConstantInt::get(i32_type, aggr_col_count * sizeof(int64_t))},
336  "smem_buffer",
337  bb_entry);
338  }
339  }
340 
341  llvm::Value* const row_count_ptr = get_arg_by_name(query_func_ptr, "row_count_ptr");
342  LoadInst* row_count = new LoadInst(get_pointer_element_type(row_count_ptr),
343  row_count_ptr,
344  "row_count",
345  false,
346  bb_entry);
347  row_count->setAlignment(LLVM_ALIGN(8));
348  row_count->setName("row_count");
349  std::vector<Value*> agg_init_val_vec;
350  if (!is_estimate_query) {
351  for (size_t i = 0; i < aggr_col_count; ++i) {
352  auto idx_lv = ConstantInt::get(i32_type, i);
353  auto agg_init_gep = GetElementPtrInst::CreateInBounds(
354  agg_init_val->getType()->getPointerElementType(),
355  agg_init_val,
356  idx_lv,
357  "",
358  bb_entry);
359  auto agg_init_val = new LoadInst(
360  get_pointer_element_type(agg_init_gep), agg_init_gep, "", false, bb_entry);
361  agg_init_val->setAlignment(LLVM_ALIGN(8));
362  agg_init_val_vec.push_back(agg_init_val);
363  auto init_val_st = new StoreInst(agg_init_val, result_ptr_vec[i], false, bb_entry);
364  init_val_st->setAlignment(LLVM_ALIGN(8));
365  }
366  }
367 
368  CallInst* pos_start = CallInst::Create(func_pos_start, "pos_start", bb_entry);
369  pos_start->setCallingConv(CallingConv::C);
370  pos_start->setTailCall(true);
371  llvm::AttributeList pos_start_pal;
372  pos_start->setAttributes(pos_start_pal);
373 
374  CallInst* pos_step = CallInst::Create(func_pos_step, "pos_step", bb_entry);
375  pos_step->setCallingConv(CallingConv::C);
376  pos_step->setTailCall(true);
377  llvm::AttributeList pos_step_pal;
378  pos_step->setAttributes(pos_step_pal);
379 
380  CallInst* group_buff_idx = nullptr;
381  if (!is_estimate_query) {
382  group_buff_idx = CallInst::Create(func_group_buff_idx, "group_buff_idx", bb_entry);
383  group_buff_idx->setCallingConv(CallingConv::C);
384  group_buff_idx->setTailCall(true);
385  llvm::AttributeList group_buff_idx_pal;
386  group_buff_idx->setAttributes(group_buff_idx_pal);
387  }
388 
389  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
390  ICmpInst* enter_or_not =
391  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
392  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
393 
394  // Block .loop.preheader
395  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
396  BranchInst::Create(bb_forbody, bb_preheader);
397 
398  // Block .forbody
399  Argument* pos_inc_pre = new Argument(i64_type);
400  PHINode* pos = PHINode::Create(i64_type, 2, "pos", bb_forbody);
401  pos->addIncoming(pos_start_i64, bb_preheader);
402  pos->addIncoming(pos_inc_pre, bb_forbody);
403 
404  std::vector<Value*> row_process_params;
405  llvm::Value* const out = get_arg_by_name(query_func_ptr, "out");
406  row_process_params.insert(
407  row_process_params.end(), result_ptr_vec.begin(), result_ptr_vec.end());
408  if (is_estimate_query) {
409  row_process_params.push_back(
410  new LoadInst(get_pointer_element_type(out), out, "", false, bb_forbody));
411  }
412  row_process_params.push_back(agg_init_val);
413  row_process_params.push_back(pos);
414  row_process_params.push_back(get_arg_by_name(query_func_ptr, "frag_row_off_ptr"));
415  row_process_params.push_back(row_count_ptr);
416  if (hoist_literals) {
417  row_process_params.push_back(get_arg_by_name(query_func_ptr, "literals"));
418  }
419  CallInst* row_process =
420  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
421  row_process->setCallingConv(CallingConv::C);
422  row_process->setTailCall(false);
423  llvm::AttributeList row_process_pal;
424  row_process->setAttributes(row_process_pal);
425 
426  BinaryOperator* pos_inc =
427  BinaryOperator::CreateNSW(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
428  ICmpInst* loop_or_exit =
429  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
430  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
431 
432  // Block ._crit_edge
433  std::vector<Instruction*> result_vec_pre;
434  if (!is_estimate_query) {
435  for (size_t i = 0; i < aggr_col_count; ++i) {
436  auto result = new LoadInst(get_pointer_element_type(result_ptr_vec[i]),
437  result_ptr_vec[i],
438  ".pre.result",
439  false,
440  bb_crit_edge);
441  result->setAlignment(LLVM_ALIGN(8));
442  result_vec_pre.push_back(result);
443  }
444  }
445 
446  BranchInst::Create(bb_exit, bb_crit_edge);
447 
448  // Block .exit
460  if (!is_estimate_query) {
461  std::vector<PHINode*> result_vec;
462  for (int64_t i = aggr_col_count - 1; i >= 0; --i) {
463  auto result =
464  PHINode::Create(IntegerType::get(mod->getContext(), 64), 2, "", bb_exit);
465  result->addIncoming(result_vec_pre[i], bb_crit_edge);
466  result->addIncoming(agg_init_val_vec[i], bb_entry);
467  result_vec.insert(result_vec.begin(), result);
468  }
469 
470  llvm::Value* const frag_idx = get_arg_by_name(query_func_ptr, "frag_idx");
471  for (size_t i = 0; i < aggr_col_count; ++i) {
472  auto col_idx = ConstantInt::get(i32_type, i);
473  if (gpu_smem_context.isSharedMemoryUsed()) {
474  auto target_addr = GetElementPtrInst::CreateInBounds(
475  smem_output_buffer->getType()->getPointerElementType(),
476  smem_output_buffer,
477  col_idx,
478  "",
479  bb_exit);
480  // TODO: generalize this once we want to support other types of aggregate
481  // functions besides COUNT.
482  auto agg_func = mod->getFunction("agg_sum_shared");
483  CHECK(agg_func);
484  CallInst::Create(
485  agg_func, std::vector<llvm::Value*>{target_addr, result_vec[i]}, "", bb_exit);
486  } else {
487  auto out_gep = GetElementPtrInst::CreateInBounds(
488  out->getType()->getPointerElementType(), out, col_idx, "", bb_exit);
489  auto col_buffer =
490  new LoadInst(get_pointer_element_type(out_gep), out_gep, "", false, bb_exit);
491  col_buffer->setAlignment(LLVM_ALIGN(8));
492  auto slot_idx = BinaryOperator::CreateAdd(
494  BinaryOperator::CreateMul(frag_idx, pos_step, "", bb_exit),
495  "",
496  bb_exit);
497  auto target_addr = GetElementPtrInst::CreateInBounds(
498  col_buffer->getType()->getPointerElementType(),
499  col_buffer,
500  slot_idx,
501  "",
502  bb_exit);
503  StoreInst* result_st = new StoreInst(result_vec[i], target_addr, false, bb_exit);
504  result_st->setAlignment(LLVM_ALIGN(8));
505  }
506  }
507  if (gpu_smem_context.isSharedMemoryUsed()) {
508  // final reduction of results from shared memory buffer back into global memory.
509  auto sync_thread_func = mod->getFunction("sync_threadblock");
510  CHECK(sync_thread_func);
511  CallInst::Create(sync_thread_func, std::vector<llvm::Value*>{}, "", bb_exit);
512  auto reduce_smem_to_gmem_func = mod->getFunction("write_back_non_grouped_agg");
513  CHECK(reduce_smem_to_gmem_func);
514  // each thread reduce the aggregate target corresponding to its own thread ID.
515  // If there are more targets than threads we do not currently use shared memory
516  // optimization. This can be relaxed if necessary
517  for (size_t i = 0; i < aggr_col_count; i++) {
518  auto out_gep =
519  GetElementPtrInst::CreateInBounds(out->getType()->getPointerElementType(),
520  out,
521  ConstantInt::get(i32_type, i),
522  "",
523  bb_exit);
524  auto gmem_output_buffer = new LoadInst(get_pointer_element_type(out_gep),
525  out_gep,
526  "gmem_output_buffer_" + std::to_string(i),
527  false,
528  bb_exit);
529  CallInst::Create(
530  reduce_smem_to_gmem_func,
531  std::vector<llvm::Value*>{
532  smem_output_buffer, gmem_output_buffer, ConstantInt::get(i32_type, i)},
533  "",
534  bb_exit);
535  }
536  }
537  }
538 
539  ReturnInst::Create(mod->getContext(), bb_exit);
540 
541  // Resolve Forward References
542  pos_inc_pre->replaceAllUsesWith(pos_inc);
543  delete pos_inc_pre;
544 
545  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
546  LOG(FATAL) << "Generated invalid code.";
547  }
548 
549  return {query_func_ptr, row_process};
550 }
551 
552 // Return pair (query_func, row_func_call)
553 std::tuple<llvm::Function*, llvm::CallInst*> query_group_by_template(
554  llvm::Module* mod,
555  const bool hoist_literals,
557  const ExecutorDeviceType device_type,
558  const bool check_scan_limit,
559  const GpuSharedMemoryContext& gpu_smem_context) {
560  if (gpu_smem_context.isSharedMemoryUsed()) {
561  CHECK(device_type == ExecutorDeviceType::GPU);
562  }
563  using namespace llvm;
564 
565  auto* const i32_type = llvm::IntegerType::get(mod->getContext(), 32);
566  auto* const i64_type = llvm::IntegerType::get(mod->getContext(), 64);
567 
568  llvm::Function* const func_pos_start = pos_start(mod);
569  CHECK(func_pos_start);
570  llvm::Function* const func_pos_step = pos_step(mod);
571  CHECK(func_pos_step);
572  llvm::Function* const func_group_buff_idx = group_buff_idx(mod);
573  CHECK(func_group_buff_idx);
574  llvm::Function* const func_row_process = row_process(mod, 0, hoist_literals);
575  CHECK(func_row_process);
576  llvm::Function* const func_init_shared_mem =
577  gpu_smem_context.isSharedMemoryUsed() ? mod->getFunction("init_shared_mem")
578  : mod->getFunction("init_shared_mem_nop");
579  CHECK(func_init_shared_mem);
580 
581  auto func_write_back = mod->getFunction("write_back_nop");
582  CHECK(func_write_back);
583 
584  constexpr bool IS_GROUP_BY = true;
585  Params query_func_params = make_params<IS_GROUP_BY>(mod, hoist_literals);
586 
587  FunctionType* query_func_type = FunctionType::get(
588  /*Result=*/Type::getVoidTy(mod->getContext()),
589  /*Params=*/query_func_params.types(),
590  /*isVarArg=*/false);
591 
592  std::string query_name{"query_group_by_template"};
593  auto query_func_ptr = mod->getFunction(query_name);
594  CHECK(!query_func_ptr);
595 
596  query_func_ptr = Function::Create(
597  /*Type=*/query_func_type,
598  /*Linkage=*/GlobalValue::ExternalLinkage,
599  /*Name=*/"query_group_by_template",
600  mod);
601  query_func_ptr->setCallingConv(CallingConv::C);
602  query_func_ptr->setAttributes(query_func_params.attributeList());
603  query_func_params.setNames(query_func_ptr->arg_begin());
604 
605  auto bb_entry = BasicBlock::Create(mod->getContext(), ".entry", query_func_ptr, 0);
606  auto bb_preheader =
607  BasicBlock::Create(mod->getContext(), ".loop.preheader", query_func_ptr, 0);
608  auto bb_forbody = BasicBlock::Create(mod->getContext(), ".forbody", query_func_ptr, 0);
609  auto bb_crit_edge =
610  BasicBlock::Create(mod->getContext(), "._crit_edge", query_func_ptr, 0);
611  auto bb_exit = BasicBlock::Create(mod->getContext(), ".exit", query_func_ptr, 0);
612 
613  // Block .entry
614  llvm::Value* const row_count_ptr = get_arg_by_name(query_func_ptr, "row_count_ptr");
615  LoadInst* row_count = new LoadInst(
616  get_pointer_element_type(row_count_ptr), row_count_ptr, "", false, bb_entry);
617  row_count->setAlignment(LLVM_ALIGN(8));
618  row_count->setName("row_count");
619 
620  llvm::Value* const max_matched_ptr = get_arg_by_name(query_func_ptr, "max_matched_ptr");
621  LoadInst* max_matched = new LoadInst(
622  get_pointer_element_type(max_matched_ptr), max_matched_ptr, "", false, bb_entry);
623  max_matched->setAlignment(LLVM_ALIGN(8));
624 
625  auto crt_matched_ptr = new AllocaInst(i32_type, 0, "crt_matched", bb_entry);
626  auto old_total_matched_ptr = new AllocaInst(i32_type, 0, "old_total_matched", bb_entry);
627  CallInst* pos_start = CallInst::Create(func_pos_start, "", bb_entry);
628  pos_start->setCallingConv(CallingConv::C);
629  pos_start->setTailCall(true);
630  llvm::AttributeList pos_start_pal;
631  pos_start->setAttributes(pos_start_pal);
632 
633  CallInst* pos_step = CallInst::Create(func_pos_step, "", bb_entry);
634  pos_step->setCallingConv(CallingConv::C);
635  pos_step->setTailCall(true);
636  llvm::AttributeList pos_step_pal;
637  pos_step->setAttributes(pos_step_pal);
638 
639  CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx, "", bb_entry);
640  group_buff_idx_call->setCallingConv(CallingConv::C);
641  group_buff_idx_call->setTailCall(true);
642  llvm::AttributeList group_buff_idx_pal;
643  group_buff_idx_call->setAttributes(group_buff_idx_pal);
644  Value* group_buff_idx = group_buff_idx_call;
645 
646  auto* const group_by_buffers = get_arg_by_name(query_func_ptr, "group_by_buffers");
647  const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
648  CHECK(Ty);
649 
650  Value* varlen_output_buffer{nullptr};
651  if (query_mem_desc.hasVarlenOutput()) {
652  // make the varlen buffer the _first_ 8 byte value in the group by buffers double ptr,
653  // and offset the group by buffers index by 8 bytes
654  auto varlen_output_buffer_gep = GetElementPtrInst::Create(
655  Ty->getPointerElementType(),
656  group_by_buffers,
657  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
658  "",
659  bb_entry);
660  varlen_output_buffer =
661  new LoadInst(get_pointer_element_type(varlen_output_buffer_gep),
662  varlen_output_buffer_gep,
663  "varlen_output_buffer",
664  false,
665  bb_entry);
666 
667  group_buff_idx = BinaryOperator::Create(
668  Instruction::Add,
670  llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
671  "group_buff_idx_varlen_offset",
672  bb_entry);
673  } else {
674  varlen_output_buffer =
675  ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
676  }
677  CHECK(varlen_output_buffer);
678 
679  CastInst* pos_start_i64 = new SExtInst(pos_start, i64_type, "", bb_entry);
680  GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
681  Ty->getPointerElementType(), group_by_buffers, group_buff_idx, "", bb_entry);
682  LoadInst* col_buffer = new LoadInst(get_pointer_element_type(group_by_buffers_gep),
683  group_by_buffers_gep,
684  "",
685  false,
686  bb_entry);
687  col_buffer->setName("col_buffer");
688  col_buffer->setAlignment(LLVM_ALIGN(8));
689 
690  llvm::ConstantInt* shared_mem_bytes_lv =
691  ConstantInt::get(i32_type, gpu_smem_context.getSharedMemorySize());
692  // TODO(Saman): change this further, normal path should not go through this
693  llvm::CallInst* result_buffer =
694  CallInst::Create(func_init_shared_mem,
695  std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
696  "result_buffer",
697  bb_entry);
698 
699  ICmpInst* enter_or_not =
700  new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count, "");
701  BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
702 
703  // Block .loop.preheader
704  CastInst* pos_step_i64 = new SExtInst(pos_step, i64_type, "", bb_preheader);
705  BranchInst::Create(bb_forbody, bb_preheader);
706 
707  // Block .forbody
708  Argument* pos_pre = new Argument(i64_type);
709  PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2, "pos", bb_forbody);
710 
711  std::vector<Value*> row_process_params;
712  row_process_params.push_back(result_buffer);
713  row_process_params.push_back(varlen_output_buffer);
714  row_process_params.push_back(crt_matched_ptr);
715  row_process_params.push_back(get_arg_by_name(query_func_ptr, "total_matched"));
716  row_process_params.push_back(old_total_matched_ptr);
717  row_process_params.push_back(max_matched_ptr);
718  row_process_params.push_back(get_arg_by_name(query_func_ptr, "agg_init_val"));
719  row_process_params.push_back(pos);
720  row_process_params.push_back(get_arg_by_name(query_func_ptr, "frag_row_off_ptr"));
721  row_process_params.push_back(row_count_ptr);
722  if (hoist_literals) {
723  row_process_params.push_back(get_arg_by_name(query_func_ptr, "literals"));
724  }
725  if (check_scan_limit) {
726  new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
727  crt_matched_ptr,
728  bb_forbody);
729  }
730  CallInst* row_process =
731  CallInst::Create(func_row_process, row_process_params, "", bb_forbody);
732  row_process->setCallingConv(CallingConv::C);
733  row_process->setTailCall(true);
734  llvm::AttributeList row_process_pal;
735  row_process->setAttributes(row_process_pal);
736 
737  // Forcing all threads within a warp to be synchronized (Compute >= 7.x)
738  if (query_mem_desc.isWarpSyncRequired(device_type)) {
739  auto func_sync_warp_protected = mod->getFunction("sync_warp_protected");
740  CHECK(func_sync_warp_protected);
741  CallInst::Create(func_sync_warp_protected,
742  std::vector<llvm::Value*>{pos, row_count},
743  "",
744  bb_forbody);
745  }
746 
747  BinaryOperator* pos_inc =
748  BinaryOperator::Create(Instruction::Add, pos, pos_step_i64, "", bb_forbody);
749  ICmpInst* loop_or_exit =
750  new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count, "");
751  if (check_scan_limit) {
752  auto crt_matched = new LoadInst(get_pointer_element_type(crt_matched_ptr),
753  crt_matched_ptr,
754  "crt_matched",
755  false,
756  bb_forbody);
757  auto filter_match = BasicBlock::Create(
758  mod->getContext(), "filter_match", query_func_ptr, bb_crit_edge);
759  llvm::Value* new_total_matched =
760  new LoadInst(get_pointer_element_type(old_total_matched_ptr),
761  old_total_matched_ptr,
762  "",
763  false,
764  filter_match);
765  new_total_matched =
766  BinaryOperator::CreateAdd(new_total_matched, crt_matched, "", filter_match);
767  CHECK(new_total_matched);
768  ICmpInst* limit_not_reached = new ICmpInst(*filter_match,
769  ICmpInst::ICMP_SLT,
770  new_total_matched,
771  max_matched,
772  "limit_not_reached");
773  BranchInst::Create(
774  bb_forbody,
775  bb_crit_edge,
776  BinaryOperator::Create(
777  BinaryOperator::And, loop_or_exit, limit_not_reached, "", filter_match),
778  filter_match);
779  auto filter_nomatch = BasicBlock::Create(
780  mod->getContext(), "filter_nomatch", query_func_ptr, bb_crit_edge);
781  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
782  ICmpInst* crt_matched_nz = new ICmpInst(
783  *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0), "");
784  BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
785  pos->addIncoming(pos_start_i64, bb_preheader);
786  pos->addIncoming(pos_pre, filter_match);
787  pos->addIncoming(pos_pre, filter_nomatch);
788  } else {
789  pos->addIncoming(pos_start_i64, bb_preheader);
790  pos->addIncoming(pos_pre, bb_forbody);
791  BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
792  }
793 
794  // Block ._crit_edge
795  BranchInst::Create(bb_exit, bb_crit_edge);
796 
797  // Block .exit
798  CallInst::Create(func_write_back,
799  std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
800  "",
801  bb_exit);
802 
803  ReturnInst::Create(mod->getContext(), bb_exit);
804 
805  // Resolve Forward References
806  pos_pre->replaceAllUsesWith(pos_inc);
807  delete pos_pre;
808 
809  if (verifyFunction(*query_func_ptr, &llvm::errs())) {
810  LOG(FATAL) << "Generated invalid code. ";
811  }
812 
813  return {query_func_ptr, row_process};
814 }
void addAttributes(unsigned const index, ATTRS const ...attrs)
llvm::SmallVector< llvm::AttributeList, NTYPES+1 > attrs_
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
#define LOG(tag)
Definition: Logger.h:285
size_t getSharedMemorySize() const
llvm::Function * group_buff_idx(llvm::Module *mod)
llvm::AttributeList make_attribute_list(llvm::Module const *const mod, unsigned const index, ATTRS const ...attrs)
std::tuple< llvm::Function *, llvm::CallInst * > query_template(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals, const bool is_estimate_query, const GpuSharedMemoryContext &gpu_smem_context)
AGG_TYPE agg_func(AGG_TYPE const lhs, AGG_TYPE const rhs)
Type pointer_type(const Type pointee)
#define LLVM_ALIGN(alignment)
std::tuple< llvm::Function *, llvm::CallInst * > query_group_by_template(llvm::Module *mod, const bool hoist_literals, const QueryMemoryDescriptor &query_mem_desc, const ExecutorDeviceType device_type, const bool check_scan_limit, const GpuSharedMemoryContext &gpu_smem_context)
llvm::Function * default_func_builder(llvm::Module *mod, const std::string &name)
ExecutorDeviceType
std::string to_string(char const *&&v)
llvm::Value * get_arg_by_name(llvm::Function *func, const std::string &name)
Definition: Execute.h:168
Params< NTYPES > make_params(llvm::Module const *const mod, bool const hoist_literals)
dictionary params
Definition: report.py:27
bool isWarpSyncRequired(const ExecutorDeviceType) const
void setNames(llvm::Function::arg_iterator itr) const
#define CHECK(condition)
Definition: Logger.h:291
string name
Definition: setup.in.py:72
llvm::Type * get_pointer_element_type(llvm::Value *value)
void pushBack(llvm::Type *const type, char const *const name, ATTRS const ...attrs)