OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
JoinLoop.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "JoinLoop.h"
18 #include "../CgenState.h"
19 #include "Logger/Logger.h"
20 
21 #include <llvm/IR/Type.h>
22 
23 #include <stack>
24 
26  const JoinType type,
27  const std::function<JoinLoopDomain(const std::vector<llvm::Value*>&)>&
28  iteration_domain_codegen,
29  const std::function<llvm::Value*(const std::vector<llvm::Value*>&)>&
30  outer_condition_match,
31  const std::function<void(llvm::Value*)>& found_outer_matches,
32  const HoistedFiltersCallback& hoisted_filters,
33  const std::function<llvm::Value*(const std::vector<llvm::Value*>&,
34  llvm::Value*)>& is_deleted,
35  const bool nested_loop_join,
36  const std::string& name)
37  : kind_(kind)
38  , type_(type)
39  , iteration_domain_codegen_(iteration_domain_codegen)
40  , outer_condition_match_(outer_condition_match)
41  , found_outer_matches_(found_outer_matches)
42  , hoisted_filters_(hoisted_filters)
43  , is_deleted_(is_deleted)
44  , nested_loop_join_(nested_loop_join)
45  , name_(name) {
46  CHECK(outer_condition_match == nullptr || type == JoinType::LEFT);
47  CHECK_EQ(static_cast<bool>(found_outer_matches), (type == JoinType::LEFT));
48 }
49 
50 llvm::BasicBlock* JoinLoop::codegen(
51  const std::vector<JoinLoop>& join_loops,
52  const std::function<llvm::BasicBlock*(const std::vector<llvm::Value*>&)>&
53  body_codegen,
54  llvm::Value* outer_iter,
55  llvm::BasicBlock* exit_bb,
56  CgenState* cgen_state) {
57  AUTOMATIC_IR_METADATA(cgen_state);
58  llvm::IRBuilder<>& builder = cgen_state->ir_builder_;
59  llvm::BasicBlock* prev_exit_bb{exit_bb};
60  llvm::BasicBlock* prev_iter_advance_bb{nullptr};
61  llvm::BasicBlock* last_head_bb{nullptr};
62  auto& context = builder.getContext();
63  const auto parent_func = builder.GetInsertBlock()->getParent();
64  llvm::Value* prev_comparison_result{nullptr};
65  llvm::BasicBlock* entry{nullptr};
66  std::vector<llvm::Value*> iterators;
67  iterators.push_back(outer_iter);
68  JoinType prev_join_type{JoinType::INVALID};
69  for (const auto& join_loop : join_loops) {
70  switch (join_loop.kind_) {
72  case JoinLoopKind::Set:
74  const auto preheader_bb = llvm::BasicBlock::Create(
75  context, "ub_iter_preheader_" + join_loop.name_, parent_func);
76 
77  llvm::BasicBlock* filter_bb{nullptr};
78  if (join_loop.hoisted_filters_) {
79  filter_bb = join_loop.hoisted_filters_(
80  preheader_bb, prev_exit_bb, join_loop.name_, parent_func, cgen_state);
81  }
82 
83  if (!entry) {
84  entry = filter_bb ? filter_bb : preheader_bb;
85  }
86 
87  if (prev_comparison_result) {
88  builder.CreateCondBr(
89  prev_comparison_result,
90  filter_bb ? filter_bb : preheader_bb,
91  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
92  }
93  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
94  builder.SetInsertPoint(preheader_bb);
95 
96  const auto iteration_counter_ptr = builder.CreateAlloca(
97  get_int_type(64, context), nullptr, "ub_iter_counter_ptr_" + join_loop.name_);
98  llvm::Value* found_an_outer_match_ptr{nullptr};
99  llvm::Value* current_condition_match_ptr{nullptr};
100  if (join_loop.type_ == JoinType::LEFT) {
101  found_an_outer_match_ptr = builder.CreateAlloca(
102  get_int_type(1, context), nullptr, "found_an_outer_match");
103  builder.CreateStore(ll_bool(false, context), found_an_outer_match_ptr);
104  current_condition_match_ptr = builder.CreateAlloca(
105  get_int_type(1, context), nullptr, "outer_condition_current_match");
106  }
107  builder.CreateStore(ll_int(int64_t(0), context), iteration_counter_ptr);
108  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
109 
110  const auto head_bb = llvm::BasicBlock::Create(
111  context, "ub_iter_head_" + join_loop.name_, parent_func);
112 
113  if (iteration_domain.error_code) {
114  cgen_state->needs_error_check_ = true;
115  auto ub_iter_success_code = ll_int(int32_t(0), context);
116  const auto ub_iter_error_condition =
117  builder.CreateICmpEQ(iteration_domain.error_code, ub_iter_success_code);
118  auto error_bb =
119  llvm::BasicBlock::Create(context, "ub_iter_error_exit", parent_func);
120  builder.CreateCondBr(ub_iter_error_condition, head_bb, error_bb);
121 
122  builder.SetInsertPoint(error_bb);
123  builder.CreateRet(iteration_domain.error_code);
124  } else {
125  builder.CreateBr(head_bb);
126  }
127 
128  builder.SetInsertPoint(head_bb);
129  llvm::Value* iteration_counter =
130  builder.CreateLoad(iteration_counter_ptr->getType()->getPointerElementType(),
131  iteration_counter_ptr,
132  "ub_iter_counter_val_" + join_loop.name_);
133  auto iteration_val = iteration_counter;
134  CHECK(join_loop.kind_ == JoinLoopKind::Set ||
135  join_loop.kind_ == JoinLoopKind::MultiSet ||
136  !iteration_domain.values_buffer);
137  if (join_loop.kind_ == JoinLoopKind::Set ||
138  join_loop.kind_ == JoinLoopKind::MultiSet) {
139  CHECK(iteration_domain.values_buffer->getType()->isPointerTy());
140  const auto ptr_type =
141  static_cast<llvm::PointerType*>(iteration_domain.values_buffer->getType());
142  if (ptr_type->getPointerElementType()->isArrayTy()) {
143  iteration_val = builder.CreateGEP(
144  iteration_domain.values_buffer->getType()
145  ->getScalarType()
146  ->getPointerElementType(),
147  iteration_domain.values_buffer,
148  std::vector<llvm::Value*>{
149  llvm::ConstantInt::get(get_int_type(64, context), 0),
150  iteration_counter},
151  "ub_iter_counter_" + join_loop.name_);
152  } else {
153  iteration_val = builder.CreateGEP(iteration_domain.values_buffer->getType()
154  ->getScalarType()
155  ->getPointerElementType(),
156  iteration_domain.values_buffer,
157  iteration_counter,
158  "ub_iter_counter_" + join_loop.name_);
159  }
160  }
161  iterators.push_back(iteration_val);
162  const auto have_more_inner_rows = builder.CreateICmpSLT(
163  iteration_counter,
164  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
165  : iteration_domain.element_count,
166  "have_more_inner_rows");
167  const auto iter_advance_bb = llvm::BasicBlock::Create(
168  context, "ub_iter_advance_" + join_loop.name_, parent_func);
169  llvm::BasicBlock* row_not_deleted_bb{nullptr};
170  if (join_loop.is_deleted_) {
171  row_not_deleted_bb = llvm::BasicBlock::Create(
172  context, "row_not_deleted_" + join_loop.name_, parent_func);
173  const auto row_is_deleted =
174  join_loop.is_deleted_(iterators, have_more_inner_rows);
175  builder.CreateCondBr(row_is_deleted, iter_advance_bb, row_not_deleted_bb);
176  builder.SetInsertPoint(row_not_deleted_bb);
177  }
178  if (join_loop.type_ == JoinType::LEFT) {
179  std::tie(last_head_bb, prev_comparison_result) =
180  evaluateOuterJoinCondition(join_loop,
181  iteration_domain,
182  iterators,
183  iteration_counter,
184  have_more_inner_rows,
185  found_an_outer_match_ptr,
186  current_condition_match_ptr,
187  cgen_state);
188  } else {
189  prev_comparison_result = have_more_inner_rows;
190  last_head_bb = row_not_deleted_bb ? row_not_deleted_bb : head_bb;
191  }
192  builder.SetInsertPoint(iter_advance_bb);
193  const auto iteration_counter_next_val =
194  builder.CreateAdd(iteration_counter, ll_int(int64_t(1), context));
195  builder.CreateStore(iteration_counter_next_val, iteration_counter_ptr);
196  if (join_loop.type_ == JoinType::LEFT) {
197  const auto no_more_inner_rows =
198  builder.CreateICmpSGT(iteration_counter_next_val,
199  join_loop.kind_ == JoinLoopKind::UpperBound
200  ? iteration_domain.upper_bound
201  : iteration_domain.element_count,
202  "no_more_inner_rows");
203  builder.CreateCondBr(no_more_inner_rows, prev_exit_bb, head_bb);
204  } else {
205  builder.CreateBr(head_bb);
206  }
207  builder.SetInsertPoint(last_head_bb);
208  prev_iter_advance_bb = iter_advance_bb;
209  break;
210  }
212  const auto true_bb = llvm::BasicBlock::Create(
213  context, "singleton_true_" + join_loop.name_, parent_func);
214 
215  llvm::BasicBlock* filter_bb{nullptr};
216  if (join_loop.hoisted_filters_) {
217  filter_bb = join_loop.hoisted_filters_(
218  true_bb, prev_exit_bb, join_loop.name_, parent_func, cgen_state);
219  }
220 
221  if (!entry) {
222  entry = filter_bb ? filter_bb : true_bb;
223  }
224 
225  if (prev_comparison_result) {
226  builder.CreateCondBr(
227  prev_comparison_result,
228  filter_bb ? filter_bb : true_bb,
229  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
230  }
231  prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
232 
233  builder.SetInsertPoint(true_bb);
234  const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
235  CHECK(!iteration_domain.values_buffer);
236  iterators.push_back(iteration_domain.slot_lookup_result);
237  auto join_cond_match = builder.CreateICmpSGE(iteration_domain.slot_lookup_result,
238  ll_int<int64_t>(0, context));
239  llvm::Value* remaining_cond_match = builder.CreateAlloca(
240  get_int_type(1, context), nullptr, "remaining_outer_cond_match");
241  builder.CreateStore(ll_bool(true, context), remaining_cond_match);
242 
243  if (join_loop.type_ == JoinType::LEFT && join_loop.outer_condition_match_) {
244  const auto parent_func = builder.GetInsertBlock()->getParent();
245  const auto evaluate_remaining_outer_cond_bb = llvm::BasicBlock::Create(
246  context, "eval_remaining_outer_cond_" + join_loop.name_, parent_func);
247  const auto after_evaluate_outer_cond_bb = llvm::BasicBlock::Create(
248  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
249  builder.CreateCondBr(join_cond_match,
250  evaluate_remaining_outer_cond_bb,
251  after_evaluate_outer_cond_bb);
252  builder.SetInsertPoint(evaluate_remaining_outer_cond_bb);
253  const auto outer_cond_match = join_loop.outer_condition_match_(iterators);
254  const auto true_left_cond_match =
255  builder.CreateAnd(outer_cond_match, join_cond_match);
256  builder.CreateStore(true_left_cond_match, remaining_cond_match);
257  builder.CreateBr(after_evaluate_outer_cond_bb);
258  builder.SetInsertPoint(after_evaluate_outer_cond_bb);
259  }
260  auto match_found = builder.CreateAnd(
261  join_cond_match,
262  builder.CreateLoad(remaining_cond_match->getType()->getPointerElementType(),
263  remaining_cond_match));
264  CHECK(match_found);
265  if (join_loop.is_deleted_) {
266  match_found = builder.CreateAnd(
267  match_found, builder.CreateNot(join_loop.is_deleted_(iterators, nullptr)));
268  }
269  auto match_found_bb = builder.GetInsertBlock();
270  switch (join_loop.type_) {
271  case JoinType::INNER:
272  case JoinType::SEMI: {
273  prev_comparison_result = match_found;
274  break;
275  }
276  case JoinType::ANTI: {
277  auto match_found_for_anti_join = builder.CreateICmpSLT(
278  iteration_domain.slot_lookup_result, ll_int<int64_t>(0, context));
279  prev_comparison_result = match_found_for_anti_join;
280  break;
281  }
282  case JoinType::LEFT: {
283  join_loop.found_outer_matches_(match_found);
284  // For outer joins, do the iteration regardless of the result of the match.
285  prev_comparison_result = ll_bool(true, context);
286  break;
287  }
288  default:
289  CHECK(false);
290  }
291  if (!prev_iter_advance_bb) {
292  prev_iter_advance_bb = prev_exit_bb;
293  }
294  last_head_bb = match_found_bb;
295  break;
296  }
297  default:
298  CHECK(false);
299  }
300  prev_join_type = join_loop.type_;
301  }
302 
303  const auto body_bb = body_codegen(iterators);
304  builder.CreateBr(prev_iter_advance_bb);
305  builder.SetInsertPoint(last_head_bb);
306  builder.CreateCondBr(
307  prev_comparison_result,
308  body_bb,
309  prev_join_type == JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
310  return entry;
311 }
312 
313 std::pair<llvm::BasicBlock*, llvm::Value*> JoinLoop::evaluateOuterJoinCondition(
314  const JoinLoop& join_loop,
315  const JoinLoopDomain& iteration_domain,
316  const std::vector<llvm::Value*>& iterators,
317  llvm::Value* iteration_counter,
318  llvm::Value* have_more_inner_rows,
319  llvm::Value* found_an_outer_match_ptr,
320  llvm::Value* current_condition_match_ptr,
321  CgenState* cgen_state) {
322  AUTOMATIC_IR_METADATA(cgen_state);
323  llvm::IRBuilder<>& builder = cgen_state->ir_builder_;
324  auto& context = builder.getContext();
325  const auto parent_func = builder.GetInsertBlock()->getParent();
326  builder.CreateStore(ll_bool(false, context), current_condition_match_ptr);
327  const auto evaluate_outer_condition_bb = llvm::BasicBlock::Create(
328  context, "eval_outer_cond_" + join_loop.name_, parent_func);
329  const auto after_evaluate_outer_condition_bb = llvm::BasicBlock::Create(
330  context, "after_eval_outer_cond_" + join_loop.name_, parent_func);
331  builder.CreateCondBr(have_more_inner_rows,
332  evaluate_outer_condition_bb,
333  after_evaluate_outer_condition_bb);
334  builder.SetInsertPoint(evaluate_outer_condition_bb);
335  const auto current_condition_match = join_loop.outer_condition_match_
336  ? join_loop.outer_condition_match_(iterators)
337  : ll_bool(true, context);
338  builder.CreateStore(current_condition_match, current_condition_match_ptr);
339  const auto updated_condition_match = builder.CreateOr(
340  current_condition_match,
341  builder.CreateLoad(found_an_outer_match_ptr->getType()->getPointerElementType(),
342  found_an_outer_match_ptr));
343  builder.CreateStore(updated_condition_match, found_an_outer_match_ptr);
344  builder.CreateBr(after_evaluate_outer_condition_bb);
345  builder.SetInsertPoint(after_evaluate_outer_condition_bb);
346  const auto no_matches_found = builder.CreateNot(
347  builder.CreateLoad(found_an_outer_match_ptr->getType()->getPointerElementType(),
348  found_an_outer_match_ptr));
349  const auto no_more_inner_rows = builder.CreateICmpEQ(
350  iteration_counter,
351  join_loop.kind_ == JoinLoopKind::UpperBound ? iteration_domain.upper_bound
352  : iteration_domain.element_count);
353  // Do the iteration if the outer condition is true or it's the last iteration and no
354  // matches have been found.
355  const auto do_iteration = builder.CreateOr(
356  builder.CreateLoad(current_condition_match_ptr->getType()->getPointerElementType(),
357  current_condition_match_ptr),
358  builder.CreateAnd(no_matches_found, no_more_inner_rows));
359  join_loop.found_outer_matches_(
360  builder.CreateLoad(current_condition_match_ptr->getType()->getPointerElementType(),
361  current_condition_match_ptr));
362  return {after_evaluate_outer_condition_bb, do_iteration};
363 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
JoinType
Definition: sqldefs.h:238
llvm::Value * element_count
Definition: JoinLoop.h:46
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
Definition: CgenState.h:384
std::function< llvm::BasicBlock *(llvm::BasicBlock *, llvm::BasicBlock *, const std::string &, llvm::Function *, CgenState *)> HoistedFiltersCallback
Definition: JoinLoop.h:62
const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> outer_condition_match_
Definition: JoinLoop.h:110
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
const JoinLoopKind kind_
Definition: JoinLoop.h:100
const std::function< void(llvm::Value *)> found_outer_matches_
Definition: JoinLoop.h:113
const std::string name_
Definition: JoinLoop.h:126
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
Definition: JoinLoop.cpp:50
bool needs_error_check_
Definition: CgenState.h:405
#define AUTOMATIC_IR_METADATA(CGENSTATE)
llvm::Value * upper_bound
Definition: JoinLoop.h:45
JoinLoop(const JoinLoopKind, const JoinType, const std::function< JoinLoopDomain(const std::vector< llvm::Value * > &)> &iteration_domain_codegen, const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> &outer_condition_match, const std::function< void(llvm::Value *)> &found_outer_matches, const HoistedFiltersCallback &hoisted_filters, const std::function< llvm::Value *(const std::vector< llvm::Value * > &prev_iters, llvm::Value *)> &is_deleted, const bool nested_loop_join=false, const std::string &name="")
Definition: JoinLoop.cpp:25
#define CHECK(condition)
Definition: Logger.h:291
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
JoinLoopKind
Definition: JoinLoop.h:31
string name
Definition: setup.in.py:72
static std::pair< llvm::BasicBlock *, llvm::Value * > evaluateOuterJoinCondition(const JoinLoop &join_loop, const JoinLoopDomain &iteration_domain, const std::vector< llvm::Value * > &iterators, llvm::Value *iteration_counter, llvm::Value *have_more_inner_rows, llvm::Value *found_an_outer_match_ptr, llvm::Value *current_condition_match_ptr, CgenState *cgen_state)
Definition: JoinLoop.cpp:313