OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
LogicalIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 #include "GeoOperators/Codegen.h"
20 #include "NullableValue.h"
21 
22 #include <llvm/IR/MDBuilder.h>
23 
24 namespace {
25 
27  auto is_div = [](const Analyzer::Expr* e) -> bool {
28  auto bin_oper = dynamic_cast<const Analyzer::BinOper*>(e);
29  if (bin_oper && bin_oper->get_optype() == kDIVIDE) {
30  auto rhs = bin_oper->get_right_operand();
31  auto rhs_constant = dynamic_cast<const Analyzer::Constant*>(rhs);
32  if (!rhs_constant || rhs_constant->get_is_null()) {
33  return true;
34  }
35  const auto& datum = rhs_constant->get_constval();
36  const auto& ti = rhs_constant->get_type_info();
37  const auto type = ti.is_decimal() ? decimal_to_int_type(ti) : ti.get_type();
38  if ((type == kBOOLEAN && datum.boolval == 0) ||
39  (type == kTINYINT && datum.tinyintval == 0) ||
40  (type == kSMALLINT && datum.smallintval == 0) ||
41  (type == kINT && datum.intval == 0) ||
42  (type == kBIGINT && datum.bigintval == 0LL) ||
43  (type == kFLOAT && datum.floatval == 0.0) ||
44  (type == kDOUBLE && datum.doubleval == 0.0)) {
45  return true;
46  }
47  }
48  return false;
49  };
50  std::list<const Analyzer::Expr*> binoper_list;
51  expr->find_expr(is_div, binoper_list);
52  return !binoper_list.empty();
53 }
54 
55 bool should_defer_eval(const std::shared_ptr<Analyzer::Expr> expr) {
56  if (std::dynamic_pointer_cast<Analyzer::LikeExpr>(expr)) {
57  return true;
58  }
59  if (std::dynamic_pointer_cast<Analyzer::RegexpExpr>(expr)) {
60  return true;
61  }
62  if (std::dynamic_pointer_cast<Analyzer::FunctionOper>(expr)) {
63  return true;
64  }
65  if (!std::dynamic_pointer_cast<Analyzer::BinOper>(expr)) {
66  return false;
67  }
68  const auto bin_expr = std::static_pointer_cast<Analyzer::BinOper>(expr);
69  if (contains_unsafe_division(bin_expr.get())) {
70  return true;
71  }
72  if (bin_expr->is_bbox_intersect_oper()) {
73  return false;
74  }
75  const auto rhs = bin_expr->get_right_operand();
76  return rhs->get_type_info().is_array();
77 }
78 
80  Likelihood truth{1.0};
81  auto likelihood_expr = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
82  if (likelihood_expr) {
83  return Likelihood(likelihood_expr->get_likelihood());
84  }
85  auto u_oper = dynamic_cast<const Analyzer::UOper*>(expr);
86  if (u_oper) {
87  Likelihood oper_likelihood = get_likelihood(u_oper->get_operand());
88  if (oper_likelihood.isInvalid()) {
89  return Likelihood();
90  }
91  if (u_oper->get_optype() == kNOT) {
92  return truth - oper_likelihood;
93  }
94  return oper_likelihood;
95  }
96  auto bin_oper = dynamic_cast<const Analyzer::BinOper*>(expr);
97  if (bin_oper) {
98  auto lhs = bin_oper->get_left_operand();
99  auto rhs = bin_oper->get_right_operand();
100  Likelihood lhs_likelihood = get_likelihood(lhs);
101  Likelihood rhs_likelihood = get_likelihood(rhs);
102  if (lhs_likelihood.isInvalid() && rhs_likelihood.isInvalid()) {
103  return Likelihood();
104  }
105  const auto optype = bin_oper->get_optype();
106  if (optype == kOR) {
107  auto both_false = (truth - lhs_likelihood) * (truth - rhs_likelihood);
108  return truth - both_false;
109  }
110  if (optype == kAND) {
111  return lhs_likelihood * rhs_likelihood;
112  }
113  return (lhs_likelihood + rhs_likelihood) / 2.0;
114  }
115 
116  return Likelihood();
117 }
118 
119 Weight get_weight(const Analyzer::Expr* expr, int depth = 0) {
120  auto like_expr = dynamic_cast<const Analyzer::LikeExpr*>(expr);
121  if (like_expr) {
122  // heavy weight expr, start valid weight propagation
123  return Weight((like_expr->get_is_simple()) ? 200 : 1000);
124  }
125  auto regexp_expr = dynamic_cast<const Analyzer::RegexpExpr*>(expr);
126  if (regexp_expr) {
127  // heavy weight expr, start valid weight propagation
128  return Weight(2000);
129  }
130  auto u_oper = dynamic_cast<const Analyzer::UOper*>(expr);
131  if (u_oper) {
132  auto weight = get_weight(u_oper->get_operand(), depth + 1);
133  return weight + 1;
134  }
135  auto bin_oper = dynamic_cast<const Analyzer::BinOper*>(expr);
136  if (bin_oper) {
137  auto lhs = bin_oper->get_left_operand();
138  auto rhs = bin_oper->get_right_operand();
139  auto lhs_weight = get_weight(lhs, depth + 1);
140  auto rhs_weight = get_weight(rhs, depth + 1);
141  if (rhs->get_type_info().is_array()) {
142  // heavy weight expr, start valid weight propagation
143  rhs_weight = rhs_weight + Weight(100);
144  }
145  auto weight = lhs_weight + rhs_weight;
146  return weight + 1;
147  }
148 
149  if (depth > 4) {
150  return Weight(1);
151  }
152 
153  return Weight();
154 }
155 
156 } // namespace
157 
159  std::vector<Analyzer::Expr*>& primary_quals,
160  std::vector<Analyzer::Expr*>& deferred_quals,
161  const PlanState::HoistedFiltersSet& hoisted_quals) {
162  for (auto expr : ra_exe_unit.simple_quals) {
163  if (hoisted_quals.find(expr) != hoisted_quals.end()) {
164  continue;
165  }
166  if (should_defer_eval(expr)) {
167  deferred_quals.push_back(expr.get());
168  continue;
169  }
170  primary_quals.push_back(expr.get());
171  }
172 
173  bool short_circuit = false;
174 
175  for (auto expr : ra_exe_unit.quals) {
176  if (hoisted_quals.find(expr) != hoisted_quals.end()) {
177  continue;
178  }
179 
180  if (get_likelihood(expr.get()) < 0.10 && !contains_unsafe_division(expr.get())) {
181  if (!short_circuit) {
182  primary_quals.push_back(expr.get());
183  short_circuit = true;
184  continue;
185  }
186  }
187  if (short_circuit || should_defer_eval(expr)) {
188  deferred_quals.push_back(expr.get());
189  continue;
190  }
191  primary_quals.push_back(expr.get());
192  }
193 
194  return short_circuit;
195 }
196 
198  const CompilationOptions& co) {
200  const auto optype = bin_oper->get_optype();
201  auto lhs = bin_oper->get_left_operand();
202  auto rhs = bin_oper->get_right_operand();
203 
204  if (contains_unsafe_division(rhs)) {
205  // rhs contains a possible div-by-0: short-circuit
206  } else if (contains_unsafe_division(lhs)) {
207  // lhs contains a possible div-by-0: swap and short-circuit
208  std::swap(rhs, lhs);
209  } else if (((optype == kOR && get_likelihood(lhs) > 0.90) ||
210  (optype == kAND && get_likelihood(lhs) < 0.10)) &&
211  get_weight(rhs) > 10) {
212  // short circuit if we're likely to see either (trueA || heavyB) or (falseA && heavyB)
213  } else if (((optype == kOR && get_likelihood(rhs) > 0.90) ||
214  (optype == kAND && get_likelihood(rhs) < 0.10)) &&
215  get_weight(lhs) > 10) {
216  // swap and short circuit if we're likely to see either (heavyA || trueB) or (heavyA
217  // && falseB)
218  std::swap(rhs, lhs);
219  } else {
220  // no motivation to short circuit
221  return nullptr;
222  }
223 
224  const auto& ti = bin_oper->get_type_info();
225  auto lhs_lv = codegen(lhs, true, co).front();
226 
227  // Here the linear control flow will diverge and expressions cached during the
228  // code branch code generation (currently just column decoding) are not going
229  // to be available once we're done generating the short-circuited logic.
230  // Take a snapshot of the cache with FetchCacheAnchor and restore it once
231  // the control flow converges.
233 
234  auto rhs_bb = llvm::BasicBlock::Create(
236  auto ret_bb = llvm::BasicBlock::Create(
238  llvm::BasicBlock* nullcheck_ok_bb{nullptr};
239  llvm::BasicBlock* nullcheck_fail_bb{nullptr};
240 
241  if (!ti.get_notnull()) {
242  // need lhs nullcheck before short circuiting
243  nullcheck_ok_bb = llvm::BasicBlock::Create(
244  cgen_state_->context_, "nullcheck_ok_bb", cgen_state_->current_func_);
245  nullcheck_fail_bb = llvm::BasicBlock::Create(
246  cgen_state_->context_, "nullcheck_fail_bb", cgen_state_->current_func_);
247  if (lhs_lv->getType()->isIntegerTy(1)) {
248  lhs_lv = cgen_state_->castToTypeIn(lhs_lv, 8);
249  }
250  auto lhs_nullcheck =
251  cgen_state_->ir_builder_.CreateICmpEQ(lhs_lv, cgen_state_->inlineIntNull(ti));
252  cgen_state_->ir_builder_.CreateCondBr(
253  lhs_nullcheck, nullcheck_fail_bb, nullcheck_ok_bb);
254  cgen_state_->ir_builder_.SetInsertPoint(nullcheck_ok_bb);
255  }
256 
257  auto sc_check_bb = cgen_state_->ir_builder_.GetInsertBlock();
258  auto cnst_lv = llvm::ConstantInt::get(lhs_lv->getType(), (optype == kOR));
259  // Branch to codegen rhs if NOT getting (true || rhs) or (false && rhs), likelihood of
260  // the branch is < 0.10
261  cgen_state_->ir_builder_.CreateCondBr(
262  cgen_state_->ir_builder_.CreateICmpNE(lhs_lv, cnst_lv),
263  rhs_bb,
264  ret_bb,
265  llvm::MDBuilder(cgen_state_->context_).createBranchWeights(10, 90));
266 
267  // Codegen rhs when unable to short circuit.
268  cgen_state_->ir_builder_.SetInsertPoint(rhs_bb);
269  auto rhs_lv = codegen(rhs, true, co).front();
270  if (!ti.get_notnull()) {
271  // need rhs nullcheck as well
272  if (rhs_lv->getType()->isIntegerTy(1)) {
273  rhs_lv = cgen_state_->castToTypeIn(rhs_lv, 8);
274  }
275  auto rhs_nullcheck =
276  cgen_state_->ir_builder_.CreateICmpEQ(rhs_lv, cgen_state_->inlineIntNull(ti));
277  cgen_state_->ir_builder_.CreateCondBr(rhs_nullcheck, nullcheck_fail_bb, ret_bb);
278  } else {
279  cgen_state_->ir_builder_.CreateBr(ret_bb);
280  }
281  auto rhs_codegen_bb = cgen_state_->ir_builder_.GetInsertBlock();
282 
283  if (!ti.get_notnull()) {
284  cgen_state_->ir_builder_.SetInsertPoint(nullcheck_fail_bb);
285  cgen_state_->ir_builder_.CreateBr(ret_bb);
286  }
287 
288  cgen_state_->ir_builder_.SetInsertPoint(ret_bb);
289  auto result_phi =
290  cgen_state_->ir_builder_.CreatePHI(lhs_lv->getType(), (!ti.get_notnull()) ? 3 : 2);
291  if (!ti.get_notnull()) {
292  result_phi->addIncoming(cgen_state_->inlineIntNull(ti), nullcheck_fail_bb);
293  }
294  result_phi->addIncoming(cnst_lv, sc_check_bb);
295  result_phi->addIncoming(rhs_lv, rhs_codegen_bb);
296  return result_phi;
297 }
298 
300  const CompilationOptions& co) {
302  const auto optype = bin_oper->get_optype();
303  CHECK(IS_LOGIC(optype));
304 
305  if (llvm::Value* short_circuit = codegenLogicalShortCircuit(bin_oper, co)) {
306  return short_circuit;
307  }
308 
309  const auto lhs = bin_oper->get_left_operand();
310  const auto rhs = bin_oper->get_right_operand();
311  auto lhs_lv = codegen(lhs, true, co).front();
312  auto rhs_lv = codegen(rhs, true, co).front();
313  const auto& ti = bin_oper->get_type_info();
314  if (ti.get_notnull()) {
315  switch (optype) {
316  case kAND:
317  return cgen_state_->ir_builder_.CreateAnd(toBool(lhs_lv), toBool(rhs_lv));
318  case kOR:
319  return cgen_state_->ir_builder_.CreateOr(toBool(lhs_lv), toBool(rhs_lv));
320  default:
321  CHECK(false);
322  }
323  }
324  CHECK(lhs_lv->getType()->isIntegerTy(1) || lhs_lv->getType()->isIntegerTy(8));
325  CHECK(rhs_lv->getType()->isIntegerTy(1) || rhs_lv->getType()->isIntegerTy(8));
326  if (lhs_lv->getType()->isIntegerTy(1)) {
327  lhs_lv = cgen_state_->castToTypeIn(lhs_lv, 8);
328  }
329  if (rhs_lv->getType()->isIntegerTy(1)) {
330  rhs_lv = cgen_state_->castToTypeIn(rhs_lv, 8);
331  }
332  switch (optype) {
333  case kAND:
334  return cgen_state_->emitCall("logical_and",
335  {lhs_lv, rhs_lv, cgen_state_->inlineIntNull(ti)});
336  case kOR:
337  return cgen_state_->emitCall("logical_or",
338  {lhs_lv, rhs_lv, cgen_state_->inlineIntNull(ti)});
339  default:
340  abort();
341  }
342 }
343 
344 llvm::Value* CodeGenerator::toBool(llvm::Value* lv) {
346  CHECK(lv->getType()->isIntegerTy());
347  if (static_cast<llvm::IntegerType*>(lv->getType())->getBitWidth() > 1) {
348  return cgen_state_->ir_builder_.CreateICmp(
349  llvm::ICmpInst::ICMP_SGT, lv, llvm::ConstantInt::get(lv->getType(), 0));
350  }
351  return lv;
352 }
353 
354 namespace {
355 
357  const auto bin_oper = dynamic_cast<const Analyzer::BinOper*>(expr);
358  return bin_oper && bin_oper->get_qualifier() != kONE;
359 }
360 
361 } // namespace
362 
364  const CompilationOptions& co) {
366  const auto optype = uoper->get_optype();
367  CHECK_EQ(kNOT, optype);
368  const auto operand = uoper->get_operand();
369  const auto& operand_ti = operand->get_type_info();
370  CHECK(operand_ti.is_boolean());
371  const auto operand_lv = codegen(operand, true, co).front();
372  CHECK(operand_lv->getType()->isIntegerTy());
373  const bool not_null = (operand_ti.get_notnull() || is_qualified_bin_oper(operand));
374  CHECK(not_null || operand_lv->getType()->isIntegerTy(8));
375  return not_null
376  ? cgen_state_->ir_builder_.CreateNot(toBool(operand_lv))
378  "logical_not", {operand_lv, cgen_state_->inlineIntNull(operand_ti)});
379 }
380 
382  const CompilationOptions& co) {
384  const auto operand = uoper->get_operand();
385  if (dynamic_cast<const Analyzer::Constant*>(operand) &&
386  dynamic_cast<const Analyzer::Constant*>(operand)->get_is_null()) {
387  // for null constants, short-circuit to true
388  return llvm::ConstantInt::get(get_int_type(1, cgen_state_->context_), 1);
389  }
390  const auto& ti = operand->get_type_info();
391  CHECK(ti.is_integer() || ti.is_boolean() || ti.is_decimal() || ti.is_time() ||
392  ti.is_string() || ti.is_fp() || ti.is_array() || ti.is_geometry());
393  // if the type is inferred as non null, short-circuit to false
394  if (ti.get_notnull()) {
395  return llvm::ConstantInt::get(get_int_type(1, cgen_state_->context_), 0);
396  }
397  llvm::Value* operand_lv = codegen(operand, true, co).front();
398  // NULL-check array or geo's coords array
399  if (ti.get_type() == kPOINT && dynamic_cast<Analyzer::GeoOperator const*>(operand)) {
400  char const* const fname = spatial_type::Codegen::pointIsNullFunctionName(ti);
401  return cgen_state_->emitCall(fname, {operand_lv});
402  } else if (ti.is_array() || ti.is_geometry()) {
403  // POINT [un]compressed coord check requires custom checker and chunk iterator
404  // Non-POINT NULL geographies will have a normally encoded null coord array
405  auto fname =
406  (ti.get_type() == kPOINT) ? "point_coord_array_is_null" : "array_is_null";
408  fname, get_int_type(1, cgen_state_->context_), {operand_lv, posArg(operand)});
409  } else if (ti.is_none_encoded_string()) {
410  operand_lv = cgen_state_->ir_builder_.CreateExtractValue(operand_lv, 0);
411  operand_lv = cgen_state_->castToTypeIn(operand_lv, sizeof(int64_t) * 8);
412  }
413  return codegenIsNullNumber(operand_lv, ti);
414 }
415 
416 llvm::Value* CodeGenerator::codegenIsNullNumber(llvm::Value* operand_lv,
417  const SQLTypeInfo& ti) {
419  if (ti.is_fp()) {
420  return cgen_state_->ir_builder_.CreateFCmp(llvm::FCmpInst::FCMP_OEQ,
421  operand_lv,
422  ti.get_type() == kFLOAT
425  }
426  return cgen_state_->ir_builder_.CreateICmp(
427  llvm::ICmpInst::ICMP_EQ, operand_lv, cgen_state_->inlineIntNull(ti));
428 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
llvm::Value * castToTypeIn(llvm::Value *val, const size_t bit_width)
Definition: CgenState.cpp:150
#define IS_LOGIC(X)
Definition: sqldefs.h:64
#define NULL_DOUBLE
bool should_defer_eval(const std::shared_ptr< Analyzer::Expr > expr)
Definition: LogicalIR.cpp:55
CgenState * cgen_state_
#define NULL_FLOAT
bool is_fp() const
Definition: sqltypes.h:573
const Expr * get_right_operand() const
Definition: Analyzer.h:456
llvm::IRBuilder ir_builder_
Definition: CgenState.h:384
llvm::Value * posArg(const Analyzer::Expr *) const
Definition: ColumnIR.cpp:590
Definition: sqldefs.h:40
std::unordered_set< std::shared_ptr< Analyzer::Expr >> HoistedFiltersSet
Definition: PlanState.h:45
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:391
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
static char const * pointIsNullFunctionName(SQLTypeInfo const &)
Definition: Codegen.cpp:69
llvm::Value * codegenIsNull(const Analyzer::UOper *, const CompilationOptions &)
Definition: LogicalIR.cpp:381
SQLOps get_optype() const
Definition: Analyzer.h:452
Likelihood get_likelihood(const Analyzer::Expr *expr)
Definition: LogicalIR.cpp:79
llvm::LLVMContext & context_
Definition: CgenState.h:382
llvm::Function * current_func_
Definition: CgenState.h:376
llvm::Value * emitExternalCall(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, const std::vector< llvm::Attribute::AttrKind > &fnattrs={}, const bool has_struct_return=false)
Definition: CgenState.cpp:395
bool is_qualified_bin_oper(const Analyzer::Expr *expr)
Definition: LogicalIR.cpp:356
llvm::ConstantInt * inlineIntNull(const SQLTypeInfo &)
Definition: CgenState.cpp:65
Weight get_weight(const Analyzer::Expr *expr, int depth=0)
Definition: LogicalIR.cpp:119
bool isInvalid() const
Definition: NullableValue.h:33
Definition: sqldefs.h:39
NullableValue< float > Likelihood
Definition: NullableValue.h:99
llvm::ConstantFP * llFp(const float v) const
Definition: CgenState.h:253
#define AUTOMATIC_IR_METADATA(CGENSTATE)
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
llvm::Value * emitCall(const std::string &fname, const std::vector< llvm::Value * > &args)
Definition: CgenState.cpp:217
SQLTypes decimal_to_int_type(const SQLTypeInfo &ti)
Definition: Datum.cpp:561
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:30
Definition: sqldefs.h:74
const Expr * get_operand() const
Definition: Analyzer.h:384
Datum get_constval() const
Definition: Analyzer.h:348
llvm::Value * toBool(llvm::Value *)
Definition: LogicalIR.cpp:344
static bool prioritizeQuals(const RelAlgExecutionUnit &ra_exe_unit, std::vector< Analyzer::Expr * > &primary_quals, std::vector< Analyzer::Expr * > &deferred_quals, const PlanState::HoistedFiltersSet &hoisted_quals)
Definition: LogicalIR.cpp:158
std::list< std::shared_ptr< Analyzer::Expr > > quals
#define CHECK(condition)
Definition: Logger.h:291
llvm::Value * codegenIsNullNumber(llvm::Value *, const SQLTypeInfo &)
Definition: LogicalIR.cpp:416
llvm::Value * codegenLogical(const Analyzer::BinOper *, const CompilationOptions &)
Definition: LogicalIR.cpp:299
bool contains_unsafe_division(const Analyzer::Expr *expr)
Definition: LogicalIR.cpp:26
const Expr * get_left_operand() const
Definition: Analyzer.h:455
Definition: sqltypes.h:72
llvm::Value * codegenLogicalShortCircuit(const Analyzer::BinOper *, const CompilationOptions &)
Definition: LogicalIR.cpp:197
virtual void find_expr(std::function< bool(const Expr *)> f, std::list< const Expr * > &expr_list) const
Definition: Analyzer.h:163
DEVICE void swap(ARGS &&...args)
Definition: gpu_enabled.h:114
Definition: sqldefs.h:41
SQLOps get_optype() const
Definition: Analyzer.h:383
bool is_array() const
Definition: sqltypes.h:585
NullableValue< uint64_t > Weight
std::list< std::shared_ptr< Analyzer::Expr > > simple_quals
SQLQualifier get_qualifier() const
Definition: Analyzer.h:454