23 #include <llvm/ExecutionEngine/MCJIT.h>
24 #include <llvm/IR/BasicBlock.h>
25 #include <llvm/IR/Function.h>
26 #include <llvm/IR/IRBuilder.h>
27 #include <llvm/IR/Module.h>
28 #include <llvm/IR/Type.h>
29 #include <llvm/IR/Verifier.h>
30 #include <llvm/Support/TargetSelect.h>
31 #include <llvm/Support/raw_os_ostream.h>
39 printf(
"%ld, %ld, %ld\n", i, j, k);
47 std::stringstream err_ss;
48 llvm::raw_os_ostream err_os(err_ss);
49 if (llvm::verifyFunction(*func, &err_os)) {
50 func->print(llvm::outs());
57 const std::vector<llvm::Value*>
args,
58 llvm::Module* llvm_module,
59 llvm::IRBuilder<>& builder) {
60 std::vector<llvm::Type*> arg_types;
61 for (
const auto arg : args) {
62 arg_types.push_back(arg->getType());
64 auto func_ty = llvm::FunctionType::get(ret_type, arg_types,
false);
65 auto func_p = llvm_module->getOrInsertFunction(fname, func_ty);
67 llvm::Value*
result = builder.CreateCall(func_p, args);
69 CHECK_EQ(result->getType(), ret_type);
74 llvm::Module* llvm_module,
75 const std::vector<JoinLoop>& join_loops) {
76 std::vector<llvm::Type*> argument_types;
78 llvm::FunctionType::get(llvm::Type::getVoidTy(context), argument_types,
false);
79 const auto func = llvm::Function::Create(
80 ft, llvm::Function::ExternalLinkage,
"loop_test_func", llvm_module);
81 const auto entry_bb = llvm::BasicBlock::Create(context,
"entry", func);
82 const auto exit_bb = llvm::BasicBlock::Create(context,
"exit", func);
83 llvm::IRBuilder<> builder(context);
84 builder.SetInsertPoint(exit_bb);
85 builder.CreateRetVoid();
88 [&builder, llvm_module](
const std::vector<llvm::Value*>& iterators) {
89 const auto loop_body_bb = llvm::BasicBlock::Create(
90 builder.getContext(),
"loop_body", builder.GetInsertBlock()->getParent());
91 builder.SetInsertPoint(loop_body_bb);
92 const std::vector<llvm::Value*>
args(iterators.begin() + 1, iterators.end());
94 llvm::Type::getVoidTy(builder.getContext()),
103 builder.SetInsertPoint(entry_bb);
104 builder.CreateBr(loop_body_bb);
110 return std::make_unique<llvm::Module>(
"Nested loops JIT",
g_global_context);
114 std::unique_ptr<llvm::Module>& llvm_module,
115 llvm::Function* func) {
116 llvm::ExecutionEngine* execution_engine{
nullptr};
118 auto init_err = llvm::InitializeNativeTarget();
121 llvm::InitializeAllTargetMCs();
122 llvm::InitializeNativeTargetAsmPrinter();
123 llvm::InitializeNativeTargetAsmParser();
126 llvm::EngineBuilder eb(std::move(llvm_module));
127 eb.setErrorStr(&err_str);
128 eb.setEngineKind(llvm::EngineKind::JIT);
129 llvm::TargetOptions to;
130 to.EnableFastISel =
true;
131 eb.setTargetOptions(to);
132 execution_engine = eb.create();
133 CHECK(execution_engine);
135 execution_engine->finalizeObject();
136 auto native_code = execution_engine->getPointerToFunction(func);
139 return {native_code, std::unique_ptr<llvm::ExecutionEngine>(execution_engine)};
143 const unsigned cond_mask,
145 std::vector<JoinLoop> join_loops;
147 for (
size_t i = 0; i < upper_bounds.size(); ++i) {
148 if (mask & (1 << i)) {
149 const bool cond_is_true = cond_mask & (1 << cond_idx);
150 join_loops.emplace_back(
153 [i, cond_is_true](
const std::vector<llvm::Value*>& v) {
158 ?
ll_int(int64_t(99), g_global_context)
159 :
ll_int(int64_t(-1), g_global_context);
170 join_loops.emplace_back(
173 [i,
upper_bound](
const std::vector<llvm::Value*>& v) {
194 for (
unsigned mask = 0; mask < static_cast<unsigned>(1 <<
upper_bounds.size());
197 for (
unsigned cond_mask = 0; cond_mask < static_cast<unsigned>(1 <<
mask_bitcount);
201 const auto function =
204 reinterpret_cast<int64_t (*)()
>(func_and_ee.first)();
DEVICE auto upper_bound(ARGS &&...args)
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
void verify_function_ir(const llvm::Function *func)
llvm::Value * emit_external_call(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, llvm::Module *llvm_module, llvm::IRBuilder<> &builder)
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
RUNTIME_EXPORT void print_iterators(const int64_t i, const int64_t j, const int64_t k)
std::pair< void *, std::unique_ptr< llvm::ExecutionEngine > > native_codegen(std::unique_ptr< llvm::Module > &llvm_module, llvm::Function *func)
llvm::Value * slot_lookup_result
llvm::ManagedStatic< llvm::LLVMContext > g_global_context
llvm::Value * upper_bound
std::unique_ptr< llvm::Module > create_loop_test_module()
std::vector< JoinLoop > generate_descriptors(const unsigned mask, const unsigned cond_mask, const std::vector< int64_t > &upper_bounds)
llvm::Function * create_loop_test_function(llvm::LLVMContext &context, llvm::Module *llvm_module, const std::vector< JoinLoop > &join_loops)