18 #include "../CgenState.h"
21 #include <llvm/IR/Type.h>
27 const std::function<
JoinLoopDomain(
const std::vector<llvm::Value*>&)>&
28 iteration_domain_codegen,
29 const std::function<llvm::Value*(
const std::vector<llvm::Value*>&)>&
30 outer_condition_match,
31 const std::function<
void(llvm::Value*)>& found_outer_matches,
33 const std::function<llvm::Value*(
const std::vector<llvm::Value*>&,
34 llvm::Value*)>& is_deleted,
35 const bool nested_loop_join,
36 const std::string&
name)
39 , iteration_domain_codegen_(iteration_domain_codegen)
40 , outer_condition_match_(outer_condition_match)
41 , found_outer_matches_(found_outer_matches)
42 , hoisted_filters_(hoisted_filters)
43 , is_deleted_(is_deleted)
44 , nested_loop_join_(nested_loop_join)
51 const std::vector<JoinLoop>& join_loops,
52 const std::function<llvm::BasicBlock*(
const std::vector<llvm::Value*>&)>&
54 llvm::Value* outer_iter,
55 llvm::BasicBlock* exit_bb,
58 llvm::IRBuilder<>& builder = cgen_state->
ir_builder_;
59 llvm::BasicBlock* prev_exit_bb{exit_bb};
60 llvm::BasicBlock* prev_iter_advance_bb{
nullptr};
61 llvm::BasicBlock* last_head_bb{
nullptr};
62 auto& context = builder.getContext();
63 const auto parent_func = builder.GetInsertBlock()->getParent();
64 llvm::Value* prev_comparison_result{
nullptr};
65 llvm::BasicBlock* entry{
nullptr};
66 std::vector<llvm::Value*> iterators;
67 iterators.push_back(outer_iter);
69 for (
const auto& join_loop : join_loops) {
70 switch (join_loop.kind_) {
74 const auto preheader_bb = llvm::BasicBlock::Create(
75 context,
"ub_iter_preheader_" + join_loop.name_, parent_func);
77 llvm::BasicBlock* filter_bb{
nullptr};
78 if (join_loop.hoisted_filters_) {
79 filter_bb = join_loop.hoisted_filters_(
80 preheader_bb, prev_exit_bb, join_loop.name_, parent_func, cgen_state);
84 entry = filter_bb ? filter_bb : preheader_bb;
87 if (prev_comparison_result) {
89 prev_comparison_result,
90 filter_bb ? filter_bb : preheader_bb,
91 prev_join_type ==
JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
93 prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
94 builder.SetInsertPoint(preheader_bb);
96 const auto iteration_counter_ptr = builder.CreateAlloca(
97 get_int_type(64, context),
nullptr,
"ub_iter_counter_ptr_" + join_loop.name_);
98 llvm::Value* found_an_outer_match_ptr{
nullptr};
99 llvm::Value* current_condition_match_ptr{
nullptr};
101 found_an_outer_match_ptr = builder.CreateAlloca(
102 get_int_type(1, context),
nullptr,
"found_an_outer_match");
103 builder.CreateStore(
ll_bool(
false, context), found_an_outer_match_ptr);
104 current_condition_match_ptr = builder.CreateAlloca(
105 get_int_type(1, context),
nullptr,
"outer_condition_current_match");
107 builder.CreateStore(
ll_int(int64_t(0), context), iteration_counter_ptr);
108 const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
110 const auto head_bb = llvm::BasicBlock::Create(
111 context,
"ub_iter_head_" + join_loop.name_, parent_func);
113 if (iteration_domain.error_code) {
115 auto ub_iter_success_code =
ll_int(int32_t(0), context);
116 const auto ub_iter_error_condition =
117 builder.CreateICmpEQ(iteration_domain.error_code, ub_iter_success_code);
119 llvm::BasicBlock::Create(context,
"ub_iter_error_exit", parent_func);
120 builder.CreateCondBr(ub_iter_error_condition, head_bb, error_bb);
122 builder.SetInsertPoint(error_bb);
123 builder.CreateRet(iteration_domain.error_code);
125 builder.CreateBr(head_bb);
128 builder.SetInsertPoint(head_bb);
129 llvm::Value* iteration_counter =
130 builder.CreateLoad(iteration_counter_ptr->getType()->getPointerElementType(),
131 iteration_counter_ptr,
132 "ub_iter_counter_val_" + join_loop.name_);
133 auto iteration_val = iteration_counter;
136 !iteration_domain.values_buffer);
139 CHECK(iteration_domain.values_buffer->getType()->isPointerTy());
140 const auto ptr_type =
141 static_cast<llvm::PointerType*
>(iteration_domain.values_buffer->getType());
142 if (ptr_type->getPointerElementType()->isArrayTy()) {
143 iteration_val = builder.CreateGEP(
144 iteration_domain.values_buffer->getType()
146 ->getPointerElementType(),
147 iteration_domain.values_buffer,
148 std::vector<llvm::Value*>{
151 "ub_iter_counter_" + join_loop.name_);
153 iteration_val = builder.CreateGEP(iteration_domain.values_buffer->getType()
155 ->getPointerElementType(),
156 iteration_domain.values_buffer,
158 "ub_iter_counter_" + join_loop.name_);
161 iterators.push_back(iteration_val);
162 const auto have_more_inner_rows = builder.CreateICmpSLT(
165 : iteration_domain.element_count,
166 "have_more_inner_rows");
167 const auto iter_advance_bb = llvm::BasicBlock::Create(
168 context,
"ub_iter_advance_" + join_loop.name_, parent_func);
169 llvm::BasicBlock* row_not_deleted_bb{
nullptr};
170 if (join_loop.is_deleted_) {
171 row_not_deleted_bb = llvm::BasicBlock::Create(
172 context,
"row_not_deleted_" + join_loop.name_, parent_func);
173 const auto row_is_deleted =
174 join_loop.is_deleted_(iterators, have_more_inner_rows);
175 builder.CreateCondBr(row_is_deleted, iter_advance_bb, row_not_deleted_bb);
176 builder.SetInsertPoint(row_not_deleted_bb);
179 std::tie(last_head_bb, prev_comparison_result) =
184 have_more_inner_rows,
185 found_an_outer_match_ptr,
186 current_condition_match_ptr,
189 prev_comparison_result = have_more_inner_rows;
190 last_head_bb = row_not_deleted_bb ? row_not_deleted_bb : head_bb;
192 builder.SetInsertPoint(iter_advance_bb);
193 const auto iteration_counter_next_val =
194 builder.CreateAdd(iteration_counter,
ll_int(int64_t(1), context));
195 builder.CreateStore(iteration_counter_next_val, iteration_counter_ptr);
197 const auto no_more_inner_rows =
198 builder.CreateICmpSGT(iteration_counter_next_val,
200 ? iteration_domain.upper_bound
201 : iteration_domain.element_count,
202 "no_more_inner_rows");
203 builder.CreateCondBr(no_more_inner_rows, prev_exit_bb, head_bb);
205 builder.CreateBr(head_bb);
207 builder.SetInsertPoint(last_head_bb);
208 prev_iter_advance_bb = iter_advance_bb;
212 const auto true_bb = llvm::BasicBlock::Create(
213 context,
"singleton_true_" + join_loop.name_, parent_func);
215 llvm::BasicBlock* filter_bb{
nullptr};
216 if (join_loop.hoisted_filters_) {
217 filter_bb = join_loop.hoisted_filters_(
218 true_bb, prev_exit_bb, join_loop.name_, parent_func, cgen_state);
222 entry = filter_bb ? filter_bb : true_bb;
225 if (prev_comparison_result) {
226 builder.CreateCondBr(
227 prev_comparison_result,
228 filter_bb ? filter_bb : true_bb,
229 prev_join_type ==
JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
231 prev_exit_bb = prev_iter_advance_bb ? prev_iter_advance_bb : exit_bb;
233 builder.SetInsertPoint(true_bb);
234 const auto iteration_domain = join_loop.iteration_domain_codegen_(iterators);
235 CHECK(!iteration_domain.values_buffer);
236 iterators.push_back(iteration_domain.slot_lookup_result);
237 auto join_cond_match = builder.CreateICmpSGE(iteration_domain.slot_lookup_result,
238 ll_int<int64_t>(0, context));
239 llvm::Value* remaining_cond_match = builder.CreateAlloca(
240 get_int_type(1, context),
nullptr,
"remaining_outer_cond_match");
241 builder.CreateStore(
ll_bool(
true, context), remaining_cond_match);
243 if (join_loop.type_ ==
JoinType::LEFT && join_loop.outer_condition_match_) {
244 const auto parent_func = builder.GetInsertBlock()->getParent();
245 const auto evaluate_remaining_outer_cond_bb = llvm::BasicBlock::Create(
246 context,
"eval_remaining_outer_cond_" + join_loop.name_, parent_func);
247 const auto after_evaluate_outer_cond_bb = llvm::BasicBlock::Create(
248 context,
"after_eval_outer_cond_" + join_loop.name_, parent_func);
249 builder.CreateCondBr(join_cond_match,
250 evaluate_remaining_outer_cond_bb,
251 after_evaluate_outer_cond_bb);
252 builder.SetInsertPoint(evaluate_remaining_outer_cond_bb);
253 const auto outer_cond_match = join_loop.outer_condition_match_(iterators);
254 const auto true_left_cond_match =
255 builder.CreateAnd(outer_cond_match, join_cond_match);
256 builder.CreateStore(true_left_cond_match, remaining_cond_match);
257 builder.CreateBr(after_evaluate_outer_cond_bb);
258 builder.SetInsertPoint(after_evaluate_outer_cond_bb);
260 auto match_found = builder.CreateAnd(
262 builder.CreateLoad(remaining_cond_match->getType()->getPointerElementType(),
263 remaining_cond_match));
265 if (join_loop.is_deleted_) {
266 match_found = builder.CreateAnd(
267 match_found, builder.CreateNot(join_loop.is_deleted_(iterators,
nullptr)));
269 auto match_found_bb = builder.GetInsertBlock();
270 switch (join_loop.type_) {
273 prev_comparison_result = match_found;
277 auto match_found_for_anti_join = builder.CreateICmpSLT(
278 iteration_domain.slot_lookup_result, ll_int<int64_t>(0, context));
279 prev_comparison_result = match_found_for_anti_join;
283 join_loop.found_outer_matches_(match_found);
285 prev_comparison_result =
ll_bool(
true, context);
291 if (!prev_iter_advance_bb) {
292 prev_iter_advance_bb = prev_exit_bb;
294 last_head_bb = match_found_bb;
300 prev_join_type = join_loop.type_;
303 const auto body_bb = body_codegen(iterators);
304 builder.CreateBr(prev_iter_advance_bb);
305 builder.SetInsertPoint(last_head_bb);
306 builder.CreateCondBr(
307 prev_comparison_result,
309 prev_join_type ==
JoinType::LEFT ? prev_iter_advance_bb : prev_exit_bb);
316 const std::vector<llvm::Value*>& iterators,
317 llvm::Value* iteration_counter,
318 llvm::Value* have_more_inner_rows,
319 llvm::Value* found_an_outer_match_ptr,
320 llvm::Value* current_condition_match_ptr,
323 llvm::IRBuilder<>& builder = cgen_state->
ir_builder_;
324 auto& context = builder.getContext();
325 const auto parent_func = builder.GetInsertBlock()->getParent();
326 builder.CreateStore(
ll_bool(
false, context), current_condition_match_ptr);
327 const auto evaluate_outer_condition_bb = llvm::BasicBlock::Create(
328 context,
"eval_outer_cond_" + join_loop.
name_, parent_func);
329 const auto after_evaluate_outer_condition_bb = llvm::BasicBlock::Create(
330 context,
"after_eval_outer_cond_" + join_loop.
name_, parent_func);
331 builder.CreateCondBr(have_more_inner_rows,
332 evaluate_outer_condition_bb,
333 after_evaluate_outer_condition_bb);
334 builder.SetInsertPoint(evaluate_outer_condition_bb);
338 builder.CreateStore(current_condition_match, current_condition_match_ptr);
339 const auto updated_condition_match = builder.CreateOr(
340 current_condition_match,
341 builder.CreateLoad(found_an_outer_match_ptr->getType()->getPointerElementType(),
342 found_an_outer_match_ptr));
343 builder.CreateStore(updated_condition_match, found_an_outer_match_ptr);
344 builder.CreateBr(after_evaluate_outer_condition_bb);
345 builder.SetInsertPoint(after_evaluate_outer_condition_bb);
346 const auto no_matches_found = builder.CreateNot(
347 builder.CreateLoad(found_an_outer_match_ptr->getType()->getPointerElementType(),
348 found_an_outer_match_ptr));
349 const auto no_more_inner_rows = builder.CreateICmpEQ(
355 const auto do_iteration = builder.CreateOr(
356 builder.CreateLoad(current_condition_match_ptr->getType()->getPointerElementType(),
357 current_condition_match_ptr),
358 builder.CreateAnd(no_matches_found, no_more_inner_rows));
360 builder.CreateLoad(current_condition_match_ptr->getType()->getPointerElementType(),
361 current_condition_match_ptr));
362 return {after_evaluate_outer_condition_bb, do_iteration};
llvm::Value * element_count
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
llvm::IRBuilder ir_builder_
std::function< llvm::BasicBlock *(llvm::BasicBlock *, llvm::BasicBlock *, const std::string &, llvm::Function *, CgenState *)> HoistedFiltersCallback
const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> outer_condition_match_
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
const std::function< void(llvm::Value *)> found_outer_matches_
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
llvm::Value * upper_bound
JoinLoop(const JoinLoopKind, const JoinType, const std::function< JoinLoopDomain(const std::vector< llvm::Value * > &)> &iteration_domain_codegen, const std::function< llvm::Value *(const std::vector< llvm::Value * > &)> &outer_condition_match, const std::function< void(llvm::Value *)> &found_outer_matches, const HoistedFiltersCallback &hoisted_filters, const std::function< llvm::Value *(const std::vector< llvm::Value * > &prev_iters, llvm::Value *)> &is_deleted, const bool nested_loop_join=false, const std::string &name="")
llvm::ConstantInt * ll_bool(const bool v, llvm::LLVMContext &context)
static std::pair< llvm::BasicBlock *, llvm::Value * > evaluateOuterJoinCondition(const JoinLoop &join_loop, const JoinLoopDomain &iteration_domain, const std::vector< llvm::Value * > &iterators, llvm::Value *iteration_counter, llvm::Value *have_more_inner_rows, llvm::Value *found_an_outer_match_ptr, llvm::Value *current_condition_match_ptr, CgenState *cgen_state)