563 using namespace llvm;
565 auto*
const i32_type = llvm::IntegerType::get(mod->getContext(), 32);
566 auto*
const i64_type = llvm::IntegerType::get(mod->getContext(), 64);
568 llvm::Function*
const func_pos_start =
pos_start(mod);
569 CHECK(func_pos_start);
570 llvm::Function*
const func_pos_step =
pos_step(mod);
571 CHECK(func_pos_step);
573 CHECK(func_group_buff_idx);
574 llvm::Function*
const func_row_process =
row_process(mod, 0, hoist_literals);
575 CHECK(func_row_process);
576 llvm::Function*
const func_init_shared_mem =
578 : mod->getFunction(
"init_shared_mem_nop");
579 CHECK(func_init_shared_mem);
581 auto func_write_back = mod->getFunction(
"write_back_nop");
582 CHECK(func_write_back);
584 constexpr
bool IS_GROUP_BY =
true;
585 Params query_func_params = make_params<IS_GROUP_BY>(mod, hoist_literals);
587 FunctionType* query_func_type = FunctionType::get(
588 Type::getVoidTy(mod->getContext()),
589 query_func_params.types(),
592 std::string query_name{
"query_group_by_template"};
593 auto query_func_ptr = mod->getFunction(query_name);
594 CHECK(!query_func_ptr);
596 query_func_ptr = Function::Create(
598 GlobalValue::ExternalLinkage,
599 "query_group_by_template",
601 query_func_ptr->setCallingConv(CallingConv::C);
602 query_func_ptr->setAttributes(query_func_params.attributeList());
603 query_func_params.setNames(query_func_ptr->arg_begin());
605 auto bb_entry = BasicBlock::Create(mod->getContext(),
".entry", query_func_ptr, 0);
607 BasicBlock::Create(mod->getContext(),
".loop.preheader", query_func_ptr, 0);
608 auto bb_forbody = BasicBlock::Create(mod->getContext(),
".forbody", query_func_ptr, 0);
610 BasicBlock::Create(mod->getContext(),
"._crit_edge", query_func_ptr, 0);
611 auto bb_exit = BasicBlock::Create(mod->getContext(),
".exit", query_func_ptr, 0);
614 llvm::Value*
const row_count_ptr =
get_arg_by_name(query_func_ptr,
"row_count_ptr");
615 LoadInst* row_count =
new LoadInst(
618 row_count->setName(
"row_count");
620 llvm::Value*
const max_matched_ptr =
get_arg_by_name(query_func_ptr,
"max_matched_ptr");
621 LoadInst* max_matched =
new LoadInst(
625 auto crt_matched_ptr =
new AllocaInst(i32_type, 0,
"crt_matched", bb_entry);
626 auto old_total_matched_ptr =
new AllocaInst(i32_type, 0,
"old_total_matched", bb_entry);
627 CallInst*
pos_start = CallInst::Create(func_pos_start,
"", bb_entry);
628 pos_start->setCallingConv(CallingConv::C);
630 llvm::AttributeList pos_start_pal;
633 CallInst*
pos_step = CallInst::Create(func_pos_step,
"", bb_entry);
634 pos_step->setCallingConv(CallingConv::C);
636 llvm::AttributeList pos_step_pal;
637 pos_step->setAttributes(pos_step_pal);
639 CallInst* group_buff_idx_call = CallInst::Create(func_group_buff_idx,
"", bb_entry);
640 group_buff_idx_call->setCallingConv(CallingConv::C);
641 group_buff_idx_call->setTailCall(
true);
642 llvm::AttributeList group_buff_idx_pal;
643 group_buff_idx_call->setAttributes(group_buff_idx_pal);
646 auto*
const group_by_buffers =
get_arg_by_name(query_func_ptr,
"group_by_buffers");
647 const PointerType* Ty = dyn_cast<PointerType>(group_by_buffers->getType());
650 Value* varlen_output_buffer{
nullptr};
654 auto varlen_output_buffer_gep = GetElementPtrInst::Create(
655 Ty->getPointerElementType(),
657 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 0),
660 varlen_output_buffer =
662 varlen_output_buffer_gep,
663 "varlen_output_buffer",
670 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mod->getContext()), 1),
671 "group_buff_idx_varlen_offset",
674 varlen_output_buffer =
675 ConstantPointerNull::get(Type::getInt64PtrTy(mod->getContext()));
677 CHECK(varlen_output_buffer);
679 CastInst* pos_start_i64 =
new SExtInst(
pos_start, i64_type,
"", bb_entry);
680 GetElementPtrInst* group_by_buffers_gep = GetElementPtrInst::Create(
681 Ty->getPointerElementType(), group_by_buffers,
group_buff_idx,
"", bb_entry);
683 group_by_buffers_gep,
687 col_buffer->setName(
"col_buffer");
690 llvm::ConstantInt* shared_mem_bytes_lv =
693 llvm::CallInst* result_buffer =
694 CallInst::Create(func_init_shared_mem,
695 std::vector<llvm::Value*>{col_buffer, shared_mem_bytes_lv},
699 ICmpInst* enter_or_not =
700 new ICmpInst(*bb_entry, ICmpInst::ICMP_SLT, pos_start_i64, row_count,
"");
701 BranchInst::Create(bb_preheader, bb_exit, enter_or_not, bb_entry);
704 CastInst* pos_step_i64 =
new SExtInst(
pos_step, i64_type,
"", bb_preheader);
705 BranchInst::Create(bb_forbody, bb_preheader);
709 PHINode* pos = PHINode::Create(i64_type, check_scan_limit ? 3 : 2,
"pos", bb_forbody);
711 std::vector<Value*> row_process_params;
712 row_process_params.push_back(result_buffer);
713 row_process_params.push_back(varlen_output_buffer);
714 row_process_params.push_back(crt_matched_ptr);
715 row_process_params.push_back(
get_arg_by_name(query_func_ptr,
"total_matched"));
716 row_process_params.push_back(old_total_matched_ptr);
717 row_process_params.push_back(max_matched_ptr);
718 row_process_params.push_back(
get_arg_by_name(query_func_ptr,
"agg_init_val"));
719 row_process_params.push_back(pos);
720 row_process_params.push_back(
get_arg_by_name(query_func_ptr,
"frag_row_off_ptr"));
721 row_process_params.push_back(row_count_ptr);
722 if (hoist_literals) {
723 row_process_params.push_back(
get_arg_by_name(query_func_ptr,
"literals"));
725 if (check_scan_limit) {
726 new StoreInst(ConstantInt::get(IntegerType::get(mod->getContext(), 32), 0),
731 CallInst::Create(func_row_process, row_process_params,
"", bb_forbody);
734 llvm::AttributeList row_process_pal;
739 auto func_sync_warp_protected = mod->getFunction(
"sync_warp_protected");
740 CHECK(func_sync_warp_protected);
741 CallInst::Create(func_sync_warp_protected,
742 std::vector<llvm::Value*>{pos, row_count},
748 BinaryOperator::Create(Instruction::Add, pos, pos_step_i64,
"", bb_forbody);
749 ICmpInst* loop_or_exit =
750 new ICmpInst(*bb_forbody, ICmpInst::ICMP_SLT, pos_inc, row_count,
"");
751 if (check_scan_limit) {
757 auto filter_match = BasicBlock::Create(
758 mod->getContext(),
"filter_match", query_func_ptr, bb_crit_edge);
759 llvm::Value* new_total_matched =
761 old_total_matched_ptr,
766 BinaryOperator::CreateAdd(new_total_matched, crt_matched,
"", filter_match);
767 CHECK(new_total_matched);
768 ICmpInst* limit_not_reached =
new ICmpInst(*filter_match,
772 "limit_not_reached");
776 BinaryOperator::Create(
777 BinaryOperator::And, loop_or_exit, limit_not_reached,
"", filter_match),
779 auto filter_nomatch = BasicBlock::Create(
780 mod->getContext(),
"filter_nomatch", query_func_ptr, bb_crit_edge);
781 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, filter_nomatch);
782 ICmpInst* crt_matched_nz =
new ICmpInst(
783 *bb_forbody, ICmpInst::ICMP_NE, crt_matched, ConstantInt::get(i32_type, 0),
"");
784 BranchInst::Create(filter_match, filter_nomatch, crt_matched_nz, bb_forbody);
785 pos->addIncoming(pos_start_i64, bb_preheader);
786 pos->addIncoming(pos_pre, filter_match);
787 pos->addIncoming(pos_pre, filter_nomatch);
789 pos->addIncoming(pos_start_i64, bb_preheader);
790 pos->addIncoming(pos_pre, bb_forbody);
791 BranchInst::Create(bb_forbody, bb_crit_edge, loop_or_exit, bb_forbody);
795 BranchInst::Create(bb_exit, bb_crit_edge);
798 CallInst::Create(func_write_back,
799 std::vector<Value*>{col_buffer, result_buffer, shared_mem_bytes_lv},
803 ReturnInst::Create(mod->getContext(), bb_exit);
806 pos_pre->replaceAllUsesWith(pos_inc);
809 if (verifyFunction(*query_func_ptr, &llvm::errs())) {
810 LOG(
FATAL) <<
"Generated invalid code. ";
llvm::Function * pos_start(llvm::Module *mod)
llvm::Function * row_process(llvm::Module *mod, const size_t aggr_col_count, const bool hoist_literals)
llvm::Function * pos_step(llvm::Module *mod)
bool hasVarlenOutput() const
size_t getSharedMemorySize() const
llvm::Function * group_buff_idx(llvm::Module *mod)
#define LLVM_ALIGN(alignment)
bool isSharedMemoryUsed() const
llvm::Value * get_arg_by_name(llvm::Function *func, const std::string &name)
bool isWarpSyncRequired(const ExecutorDeviceType) const
llvm::Type * get_pointer_element_type(llvm::Value *value)