24 #include <llvm/IR/Instructions.h>
47 return llvm::Type::getVoidTy(ctx);
59 return llvm::Type::getFloatPtrTy(ctx);
62 return llvm::Type::getDoublePtrTy(ctx);
68 return llvm::PointerType::get(llvm::PointerType::get(
get_int_type(64, ctx), 0), 0);
71 LOG(
FATAL) <<
"Argument type not supported: " <<
static_cast<int>(
type);
85 return llvm::ICmpInst::ICMP_EQ;
88 return llvm::ICmpInst::ICMP_NE;
91 LOG(
FATAL) <<
"Invalid predicate: " <<
static_cast<int>(predicate);
95 return llvm::ICmpInst::ICMP_EQ;
102 return llvm::Instruction::Add;
105 return llvm::Instruction::Mul;
108 LOG(
FATAL) <<
"Invalid binary operator: " <<
static_cast<int>(op);
112 return llvm::Instruction::Add;
119 return llvm::Instruction::Trunc;
122 return llvm::Instruction::SExt;
125 return llvm::Instruction::BitCast;
128 LOG(
FATAL) <<
"Invalid cast operator: " <<
static_cast<int>(op);
132 return llvm::Instruction::SExt;
141 llvm::Function* func,
145 auto& ctx = cgen_state->context_;
146 const auto early_return = llvm::BasicBlock::Create(ctx,
".early_return", func, 0);
147 const auto do_reduction = llvm::BasicBlock::Create(ctx,
".do_reduction", func, 0);
148 cgen_state->ir_builder_.CreateCondBr(cond, early_return, do_reduction);
149 cgen_state->ir_builder_.SetInsertPoint(early_return);
151 if (func->getReturnType()->isVoidTy()) {
152 cgen_state->ir_builder_.CreateRetVoid();
155 cgen_state->ir_builder_.CreateRet(error_code);
158 cgen_state->ir_builder_.SetInsertPoint(do_reduction);
163 const std::unordered_map<const Value*, llvm::Value*>& m) {
165 const auto it = m.find(val);
166 CHECK(it != m.end());
175 const Function*
function,
176 const std::unordered_map<const Function*, llvm::Function*>&
f) {
177 const auto it = f.find(
function);
178 CHECK(it != f.end()) << function->name() <<
" not found.";
185 const std::vector<const Value*>
args,
186 const std::unordered_map<const Value*, llvm::Value*>& m) {
189 args.begin(), args.end(), std::back_inserter(llvm_args), [&m](
const Value* value) {
196 Function* ir_reduce_loop,
198 std::unordered_map<const Value*, llvm::Value*>& m,
199 const std::unordered_map<const Function*, llvm::Function*>&
f);
203 const Function*
function,
204 llvm::Function* llvm_function,
206 std::unordered_map<const Value*, llvm::Value*>& m,
207 const std::unordered_map<const Function*, llvm::Function*>&
f) {
210 auto& ctx = cgen_state->context_;
211 for (
const auto& instr : body) {
212 const auto instr_ptr = instr.get();
213 llvm::Value* translated{
nullptr};
214 if (
auto gep = dynamic_cast<const GetElementPtr*>(instr_ptr)) {
216 translated = cgen_state->ir_builder_.CreateGEP(
217 base->getType()->getScalarType()->getPointerElementType(),
221 }
else if (
auto load = dynamic_cast<const Load*>(instr_ptr)) {
223 translated = cgen_state->ir_builder_.CreateLoad(
224 value->getType()->getPointerElementType(), value,
load->label());
225 }
else if (
auto icmp = dynamic_cast<const ICmp*>(instr_ptr)) {
226 translated = cgen_state->ir_builder_.CreateICmp(
llvm_predicate(icmp->predicate()),
230 }
else if (
auto binary_operator = dynamic_cast<const BinaryOperator*>(instr_ptr)) {
232 cgen_state->ir_builder_.CreateBinOp(
llvm_binary_op(binary_operator->op()),
235 binary_operator->label());
236 }
else if (
auto cast = dynamic_cast<const Cast*>(instr_ptr)) {
237 translated = cgen_state->ir_builder_.CreateCast(
llvm_cast_op(cast->op()),
241 }
else if (
auto ret = dynamic_cast<const Ret*>(instr_ptr)) {
243 cgen_state->ir_builder_.CreateRet(
mapped_value(ret->value(), m));
245 cgen_state->ir_builder_.CreateRetVoid();
247 }
else if (
auto call = dynamic_cast<const Call*>(instr_ptr)) {
249 const auto args = call->arguments();
252 std::back_inserter(llvm_args),
254 if (call->callee()) {
255 translated = cgen_state->ir_builder_.CreateCall(
258 translated = cgen_state->emitCall(call->callee_name(),
llvm_args);
260 }
else if (
auto external_call = dynamic_cast<const ExternalCall*>(instr_ptr)) {
261 translated = cgen_state->emitExternalCall(external_call->callee_name(),
263 llvm_args(external_call->arguments(), m));
264 }
else if (
auto alloca = dynamic_cast<const Alloca*>(instr_ptr)) {
265 translated = cgen_state->ir_builder_.CreateAlloca(
269 }
else if (
auto memcpy = dynamic_cast<const MemCpy*>(instr_ptr)) {
270 cgen_state->ir_builder_.CreateMemCpy(
mapped_value(memcpy->dest(), m),
275 }
else if (
auto ret_early = dynamic_cast<const ReturnEarly*>(instr_ptr)) {
280 }
else if (
auto for_loop = dynamic_cast<const For*>(instr_ptr)) {
283 LOG(
FATAL) <<
"Instruction not supported yet";
286 const auto it_ok = m.emplace(instr_ptr, translated);
294 Function* ir_reduce_loop,
296 std::unordered_map<const Value*, llvm::Value*>& m,
297 const std::unordered_map<const Function*, llvm::Function*>&
f) {
300 const auto bb_entry = cgen_state->ir_builder_.GetInsertBlock();
301 auto& ctx = cgen_state->context_;
302 const auto i64_type =
get_int_type(64, cgen_state->context_);
306 const auto iteration_count =
307 cgen_state->ir_builder_.CreateSub(end_index, start_index,
"iteration_count");
308 const auto upper_bound = cgen_state->ir_builder_.CreateSExt(iteration_count, i64_type);
310 llvm::BasicBlock::Create(ctx,
".exit",
mapped_function(ir_reduce_loop, f));
314 [
upper_bound](
const std::vector<llvm::Value*>& v) {
327 [cgen_state, for_loop, ir_reduce_loop, &
f, &m, &reduction_code](
328 const std::vector<llvm::Value*>& iterators) {
329 const auto loop_body_bb = llvm::BasicBlock::Create(
330 cgen_state->context_,
332 cgen_state->ir_builder_.GetInsertBlock()->getParent());
333 cgen_state->ir_builder_.SetInsertPoint(loop_body_bb);
335 const auto loop_iter =
336 cgen_state->ir_builder_.CreateTrunc(iterators.back(),
338 "relative_entry_idx");
339 m.emplace(for_loop->iter(), loop_iter);
351 cgen_state->ir_builder_.SetInsertPoint(bb_entry);
352 cgen_state->ir_builder_.CreateBr(bb_loop_body);
353 cgen_state->ir_builder_.SetInsertPoint(bb_exit);
359 const auto bb_entry =
360 llvm::BasicBlock::Create(cgen_state->
context_,
".entry",
function, 0);
367 llvm::Function* llvm_function,
369 const std::unordered_map<const Function*, llvm::Function*>&
f) {
374 std::unordered_map<const Value*, llvm::Value*> m;
375 auto llvm_arg_it = llvm_function->arg_begin();
376 for (
size_t arg_idx = 0; arg_idx <
function->arg_types().size(); ++arg_idx) {
377 llvm::Value* llvm_arg = &(*llvm_arg_it);
378 const auto it_ok = m.emplace(function->arg(arg_idx), llvm_arg);
383 for (
const auto& constant : function->constants()) {
384 llvm::Value* constant_llvm{
nullptr};
385 switch (constant->type()) {
388 cgen_state->llInt<int8_t>(
static_cast<ConstantInt*
>(constant.get())->value());
392 constant_llvm = cgen_state->llInt<int32_t>(
393 static_cast<ConstantInt*
>(constant.get())->value());
397 constant_llvm = cgen_state->llInt<int64_t>(
398 static_cast<ConstantInt*
>(constant.get())->value());
402 constant_llvm = cgen_state->llFp(
403 static_cast<float>(static_cast<ConstantFP*>(constant.get())->value()));
408 cgen_state->llFp(static_cast<ConstantFP*>(constant.get())->value());
412 LOG(
FATAL) <<
"Constant type not supported: "
413 <<
static_cast<int>(constant->type());
416 CHECK(constant_llvm);
417 const auto it_ok = m.emplace(constant.get(), constant_llvm);
420 translate_body(function->body(),
function, llvm_function, reduction_code, m,
f);
DEVICE auto upper_bound(ARGS &&...args)
void create_entry_block(llvm::Function *function, CgenState *cgen_state)
void translate_for(const For *for_loop, Function *ir_reduce_loop, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
std::unique_ptr< Function > ir_reduce_loop
void load(Archive &ar, ExplainedQueryHint &query_hint, const unsigned int version)
llvm::IRBuilder ir_builder_
std::vector< llvm::Value * > llvm_args(const std::vector< const Value * > args, const std::unordered_map< const Value *, llvm::Value * > &m)
llvm::Type * get_fp_type(const int width, llvm::LLVMContext &context)
llvm::ICmpInst::Predicate llvm_predicate(const ICmp::Predicate predicate)
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
void translate_function(const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, const std::unordered_map< const Function *, llvm::Function * > &f)
const Value * end() const
const Value * start() const
void verify_function_ir(const llvm::Function *func)
llvm::LLVMContext & context_
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
Type pointee_type(const Type pointer)
llvm::Instruction::CastOps llvm_cast_op(const Cast::CastOp op)
OUTPUT transform(INPUT const &input, FUNC const &func)
llvm::Type * llvm_type(const Type type, llvm::LLVMContext &ctx)
void return_early(llvm::Value *cond, const ReductionCode &reduction_code, llvm::Function *func, llvm::Value *error_code)
llvm::Value * upper_bound
void translate_body(const std::vector< std::unique_ptr< Instruction >> &body, const Function *function, llvm::Function *llvm_function, const ReductionCode &reduction_code, std::unordered_map< const Value *, llvm::Value * > &m, const std::unordered_map< const Function *, llvm::Function * > &f)
#define LLVM_MAYBE_ALIGN(alignment)
torch::Tensor f(torch::Tensor x, torch::Tensor W_target, torch::Tensor b_target)
llvm::Function * mapped_function(const Function *function, const std::unordered_map< const Function *, llvm::Function * > &f)
llvm::Value * mapped_value(const Value *val, const std::unordered_map< const Value *, llvm::Value * > &m)
llvm::BinaryOperator::BinaryOps llvm_binary_op(const BinaryOperator::BinaryOp op)