17 #include <llvm/Transforms/Utils/Cloning.h>
25 llvm::Module* llvm_module,
26 llvm::LLVMContext& context,
28 const std::vector<TargetInfo>& targets,
29 const std::vector<int64_t>& init_agg_values,
30 const size_t executor_id)
31 : executor_id_(executor_id)
32 , module_(llvm_module)
34 , reduction_func_(nullptr)
36 , query_mem_desc_(qmd)
38 , init_agg_values_(init_agg_values) {
102 auto dest_buffer_ptr = &*arg_it;
103 dest_buffer_ptr->setName(
"dest_buffer_ptr");
105 auto src_buffer_ptr = &*arg_it;
106 src_buffer_ptr->setName(
"src_buffer_ptr");
108 auto buffer_size = &*arg_it;
109 buffer_size->setName(
"buffer_size");
114 llvm::IRBuilder<> ir_builder(bb_entry);
120 const auto func_thread_index =
getFunction(
"get_thread_index");
121 const auto thread_idx = ir_builder.CreateCall(func_thread_index, {},
"thread_index");
125 const auto entry_count_i32 =
127 const auto is_thread_inbound =
128 ir_builder.CreateICmpSLT(thread_idx, entry_count,
"is_thread_inbound");
129 ir_builder.CreateCondBr(is_thread_inbound, bb_body, bb_exit);
131 ir_builder.SetInsertPoint(bb_body);
134 auto src_byte_stream = ir_builder.CreatePointerCast(
135 src_buffer_ptr, llvm::Type::getInt8PtrTy(
context_, 0),
"src_byte_stream");
136 const auto dest_byte_stream = ir_builder.CreatePointerCast(
137 dest_buffer_ptr, llvm::Type::getInt8PtrTy(
context_, 0),
"dest_byte_stream");
141 auto rs_reduction_jit = std::make_unique<GpuReductionHelperJIT>(
142 fixup_query_mem_desc,
146 auto reduction_code = rs_reduction_jit->codegen();
147 CHECK(reduction_code.module);
148 reduction_code.module->setDataLayout(
149 "e-p:64:64:64-i1:8:8-i8:8:8-"
150 "i16:16:16-i32:32:32-i64:64:64-"
151 "f32:32:32-f64:64:64-v16:16:16-"
152 "v32:32:32-v64:64:64-v128:128:128-n16:32:64");
153 reduction_code.module->setTargetTriple(
"nvptx64-nvidia-cuda");
155 std::unique_ptr<llvm::Module> owner(reduction_code.module);
156 bool link_error = linker.linkInModule(std::move(owner));
161 auto reduce_one_entry_func =
getFunction(
"reduce_one_entry");
162 bool agg_func_found =
true;
163 while (agg_func_found) {
164 agg_func_found =
false;
165 for (
auto it = llvm::inst_begin(reduce_one_entry_func);
166 it != llvm::inst_end(reduce_one_entry_func);
168 if (!llvm::isa<llvm::CallInst>(*it)) {
171 auto& func_call = llvm::cast<llvm::CallInst>(*it);
174 std::string_view func_name_str = *func_name;
175 if (func_name_str.substr(0, 4) ==
"agg_") {
176 if (func_name_str.substr(func_name_str.length() - 7) ==
"_shared") {
179 agg_func_found =
true;
180 std::vector<llvm::Value*>
args;
181 args.reserve(func_call.getNumOperands());
182 for (
size_t i = 0; i < func_call.getNumOperands() - 1; ++i) {
183 args.push_back(func_call.getArgOperand(i));
185 auto gpu_agg_func =
getFunction(std::string(func_name_str) +
"_shared");
186 llvm::ReplaceInstWithInst(&func_call,
187 llvm::CallInst::Create(gpu_agg_func, args,
""));
193 const auto reduce_one_entry_idx_func =
getFunction(
"reduce_one_entry_idx");
194 CHECK(reduce_one_entry_idx_func);
199 const auto null_ptr_ll =
200 llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(
context_, 0));
201 const auto thread_idx_i32 = ir_builder.CreateCast(
203 ir_builder.CreateCall(reduce_one_entry_idx_func,
212 ir_builder.CreateBr(bb_exit);
213 llvm::ReturnInst::Create(
context_, bb_exit);
222 llvm::IRBuilder<>& ir_builder,
223 const size_t slot_idx,
225 llvm::Value* dest_byte_stream,
226 llvm::Value* byte_offset) {
229 auto ptr_type = [&context](
const size_t slot_bytes,
const SQLTypeInfo& sql_type) {
230 if (slot_bytes ==
sizeof(int32_t)) {
231 return llvm::Type::getInt32PtrTy(context, 3);
233 CHECK(slot_bytes ==
sizeof(int64_t));
234 return llvm::Type::getInt64PtrTy(context, 3);
237 return llvm::Type::getInt32PtrTy(context, 3);
240 const auto casted_dest_slot_address = ir_builder.CreatePointerCast(
241 ir_builder.CreateGEP(
242 dest_byte_stream->getType()->getScalarType()->getPointerElementType(),
245 ptr_type(slot_bytes, sql_type),
247 return casted_dest_slot_address;
263 CHECK(!fixup_query_mem_desc.didOutputColumnar());
264 CHECK(fixup_query_mem_desc.hasKeylessHash());
271 llvm::IRBuilder<> ir_builder(bb_entry);
272 const auto func_thread_index =
getFunction(
"get_thread_index");
273 const auto thread_idx = ir_builder.CreateCall(func_thread_index, {},
"thread_index");
276 const auto declare_smem_func =
getFunction(
"declare_dynamic_shared_memory");
277 const auto shared_mem_buffer =
278 ir_builder.CreateCall(declare_smem_func, {},
"shared_mem_buffer");
280 const auto entry_count =
ll_int(fixup_query_mem_desc.getEntryCount(),
context_);
281 const auto is_thread_inbound =
282 ir_builder.CreateICmpSLT(thread_idx, entry_count,
"is_thread_inbound");
283 ir_builder.CreateCondBr(is_thread_inbound, bb_body, bb_exit);
285 ir_builder.SetInsertPoint(bb_body);
287 const auto row_size_bytes =
ll_int(fixup_query_mem_desc.getRowWidth(),
context_);
288 auto byte_offset_ll = ir_builder.CreateMul(row_size_bytes, thread_idx,
"byte_offset");
290 const auto dest_byte_stream = ir_builder.CreatePointerCast(
291 shared_mem_buffer, llvm::Type::getInt8PtrTy(
context_),
"dest_byte_stream");
294 const auto& col_slot_context = fixup_query_mem_desc.getColSlotContext();
295 size_t init_agg_idx = 0;
296 for (
size_t target_logical_idx = 0; target_logical_idx <
targets_.size();
297 ++target_logical_idx) {
298 const auto& target_info =
targets_[target_logical_idx];
299 const auto& slots_for_target = col_slot_context.getSlotsForCol(target_logical_idx);
300 for (
size_t slot_idx = slots_for_target.front(); slot_idx <= slots_for_target.back();
302 const auto slot_size = fixup_query_mem_desc.getPaddedSlotWidthBytes(slot_idx);
305 fixup_query_mem_desc,
312 llvm::Value* init_value_ll =
nullptr;
313 if (slot_size ==
sizeof(int32_t)) {
316 }
else if (slot_size ==
sizeof(int64_t)) {
322 ir_builder.CreateStore(init_value_ll, casted_dest_slot_address);
325 if (slot_idx != (col_slot_context.getSlotCount() - 1)) {
326 byte_offset_ll = ir_builder.CreateAdd(
327 byte_offset_ll,
ll_int(static_cast<size_t>(slot_size),
context_));
332 ir_builder.CreateBr(bb_exit);
334 ir_builder.SetInsertPoint(bb_exit);
338 ir_builder.CreateRet(shared_mem_buffer);
342 std::vector<llvm::Type*> input_arguments;
343 input_arguments.push_back(llvm::Type::getInt64PtrTy(
context_));
344 input_arguments.push_back(llvm::Type::getInt64PtrTy(
context_));
345 input_arguments.push_back(llvm::Type::getInt32Ty(
context_));
347 llvm::FunctionType* ft =
348 llvm::FunctionType::get(llvm::Type::getVoidTy(
context_), input_arguments,
false);
349 const auto reduction_function = llvm::Function::Create(
350 ft, llvm::Function::ExternalLinkage,
"reduce_from_smem_to_gmem",
module_);
351 return reduction_function;
355 std::vector<llvm::Type*> input_arguments;
356 input_arguments.push_back(
357 llvm::Type::getInt64PtrTy(
context_));
358 input_arguments.push_back(llvm::Type::getInt32Ty(
context_));
360 llvm::FunctionType* ft = llvm::FunctionType::get(
361 llvm::Type::getInt64PtrTy(
context_), input_arguments,
false);
362 const auto init_function = llvm::Function::Create(
363 ft, llvm::Function::ExternalLinkage,
"init_smem_func",
module_);
364 return init_function;
368 const auto function =
module_->getFunction(func_name);
369 CHECK(
function) << func_name <<
" is not found in the module.";
380 const std::string& target_func_name,
381 llvm::Function* replace_func) {
382 for (
auto it = llvm::inst_begin(main_func), e = llvm::inst_end(main_func); it != e;
384 if (!llvm::isa<llvm::CallInst>(*it)) {
387 auto& instruction = llvm::cast<llvm::CallInst>(*it);
389 if (inst_func_name && *inst_func_name == target_func_name) {
390 std::vector<llvm::Value*>
args;
391 for (
size_t i = 0; i < instruction.getNumOperands() - 1; ++i) {
392 args.push_back(instruction.getArgOperand(i));
394 llvm::ReplaceInstWithInst(&instruction,
395 llvm::CallInst::Create(replace_func, args,
""));
399 UNREACHABLE() <<
"Target function " << target_func_name <<
" was not found in "
400 << replace_func->getName().str();
std::optional< std::string_view > getCalledFunctionName(llvm::CallInst &call_inst)
size_t getEntryCount() const
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
bool hasKeylessHash() const
const QueryMemoryDescriptor query_mem_desc_
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
llvm::Function * createInitFunction() const
const SQLTypeInfo get_compact_type(const TargetInfo &target)
void verify_function_ir(const llvm::Function *func)
llvm::LLVMContext & context_
void codegenInitialization()
GpuSharedMemCodeBuilder(llvm::Module *module, llvm::LLVMContext &context, const QueryMemoryDescriptor &qmd, const std::vector< TargetInfo > &targets, const std::vector< int64_t > &init_agg_values, const size_t executor_id)
const int8_t getPaddedSlotWidthBytes(const size_t slot_idx) const
std::string toString() const
QueryDescriptionType getQueryDescriptionType() const
std::vector< int64_t > initialize_target_values_for_storage(const std::vector< TargetInfo > &targets)
const std::vector< int64_t > init_agg_values_
llvm::Value * codegen_smem_dest_slot_ptr(llvm::LLVMContext &context, const QueryMemoryDescriptor &query_mem_desc, llvm::IRBuilder<> &ir_builder, const size_t slot_idx, const TargetInfo &target_info, llvm::Value *dest_byte_stream, llvm::Value *byte_offset)
std::string serialize_llvm_object(const T *llvm_obj)
llvm::Function * init_func_
static QueryMemoryDescriptor fixupQueryMemoryDescriptor(const QueryMemoryDescriptor &)
bool didOutputColumnar() const
const std::vector< TargetInfo > targets_
void replace_called_function_with(llvm::Function *main_func, const std::string &target_func_name, llvm::Function *replace_func)
#define DEBUG_TIMER(name)
__device__ void sync_threadblock()
llvm::Function * getFunction(const std::string &func_name) const
void injectFunctionsInto(llvm::Function *query_func)
llvm::Function * reduction_func_
llvm::Function * createReductionFunction() const