OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ExpressionRewrite.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <algorithm>
20 #include <boost/locale/conversion.hpp>
21 #include <unordered_set>
22 
23 #include "Logger/Logger.h"
24 #include "Parser/ParserNode.h"
26 #include "QueryEngine/Execute.h"
29 #include "Shared/sqldefs.h"
30 #include "StringOps/StringOps.h"
31 
32 namespace {
33 
34 class OrToInVisitor : public ScalarExprVisitor<std::shared_ptr<Analyzer::InValues>> {
35  protected:
36  std::shared_ptr<Analyzer::InValues> visitBinOper(
37  const Analyzer::BinOper* bin_oper) const override {
38  switch (bin_oper->get_optype()) {
39  case kEQ: {
40  const auto rhs_owned = bin_oper->get_own_right_operand();
41  auto rhs_no_cast = extract_cast_arg(rhs_owned.get());
42  if (!dynamic_cast<const Analyzer::Constant*>(rhs_no_cast)) {
43  return nullptr;
44  }
45  const auto arg = bin_oper->get_own_left_operand();
46  const auto& arg_ti = arg->get_type_info();
47  auto rhs = rhs_no_cast->deep_copy()->add_cast(arg_ti);
48  return makeExpr<Analyzer::InValues>(
49  arg, std::list<std::shared_ptr<Analyzer::Expr>>{rhs});
50  }
51  case kOR: {
52  return aggregateResult(visit(bin_oper->get_left_operand()),
53  visit(bin_oper->get_right_operand()));
54  }
55  default:
56  break;
57  }
58  return nullptr;
59  }
60 
61  std::shared_ptr<Analyzer::InValues> visitUOper(
62  const Analyzer::UOper* uoper) const override {
63  return nullptr;
64  }
65 
66  std::shared_ptr<Analyzer::InValues> visitInValues(
67  const Analyzer::InValues*) const override {
68  return nullptr;
69  }
70 
71  std::shared_ptr<Analyzer::InValues> visitInIntegerSet(
72  const Analyzer::InIntegerSet*) const override {
73  return nullptr;
74  }
75 
76  std::shared_ptr<Analyzer::InValues> visitCharLength(
77  const Analyzer::CharLengthExpr*) const override {
78  return nullptr;
79  }
80 
81  std::shared_ptr<Analyzer::InValues> visitKeyForString(
82  const Analyzer::KeyForStringExpr*) const override {
83  return nullptr;
84  }
85 
86  std::shared_ptr<Analyzer::InValues> visitSampleRatio(
87  const Analyzer::SampleRatioExpr*) const override {
88  return nullptr;
89  }
90 
91  std::shared_ptr<Analyzer::InValues> visitMLPredict(
92  const Analyzer::MLPredictExpr*) const override {
93  return nullptr;
94  }
95 
96  std::shared_ptr<Analyzer::InValues> visitPCAProject(
97  const Analyzer::PCAProjectExpr*) const override {
98  return nullptr;
99  }
100 
101  std::shared_ptr<Analyzer::InValues> visitCardinality(
102  const Analyzer::CardinalityExpr*) const override {
103  return nullptr;
104  }
105 
106  std::shared_ptr<Analyzer::InValues> visitLikeExpr(
107  const Analyzer::LikeExpr*) const override {
108  return nullptr;
109  }
110 
111  std::shared_ptr<Analyzer::InValues> visitRegexpExpr(
112  const Analyzer::RegexpExpr*) const override {
113  return nullptr;
114  }
115 
116  std::shared_ptr<Analyzer::InValues> visitCaseExpr(
117  const Analyzer::CaseExpr*) const override {
118  return nullptr;
119  }
120 
121  std::shared_ptr<Analyzer::InValues> visitDatetruncExpr(
122  const Analyzer::DatetruncExpr*) const override {
123  return nullptr;
124  }
125 
126  std::shared_ptr<Analyzer::InValues> visitDatediffExpr(
127  const Analyzer::DatediffExpr*) const override {
128  return nullptr;
129  }
130 
131  std::shared_ptr<Analyzer::InValues> visitDateaddExpr(
132  const Analyzer::DateaddExpr*) const override {
133  return nullptr;
134  }
135 
136  std::shared_ptr<Analyzer::InValues> visitExtractExpr(
137  const Analyzer::ExtractExpr*) const override {
138  return nullptr;
139  }
140 
141  std::shared_ptr<Analyzer::InValues> visitLikelihood(
142  const Analyzer::LikelihoodExpr*) const override {
143  return nullptr;
144  }
145 
146  std::shared_ptr<Analyzer::InValues> visitAggExpr(
147  const Analyzer::AggExpr*) const override {
148  return nullptr;
149  }
150 
151  std::shared_ptr<Analyzer::InValues> aggregateResult(
152  const std::shared_ptr<Analyzer::InValues>& lhs,
153  const std::shared_ptr<Analyzer::InValues>& rhs) const override {
154  if (!lhs || !rhs) {
155  return nullptr;
156  }
157 
158  if (lhs->get_arg()->get_type_info() == rhs->get_arg()->get_type_info() &&
159  (*lhs->get_arg() == *rhs->get_arg())) {
160  auto union_values = lhs->get_value_list();
161  const auto& rhs_values = rhs->get_value_list();
162  union_values.insert(union_values.end(), rhs_values.begin(), rhs_values.end());
163  return makeExpr<Analyzer::InValues>(lhs->get_own_arg(), union_values);
164  }
165  return nullptr;
166  }
167 };
168 
170  protected:
171  std::shared_ptr<Analyzer::Expr> visitBinOper(
172  const Analyzer::BinOper* bin_oper) const override {
173  OrToInVisitor simple_visitor;
174  if (bin_oper->get_optype() == kOR) {
175  auto rewritten = simple_visitor.visit(bin_oper);
176  if (rewritten) {
177  return rewritten;
178  }
179  }
180  auto lhs = bin_oper->get_own_left_operand();
181  auto rhs = bin_oper->get_own_right_operand();
182  auto rewritten_lhs = visit(lhs.get());
183  auto rewritten_rhs = visit(rhs.get());
184  return makeExpr<Analyzer::BinOper>(bin_oper->get_type_info(),
185  bin_oper->get_contains_agg(),
186  bin_oper->get_optype(),
187  bin_oper->get_qualifier(),
188  rewritten_lhs ? rewritten_lhs : lhs,
189  rewritten_rhs ? rewritten_rhs : rhs);
190  }
191 };
192 
194  protected:
196 
197  RetType visitArrayOper(const Analyzer::ArrayExpr* array_expr) const override {
198  std::vector<std::shared_ptr<Analyzer::Expr>> args_copy;
199  for (size_t i = 0; i < array_expr->getElementCount(); ++i) {
200  auto const element_expr_ptr = visit(array_expr->getElement(i));
201  auto const& element_expr_type_info = element_expr_ptr->get_type_info();
202 
203  if (!element_expr_type_info.is_string() ||
204  element_expr_type_info.get_compression() != kENCODING_NONE) {
205  args_copy.push_back(element_expr_ptr);
206  } else {
207  auto transient_dict_type_info = element_expr_type_info;
208 
209  transient_dict_type_info.set_compression(kENCODING_DICT);
210  transient_dict_type_info.set_comp_param(TRANSIENT_DICT_ID);
211  transient_dict_type_info.setStringDictKey(
213  transient_dict_type_info.set_fixed_size();
214  args_copy.push_back(element_expr_ptr->add_cast(transient_dict_type_info));
215  }
216  }
217 
218  const auto& type_info = array_expr->get_type_info();
219  return makeExpr<Analyzer::ArrayExpr>(
220  type_info, args_copy, array_expr->isNull(), array_expr->isLocalAlloc());
221  }
222 };
223 
225  template <typename T>
226  bool foldComparison(SQLOps optype, T t1, T t2) const {
227  switch (optype) {
228  case kEQ:
229  return t1 == t2;
230  case kNE:
231  return t1 != t2;
232  case kLT:
233  return t1 < t2;
234  case kLE:
235  return t1 <= t2;
236  case kGT:
237  return t1 > t2;
238  case kGE:
239  return t1 >= t2;
240  default:
241  break;
242  }
243  throw std::runtime_error("Unable to fold");
244  return false;
245  }
246 
247  template <typename T>
248  bool foldLogic(SQLOps optype, T t1, T t2) const {
249  switch (optype) {
250  case kAND:
251  return t1 && t2;
252  case kOR:
253  return t1 || t2;
254  case kNOT:
255  return !t1;
256  default:
257  break;
258  }
259  throw std::runtime_error("Unable to fold");
260  return false;
261  }
262 
263  template <typename T>
264  T foldArithmetic(SQLOps optype, T t1, T t2) const {
265  bool t2_is_zero = (t2 == (t2 - t2));
266  bool t2_is_negative = (t2 < (t2 - t2));
267  switch (optype) {
268  case kPLUS:
269  // The MIN limit for float and double is the smallest representable value,
270  // not the lowest negative value! Switching to C++11 lowest.
271  if ((t2_is_negative && t1 < std::numeric_limits<T>::lowest() - t2) ||
272  (!t2_is_negative && t1 > std::numeric_limits<T>::max() - t2)) {
273  num_overflows_++;
274  throw std::runtime_error("Plus overflow");
275  }
276  return t1 + t2;
277  case kMINUS:
278  if ((t2_is_negative && t1 > std::numeric_limits<T>::max() + t2) ||
279  (!t2_is_negative && t1 < std::numeric_limits<T>::lowest() + t2)) {
280  num_overflows_++;
281  throw std::runtime_error("Minus overflow");
282  }
283  return t1 - t2;
284  case kMULTIPLY: {
285  if (t2_is_zero) {
286  return t2;
287  }
288  auto ct1 = t1;
289  auto ct2 = t2;
290  // Need to keep t2's sign on the left
291  if (t2_is_negative) {
292  if (t1 == std::numeric_limits<T>::lowest() ||
293  t2 == std::numeric_limits<T>::lowest()) {
294  // negation could overflow - bail
295  num_overflows_++;
296  throw std::runtime_error("Mul neg overflow");
297  }
298  ct1 = -t1; // ct1 gets t2's negativity
299  ct2 = -t2; // ct2 is now positive
300  }
301  // Don't check overlow if we are folding FP mul by a fraction
302  bool ct2_is_fraction = (ct2 < (ct2 / ct2));
303  if (!ct2_is_fraction) {
304  if (ct1 > std::numeric_limits<T>::max() / ct2 ||
305  ct1 < std::numeric_limits<T>::lowest() / ct2) {
306  num_overflows_++;
307  throw std::runtime_error("Mul overflow");
308  }
309  }
310  return t1 * t2;
311  }
312  case kDIVIDE:
313  if (t2_is_zero) {
314  throw std::runtime_error("Will not fold division by zero");
315  }
316  return t1 / t2;
317  default:
318  break;
319  }
320  throw std::runtime_error("Unable to fold");
321  }
322 
323  bool foldOper(SQLOps optype,
324  SQLTypes type,
325  Datum lhs,
326  Datum rhs,
327  Datum& result,
328  SQLTypes& result_type) const {
329  result_type = type;
330 
331  try {
332  switch (type) {
333  case kBOOLEAN:
334  if (IS_COMPARISON(optype)) {
335  result.boolval = foldComparison<bool>(optype, lhs.boolval, rhs.boolval);
336  result_type = kBOOLEAN;
337  return true;
338  }
339  if (IS_LOGIC(optype)) {
340  result.boolval = foldLogic<bool>(optype, lhs.boolval, rhs.boolval);
341  result_type = kBOOLEAN;
342  return true;
343  }
344  CHECK(!IS_ARITHMETIC(optype));
345  break;
346  case kTINYINT:
347  if (IS_COMPARISON(optype)) {
348  result.boolval =
349  foldComparison<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
350  result_type = kBOOLEAN;
351  return true;
352  }
353  if (IS_ARITHMETIC(optype)) {
354  result.tinyintval =
355  foldArithmetic<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
356  result_type = kTINYINT;
357  return true;
358  }
359  CHECK(!IS_LOGIC(optype));
360  break;
361  case kSMALLINT:
362  if (IS_COMPARISON(optype)) {
363  result.boolval =
364  foldComparison<int16_t>(optype, lhs.smallintval, rhs.smallintval);
365  result_type = kBOOLEAN;
366  return true;
367  }
368  if (IS_ARITHMETIC(optype)) {
369  result.smallintval =
370  foldArithmetic<int16_t>(optype, lhs.smallintval, rhs.smallintval);
371  result_type = kSMALLINT;
372  return true;
373  }
374  CHECK(!IS_LOGIC(optype));
375  break;
376  case kINT:
377  if (IS_COMPARISON(optype)) {
378  result.boolval = foldComparison<int32_t>(optype, lhs.intval, rhs.intval);
379  result_type = kBOOLEAN;
380  return true;
381  }
382  if (IS_ARITHMETIC(optype)) {
383  result.intval = foldArithmetic<int32_t>(optype, lhs.intval, rhs.intval);
384  result_type = kINT;
385  return true;
386  }
387  CHECK(!IS_LOGIC(optype));
388  break;
389  case kBIGINT:
390  if (IS_COMPARISON(optype)) {
391  result.boolval =
392  foldComparison<int64_t>(optype, lhs.bigintval, rhs.bigintval);
393  result_type = kBOOLEAN;
394  return true;
395  }
396  if (IS_ARITHMETIC(optype)) {
397  result.bigintval =
398  foldArithmetic<int64_t>(optype, lhs.bigintval, rhs.bigintval);
399  result_type = kBIGINT;
400  return true;
401  }
402  CHECK(!IS_LOGIC(optype));
403  break;
404  case kFLOAT:
405  if (IS_COMPARISON(optype)) {
406  result.boolval = foldComparison<float>(optype, lhs.floatval, rhs.floatval);
407  result_type = kBOOLEAN;
408  return true;
409  }
410  if (IS_ARITHMETIC(optype)) {
411  result.floatval = foldArithmetic<float>(optype, lhs.floatval, rhs.floatval);
412  result_type = kFLOAT;
413  return true;
414  }
415  CHECK(!IS_LOGIC(optype));
416  break;
417  case kDOUBLE:
418  if (IS_COMPARISON(optype)) {
419  result.boolval = foldComparison<double>(optype, lhs.doubleval, rhs.doubleval);
420  result_type = kBOOLEAN;
421  return true;
422  }
423  if (IS_ARITHMETIC(optype)) {
424  result.doubleval =
425  foldArithmetic<double>(optype, lhs.doubleval, rhs.doubleval);
426  result_type = kDOUBLE;
427  return true;
428  }
429  CHECK(!IS_LOGIC(optype));
430  break;
431  default:
432  break;
433  }
434  } catch (...) {
435  return false;
436  }
437  return false;
438  }
439 
440  std::shared_ptr<Analyzer::Expr> visitUOper(
441  const Analyzer::UOper* uoper) const override {
442  const auto unvisited_operand = uoper->get_operand();
443  const auto optype = uoper->get_optype();
444  const auto& ti = uoper->get_type_info();
445  if (optype == kCAST) {
446  // Cache the cast type so it could be used in operand rewriting/folding
447  casts_.insert({unvisited_operand, ti});
448  }
449  const auto operand = visit(unvisited_operand);
450 
451  const auto& operand_ti = operand->get_type_info();
452  const auto operand_type =
453  operand_ti.is_decimal() ? decimal_to_int_type(operand_ti) : operand_ti.get_type();
454  const auto const_operand =
455  std::dynamic_pointer_cast<const Analyzer::Constant>(operand);
456 
457  if (const_operand) {
458  const auto operand_datum = const_operand->get_constval();
459  Datum zero_datum = {};
460  Datum result_datum = {};
461  SQLTypes result_type;
462  switch (optype) {
463  case kNOT: {
464  if (foldOper(kEQ,
465  operand_type,
466  zero_datum,
467  operand_datum,
468  result_datum,
469  result_type)) {
470  CHECK_EQ(result_type, kBOOLEAN);
471  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
472  }
473  break;
474  }
475  case kUMINUS: {
476  if (foldOper(kMINUS,
477  operand_type,
478  zero_datum,
479  operand_datum,
480  result_datum,
481  result_type)) {
482  if (!operand_ti.is_decimal()) {
483  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
484  }
485  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
486  }
487  break;
488  }
489  case kCAST: {
490  // Trying to fold number to number casts only
491  if (!ti.is_number() || !operand_ti.is_number()) {
492  break;
493  }
494  // Disallowing folding of FP to DECIMAL casts for now:
495  // allowing them would make this test pass:
496  // update dectest set d=cast( 1234.0 as float );
497  // which is expected to throw in Update.ImplicitCastToNumericTypes
498  // due to cast codegen currently not supporting these casts
499  if (ti.is_decimal() && operand_ti.is_fp()) {
500  break;
501  }
502  auto operand_copy = const_operand->deep_copy();
503  auto cast_operand = operand_copy->add_cast(ti);
504  auto const_cast_operand =
505  std::dynamic_pointer_cast<const Analyzer::Constant>(cast_operand);
506  if (const_cast_operand) {
507  auto const_cast_datum = const_cast_operand->get_constval();
508  return makeExpr<Analyzer::Constant>(ti, false, const_cast_datum);
509  }
510  }
511  default:
512  break;
513  }
514  }
515 
516  return makeExpr<Analyzer::UOper>(
517  uoper->get_type_info(), uoper->get_contains_agg(), optype, operand);
518  }
519 
520  std::shared_ptr<Analyzer::Expr> visitBinOper(
521  const Analyzer::BinOper* bin_oper) const override {
522  const auto optype = bin_oper->get_optype();
523  auto ti = bin_oper->get_type_info();
524  auto left_operand = bin_oper->get_own_left_operand();
525  auto right_operand = bin_oper->get_own_right_operand();
526 
527  // Check if bin_oper result is cast to a larger int or fp type
528  if (casts_.find(bin_oper) != casts_.end()) {
529  const auto cast_ti = casts_[bin_oper];
530  const auto& lhs_ti = bin_oper->get_left_operand()->get_type_info();
531  // Propagate cast down to the operands for folding
532  if ((cast_ti.is_integer() || cast_ti.is_fp()) && lhs_ti.is_integer() &&
533  cast_ti.get_size() > lhs_ti.get_size() &&
534  (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY)) {
535  // Before folding, cast the operands to the bigger type to avoid overflows.
536  // Currently upcasting smaller integer types to larger integers or double.
537  left_operand = left_operand->deep_copy()->add_cast(cast_ti);
538  right_operand = right_operand->deep_copy()->add_cast(cast_ti);
539  ti = cast_ti;
540  }
541  }
542 
543  const auto lhs = visit(left_operand.get());
544  const auto rhs = visit(right_operand.get());
545 
546  auto const_lhs = std::dynamic_pointer_cast<Analyzer::Constant>(lhs);
547  auto const_rhs = std::dynamic_pointer_cast<Analyzer::Constant>(rhs);
548  const auto& lhs_ti = lhs->get_type_info();
549  const auto& rhs_ti = rhs->get_type_info();
550  auto lhs_type = lhs_ti.is_decimal() ? decimal_to_int_type(lhs_ti) : lhs_ti.get_type();
551  auto rhs_type = rhs_ti.is_decimal() ? decimal_to_int_type(rhs_ti) : rhs_ti.get_type();
552 
553  if (const_lhs && const_rhs && lhs_type == rhs_type) {
554  auto lhs_datum = const_lhs->get_constval();
555  auto rhs_datum = const_rhs->get_constval();
556  Datum result_datum = {};
557  SQLTypes result_type;
558  if (foldOper(optype, lhs_type, lhs_datum, rhs_datum, result_datum, result_type)) {
559  // Fold all ops that don't take in decimal operands, and also decimal comparisons
560  if (!lhs_ti.is_decimal() || IS_COMPARISON(optype)) {
561  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
562  }
563  // Decimal arithmetic has been done as kBIGINT. Selectively fold some decimal ops,
564  // using result_datum and BinOper expr typeinfo which was adjusted for these ops.
565  if (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY) {
566  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
567  }
568  }
569  }
570 
571  if (optype == kAND && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
572  if (const_rhs && !const_rhs->get_is_null()) {
573  auto rhs_datum = const_rhs->get_constval();
574  if (rhs_datum.boolval == false) {
575  Datum d;
576  d.boolval = false;
577  // lhs && false --> false
578  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
579  }
580  // lhs && true --> lhs
581  return lhs;
582  }
583  if (const_lhs && !const_lhs->get_is_null()) {
584  auto lhs_datum = const_lhs->get_constval();
585  if (lhs_datum.boolval == false) {
586  Datum d;
587  d.boolval = false;
588  // false && rhs --> false
589  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
590  }
591  // true && rhs --> rhs
592  return rhs;
593  }
594  }
595  if (optype == kOR && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
596  if (const_rhs && !const_rhs->get_is_null()) {
597  auto rhs_datum = const_rhs->get_constval();
598  if (rhs_datum.boolval == true) {
599  Datum d;
600  d.boolval = true;
601  // lhs || true --> true
602  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
603  }
604  // lhs || false --> lhs
605  return lhs;
606  }
607  if (const_lhs && !const_lhs->get_is_null()) {
608  auto lhs_datum = const_lhs->get_constval();
609  if (lhs_datum.boolval == true) {
610  Datum d;
611  d.boolval = true;
612  // true || rhs --> true
613  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
614  }
615  // false || rhs --> rhs
616  return rhs;
617  }
618  }
619  if (*lhs == *rhs) {
620  if (!lhs_ti.get_notnull()) {
621  CHECK(!rhs_ti.get_notnull());
622  // We can't fold the ostensible tautaulogy
623  // for nullable lhs and rhs types, as
624  // lhs <> rhs when they are null
625 
626  // We likely could turn this into a lhs is not null
627  // operatation, but is it worth it?
628  return makeExpr<Analyzer::BinOper>(ti,
629  bin_oper->get_contains_agg(),
630  bin_oper->get_optype(),
631  bin_oper->get_qualifier(),
632  lhs,
633  rhs);
634  }
635  CHECK(rhs_ti.get_notnull());
636  // Tautologies: v=v; v<=v; v>=v
637  if (optype == kEQ || optype == kLE || optype == kGE) {
638  Datum d;
639  d.boolval = true;
640  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
641  }
642  // Contradictions: v!=v; v<v; v>v
643  if (optype == kNE || optype == kLT || optype == kGT) {
644  Datum d;
645  d.boolval = false;
646  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
647  }
648  // v-v
649  if (optype == kMINUS) {
650  Datum d = {};
651  return makeExpr<Analyzer::Constant>(lhs_type, false, d);
652  }
653  }
654  // Convert fp division by a constant to multiplication by 1/constant
655  if (optype == kDIVIDE && const_rhs && rhs_ti.is_fp()) {
656  auto rhs_datum = const_rhs->get_constval();
657  std::shared_ptr<Analyzer::Expr> recip_rhs = nullptr;
658  if (rhs_ti.get_type() == kFLOAT) {
659  if (rhs_datum.floatval == 1.0) {
660  return lhs;
661  }
662  auto f = std::fabs(rhs_datum.floatval);
663  if (f > 1.0 || (f != 0.0 && 1.0 < f * std::numeric_limits<float>::max())) {
664  rhs_datum.floatval = 1.0 / rhs_datum.floatval;
665  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
666  }
667  } else if (rhs_ti.get_type() == kDOUBLE) {
668  if (rhs_datum.doubleval == 1.0) {
669  return lhs;
670  }
671  auto d = std::fabs(rhs_datum.doubleval);
672  if (d > 1.0 || (d != 0.0 && 1.0 < d * std::numeric_limits<double>::max())) {
673  rhs_datum.doubleval = 1.0 / rhs_datum.doubleval;
674  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
675  }
676  }
677  if (recip_rhs) {
678  return makeExpr<Analyzer::BinOper>(ti,
679  bin_oper->get_contains_agg(),
680  kMULTIPLY,
681  bin_oper->get_qualifier(),
682  lhs,
683  recip_rhs);
684  }
685  }
686 
687  return makeExpr<Analyzer::BinOper>(ti,
688  bin_oper->get_contains_agg(),
689  bin_oper->get_optype(),
690  bin_oper->get_qualifier(),
691  lhs,
692  rhs);
693  }
694 
695  std::shared_ptr<Analyzer::Expr> visitStringOper(
696  const Analyzer::StringOper* string_oper) const override {
697  // Todo(todd): For clarity and modularity we should move string
698  // operator rewrites into their own visitor class.
699  // String operation rewrites were originally put here as they only
700  // handled string operators on rewrite, but now handle variable
701  // inputs as well.
702  const auto original_args = string_oper->getOwnArgs();
703  std::vector<std::shared_ptr<Analyzer::Expr>> rewritten_args;
704  const auto non_literal_arity = string_oper->getNonLiteralsArity();
705  const auto parent_in_string_op_chain = in_string_op_chain_;
706  const auto in_string_op_chain = non_literal_arity <= 1UL;
707  in_string_op_chain_ = in_string_op_chain;
708 
709  size_t rewritten_arg_literal_arity = 0;
710  for (auto original_arg : original_args) {
711  rewritten_args.emplace_back(visit(original_arg.get()));
712  if (dynamic_cast<const Analyzer::Constant*>(rewritten_args.back().get())) {
713  rewritten_arg_literal_arity++;
714  }
715  }
716  in_string_op_chain_ = parent_in_string_op_chain;
717  const auto kind = string_oper->get_kind();
718  const auto& return_ti = string_oper->get_type_info();
719 
720  if (string_oper->getArity() == rewritten_arg_literal_arity) {
721  Analyzer::StringOper literal_string_oper(
722  kind, string_oper->get_type_info(), rewritten_args);
723  const auto literal_args = literal_string_oper.getLiteralArgs();
724  const auto string_op_info =
725  StringOps_Namespace::StringOpInfo(kind, return_ti, literal_args);
726  if (return_ti.is_string()) {
727  const auto literal_result =
729  return Parser::StringLiteral::analyzeValue(literal_result.first,
730  literal_result.second);
731  }
732  const auto literal_datum =
734  auto nullable_return_ti = return_ti;
735  nullable_return_ti.set_notnull(false);
736  return makeExpr<Analyzer::Constant>(nullable_return_ti,
737  IsNullDatum(literal_datum, nullable_return_ti),
738  literal_datum);
739  }
740  chained_string_op_exprs_.emplace_back(
741  makeExpr<Analyzer::StringOper>(kind, return_ti, rewritten_args));
742  if (parent_in_string_op_chain && in_string_op_chain) {
743  CHECK(rewritten_args[0]->get_type_info().is_string());
744  return rewritten_args[0]->deep_copy();
745  } else {
746  auto new_string_oper = makeExpr<Analyzer::StringOper>(
747  kind, return_ti, rewritten_args, chained_string_op_exprs_);
748  chained_string_op_exprs_.clear();
749  return new_string_oper;
750  }
751  }
752 
753  protected:
754  mutable bool in_string_op_chain_{false};
755  mutable std::vector<std::shared_ptr<Analyzer::Expr>> chained_string_op_exprs_;
756  mutable std::unordered_map<const Analyzer::Expr*, const SQLTypeInfo> casts_;
757  mutable int32_t num_overflows_;
758 
759  public:
760  ConstantFoldingVisitor() : num_overflows_(0) {}
761  int32_t get_num_overflows() { return num_overflows_; }
762  void reset_num_overflows() { num_overflows_ = 0; }
763 };
764 
766  const auto with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
767  if (!with_likelihood) {
768  return expr;
769  }
770  return with_likelihood->get_arg();
771 }
772 
773 } // namespace
774 
776  return ArrayElementStringLiteralEncodingVisitor().visit(expr);
777 }
778 
780  const auto sum_window = rewrite_sum_window(expr);
781  if (sum_window) {
782  return sum_window;
783  }
784  const auto avg_window = rewrite_avg_window(expr);
785  if (avg_window) {
786  return avg_window;
787  }
788  const auto expr_no_likelihood = strip_likelihood(expr);
789  // The following check is not strictly needed, but seems silly to transform a
790  // simple string comparison to an IN just to codegen the same thing anyway.
791 
792  RecursiveOrToInVisitor visitor;
793  auto rewritten_expr = visitor.visit(expr_no_likelihood);
794  const auto expr_with_likelihood =
795  std::dynamic_pointer_cast<const Analyzer::LikelihoodExpr>(rewritten_expr);
796  if (expr_with_likelihood) {
797  // Add back likelihood
798  return std::make_shared<Analyzer::LikelihoodExpr>(
799  rewritten_expr, expr_with_likelihood->get_likelihood());
800  }
801  return rewritten_expr;
802 }
803 
804 namespace {
805 
807  std::unordered_map<const RelAlgNode*, int>& input_to_nest_level,
808  shared::ColumnKey const& column_key,
809  int target_nest_lv) {
810  for (auto& kv : input_to_nest_level) {
811  auto ra = kv.first;
812  auto table_keys = get_physical_table_inputs(ra);
813  if (std::any_of(table_keys.begin(),
814  table_keys.end(),
815  [column_key](shared::TableKey const& key) {
816  return key.table_id == column_key.table_id &&
817  key.db_id == column_key.db_id;
818  })) {
819  input_to_nest_level[ra] = target_nest_lv;
820  return;
821  }
822  }
823 }
824 
825 int update_input_desc(std::vector<InputDescriptor>& input_descs,
826  shared::ColumnKey const& column_key,
827  int target_nest_lv) {
828  int num_input_descs = static_cast<int>(input_descs.size());
829  for (int i = 0; i < num_input_descs; i++) {
830  auto const tbl_key = input_descs[i].getTableKey();
831  if (tbl_key.db_id == column_key.db_id && tbl_key.table_id == column_key.table_id) {
832  input_descs[i] = InputDescriptor(tbl_key.db_id, tbl_key.table_id, target_nest_lv);
833  return i;
834  }
835  }
836  return -1;
837 }
838 
840  std::list<std::shared_ptr<const InputColDescriptor>>& input_col_desc,
841  shared::ColumnKey const& column_key,
842  int target_nest_lv) {
843  for (auto it = input_col_desc.begin(); it != input_col_desc.end(); it++) {
844  auto const tbl_key = (*it)->getScanDesc().getTableKey();
845  if (tbl_key.db_id == column_key.db_id && tbl_key.table_id == column_key.table_id) {
846  (*it) = std::make_shared<InputColDescriptor>(
847  (*it)->getColId(), tbl_key.table_id, tbl_key.db_id, target_nest_lv);
848  return it;
849  }
850  }
851  return input_col_desc.end();
852 }
853 
854 } // namespace
855 
858  const std::shared_ptr<Analyzer::Expr> expr,
859  std::vector<InputDescriptor>& input_descs,
860  std::unordered_map<const RelAlgNode*, int>& input_to_nest_level,
861  std::vector<size_t>& input_permutation,
862  std::list<std::shared_ptr<const InputColDescriptor>>& input_col_desc,
863  const BoundingBoxIntersectJoinRewriteType rewrite_type,
864  Executor const* executor) {
865  auto collect_table_cardinality = [&executor](const Analyzer::Expr* lhs,
866  const Analyzer::Expr* rhs) {
867  const auto lhs_cv = dynamic_cast<const Analyzer::ColumnVar*>(lhs);
868  const auto rhs_cv = dynamic_cast<const Analyzer::ColumnVar*>(rhs);
869  if (lhs_cv && rhs_cv) {
870  return std::make_pair<int64_t, int64_t>(
871  get_table_cardinality(lhs_cv->getTableKey(), executor),
872  get_table_cardinality(rhs_cv->getTableKey(), executor));
873  }
874  // otherwise, return an invalid table cardinality
875  return std::make_pair<int64_t, int64_t>(-1, -1);
876  };
877 
878  auto has_invalid_join_col_order = [](const Analyzer::Expr* lhs,
879  const Analyzer::Expr* rhs) {
880  // Check for compatible join ordering. If the join ordering does not match expected
881  // ordering for bounding box intersection, the join builder will fail.
882  std::set<int> lhs_rte_idx;
883  lhs->collect_rte_idx(lhs_rte_idx);
884  std::set<int> rhs_rte_idx;
885  rhs->collect_rte_idx(rhs_rte_idx);
886  auto has_invalid_num_join_cols = lhs_rte_idx.size() != 1 || rhs_rte_idx.size() != 1;
887  auto has_invalid_rte_idx = lhs_rte_idx > rhs_rte_idx;
888  return std::make_pair(has_invalid_num_join_cols || has_invalid_rte_idx,
889  has_invalid_rte_idx);
890  };
891  bool swap_args = false;
892  auto convert_to_range_join_oper =
893  [&](std::string_view func_name,
894  const std::shared_ptr<Analyzer::Expr> expr,
895  const Analyzer::BinOper* range_join_expr,
896  const Analyzer::GeoOperator* lhs,
897  const Analyzer::Constant* rhs) -> std::shared_ptr<Analyzer::BinOper> {
899  func_name)) {
900  CHECK_EQ(lhs->size(), size_t(2));
901  auto l_arg = lhs->getOperand(0);
902  // we try to build a join hash table for bounding box intersection based on the rhs
903  auto r_arg = lhs->getOperand(1);
904  const bool is_geography = l_arg->get_type_info().get_subtype() == kGEOGRAPHY ||
905  r_arg->get_type_info().get_subtype() == kGEOGRAPHY;
906  if (is_geography) {
907  VLOG(1) << "Range join not yet supported for geodesic distance "
908  << expr->toString();
909  return nullptr;
910  }
911  // Check for compatible join ordering. If the join ordering does not match expected
912  // ordering for bounding box intersection, the join builder will fail.
913  Analyzer::Expr* range_join_arg = r_arg;
914  Analyzer::Expr* bin_oper_arg = l_arg;
915  auto invalid_range_join_qual =
916  has_invalid_join_col_order(bin_oper_arg, range_join_arg);
917  if (invalid_range_join_qual.first) {
918  LOG(INFO) << "Unable to rewrite " << func_name
919  << " to exploit bounding box intersection. Cannot build hash table "
920  "over LHS type. "
921  "Check join order.\n"
922  << range_join_expr->toString();
923  return nullptr;
924  }
925  // swapping rule for range join argument
926  // lhs | rhs
927  // 1. pt | pt : swap if |lhs| < |rhs| or has invalid rte values
928  // 2. pt | non-pt : return nullptr
929  // 3. non-pt | pt : return nullptr
930  // 4. non-pt | non-pt : return nullptr
931  // todo (yoonmin) : improve logic for cases 2 and 3
932  bool lhs_is_point{l_arg->get_type_info().get_type() == kPOINT};
933  bool rhs_is_point{r_arg->get_type_info().get_type() == kPOINT};
934  if (!lhs_is_point || !rhs_is_point) {
935  // case 2 ~ 4
936  VLOG(1) << "Currently, we only support range hash join for Point-to-Point "
937  "distance query: fallback to a loop join";
938  return nullptr;
939  }
940  auto const card_info = collect_table_cardinality(range_join_arg, bin_oper_arg);
941  if (invalid_range_join_qual.second && card_info.first > 0 && lhs_is_point) {
942  swap_args = true;
943  } else if (card_info.first >= 0 && card_info.first < card_info.second) {
944  swap_args = true;
945  }
946  if (swap_args) {
947  // todo (yoonmin) : find the best reordering scheme when a query has multiple
948  // range join candidates; it needs a (cost-based) plan enumeration logic
949  // in our optimizer
950  auto r_cv = dynamic_cast<Analyzer::ColumnVar*>(lhs->getOperand(1));
951  auto l_cv = dynamic_cast<Analyzer::ColumnVar*>(lhs->getOperand(0));
952  if (r_cv && l_cv && input_descs.size() == 2) {
953  // to exploit range hash join, we need to assign point geometry to rhs
954  // specifically, we need to propagate the changes made here via various
955  // query metadata such as `input_desc`, `input_col_desc` and so on
956  // otherwise, we do not try swapping join arguments for safety
957  // and we do not try argument swapping if the input query has more two
958  // input tables; but it is enough to cover most of immerse use-cases
959  auto const r_col_key = r_cv->getColumnKey();
960  auto const l_col_key = l_cv->getColumnKey();
961  int r_rte_idx = r_cv->get_rte_idx();
962  int l_rte_idx = l_cv->get_rte_idx();
963  r_cv->set_rte_idx(l_rte_idx);
964  l_cv->set_rte_idx(r_rte_idx);
965  update_input_to_nest_lv(input_to_nest_level, r_col_key, l_rte_idx);
966  update_input_to_nest_lv(input_to_nest_level, l_col_key, r_rte_idx);
967  auto const r_input_desc_idx =
968  update_input_desc(input_descs, r_col_key, l_rte_idx);
969  CHECK_GE(r_input_desc_idx, 0);
970  auto const l_input_desc_idx =
971  update_input_desc(input_descs, l_col_key, r_rte_idx);
972  CHECK_GE(l_input_desc_idx, 0);
973  auto r_input_col_desc_it =
974  update_input_col_desc(input_col_desc, r_col_key, l_rte_idx);
975  CHECK(r_input_col_desc_it != input_col_desc.end());
976  auto l_input_col_desc_it =
977  update_input_col_desc(input_col_desc, l_col_key, r_rte_idx);
978  CHECK(l_input_col_desc_it != input_col_desc.end());
979  if (!input_permutation.empty()) {
980  auto r_itr =
981  std::find(input_permutation.begin(), input_permutation.end(), r_rte_idx);
982  CHECK(r_itr != input_permutation.end());
983  auto l_itr =
984  std::find(input_permutation.begin(), input_permutation.end(), l_rte_idx);
985  CHECK(l_itr != input_permutation.end());
986  std::swap(*r_itr, *l_itr);
987  }
988  std::swap(input_descs[r_input_desc_idx], input_descs[l_input_desc_idx]);
989  std::swap(r_input_col_desc_it, l_input_col_desc_it);
990  r_arg = lhs->getOperand(0);
991  l_arg = lhs->getOperand(1);
992  VLOG(1) << "Swap range join qual's input arguments to exploit hash join "
993  "framework for bounding box intersection";
994  invalid_range_join_qual.first = false;
995  }
996  }
997  const bool inclusive = range_join_expr->get_optype() == kLE;
998  auto range_expr = makeExpr<Analyzer::RangeOper>(
999  inclusive, inclusive, r_arg->deep_copy(), rhs->deep_copy());
1000  VLOG(1) << "Successfully converted to range hash join";
1001  return makeExpr<Analyzer::BinOper>(
1002  kBOOLEAN, kBBOX_INTERSECT, kONE, l_arg->deep_copy(), range_expr);
1003  }
1004  return nullptr;
1005  };
1006 
1007  /*
1008  * Currently, our hash join framework for bounding box intersection (bbox-intersect)
1009  * supports limited set of join quals when 1) the FunctionOperator is listed in the
1010  * function list, i.e., is_bbox_intersect_supported_func, 2) the argument order of the
1011  * join qual must match the input argument order of the corresponding native function,
1012  * and 3) input tables match rte index requirement (the column used to build a hash
1013  * table has larger rte compared with that of probing column). Depending on the type
1014  * of the function, we try to convert it to corresponding hash join qual if possible.
1015  * After rewriting, we create a join operator w/ bbox-intersect which is converted from
1016  * the original expression and return BoundingBoxIntersectJoinConjunction object which
1017  * is a pair of 1) the original expr and 2) converted join expr w/ bbox-intersect. Here,
1018  * returning the original expr means we additionally call its corresponding native
1019  * function to compute the result accurately (i.e., bbox-intersect hash join operates a
1020  * kind of filter expression which may include false-positive of the true resultset).
1021  * Note that ST_IntersectsBox is the only function that does not return the original
1022  * expr.
1023  * */
1024  std::shared_ptr<Analyzer::BinOper> bbox_intersect_oper{nullptr};
1025  bool needs_to_return_original_expr = false;
1026  std::string func_name{""};
1028  auto func_oper = dynamic_cast<Analyzer::FunctionOper*>(expr.get());
1029  CHECK(func_oper);
1030  func_name = func_oper->getName();
1033  LOG(WARNING) << "Many-to-many hashjoin support is disabled, unable to rewrite "
1034  << func_oper->toString() << " to use accelerated geo join.";
1036  }
1037  DeepCopyVisitor deep_copy_visitor;
1039  CHECK_GE(func_oper->getArity(), size_t(2));
1040  // this case returns {empty quals, bbox_intersect join quals} b/c our join key
1041  // matching logic for this case is the same as the implementation of
1042  // ST_IntersectsBox function Note that we can build a join hash table for
1043  // bbox_intersect regardless of table ordering and the argument order in this case
1044  // b/c selecting lhs and rhs by arguments 0 and 1 always match the rte index
1045  // requirement (rte_lhs < rte_rhs) so what table ordering we take, the rte index
1046  // requirement satisfies
1047  // TODO(adb): we will likely want to actually check for true bbox_intersect, but
1048  // this works for now
1049  auto lhs = func_oper->getOwnArg(0);
1050  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
1051  CHECK(rewritten_lhs);
1052 
1053  auto rhs = func_oper->getOwnArg(1);
1054  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
1055  CHECK(rewritten_rhs);
1056  bbox_intersect_oper = makeExpr<Analyzer::BinOper>(
1057  kBOOLEAN, kBBOX_INTERSECT, kONE, rewritten_lhs, rewritten_rhs);
1058  } else if (func_name ==
1060  CHECK_EQ(func_oper->getArity(), size_t(8));
1061  const auto lhs = func_oper->getOwnArg(0);
1062  const auto rhs = func_oper->getOwnArg(1);
1063  // the correctness of geo args used in the ST_DWithin function is checked by
1064  // geo translation logic, i.e., RelAlgTranslator::translateTernaryGeoFunction
1065  const auto distance_const_val =
1066  dynamic_cast<const Analyzer::Constant*>(func_oper->getArg(7));
1067  if (lhs && rhs && distance_const_val) {
1068  std::vector<std::shared_ptr<Analyzer::Expr>> args{lhs, rhs};
1069  auto range_oper = makeExpr<Analyzer::GeoOperator>(
1070  SQLTypeInfo(kDOUBLE, 0, 8, true),
1072  args,
1073  std::nullopt);
1074  auto distance_oper = makeExpr<Analyzer::BinOper>(
1075  kBOOLEAN, kLE, kONE, range_oper, distance_const_val->deep_copy());
1076  VLOG(1) << "Rewrite " << func_oper->getName() << " to ST_Distance_Point_Point";
1077  bbox_intersect_oper = convert_to_range_join_oper(
1079  distance_oper,
1080  distance_oper.get(),
1081  range_oper.get(),
1082  distance_const_val);
1083  needs_to_return_original_expr = true;
1084  }
1086  is_poly_mpoly_rewrite_target_func(func_name)) {
1087  // in the five functions fall into this case,
1088  // ST_Contains is for a pair of polygons, and for ST_Intersect cases they are
1089  // combo of polygon and multipolygon so what table orders we choose, rte index
1090  // requirement for bbox_intersect can be satisfied if we choose lhs and rhs
1091  // from left-to-right order (i.e., get lhs from the arg-1 instead of arg-3)
1092  // Note that we choose them from right-to-left argument order in the past
1093  CHECK_GE(func_oper->getArity(), size_t(4));
1094  auto lhs = func_oper->getOwnArg(1);
1095  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
1096  CHECK(rewritten_lhs);
1097  auto rhs = func_oper->getOwnArg(3);
1098  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
1099  CHECK(rewritten_rhs);
1100 
1101  bbox_intersect_oper = makeExpr<Analyzer::BinOper>(
1102  kBOOLEAN, kBBOX_INTERSECT, kONE, rewritten_lhs, rewritten_rhs);
1103  needs_to_return_original_expr = true;
1105  is_point_poly_rewrite_target_func(func_name)) {
1106  // now, we try to look at one more chance to exploit bbox_intersect by
1107  // rewriting the qual as: ST_INTERSECT(POLY, POINT) -> ST_INTERSECT(POINT, POLY)
1108  // to support efficient evaluation of 1) ST_Intersects_Point_Polygon and
1109  // 2) ST_Intersects_Point_MultiPolygon based on our hash join framework w/
1110  // bbox_intersect here, we have implementation of native functions for both 1)
1111  // Point-Polygon pair and 2) Polygon-Point pair, but we currently do not support
1112  // hash table generation on top of point column thus, the goal of this rewriting is
1113  // to place a non-point geometry to the right-side of the bbox_intersect_oper (to
1114  // build hash table based on it) iff the inner table is larger than that of
1115  // non-point geometry (to reduce expensive hash join performance)
1116  size_t point_arg_idx = 0;
1117  size_t poly_arg_idx = 2;
1118  if (func_oper->getOwnArg(point_arg_idx)->get_type_info().get_type() != kPOINT) {
1119  point_arg_idx = 2;
1120  poly_arg_idx = 1;
1121  }
1122  auto point_cv = func_oper->getOwnArg(point_arg_idx);
1123  auto poly_cv = func_oper->getOwnArg(poly_arg_idx);
1124  CHECK_EQ(point_cv->get_type_info().get_type(), kPOINT);
1125  CHECK_EQ(poly_cv->get_type_info().get_type(), kARRAY);
1126  auto rewritten_lhs = deep_copy_visitor.visit(point_cv.get());
1127  CHECK(rewritten_lhs);
1128  auto rewritten_rhs = deep_copy_visitor.visit(poly_cv.get());
1129  CHECK(rewritten_rhs);
1130  VLOG(1) << "Rewriting the " << func_name
1131  << " to use bounding box intersection with lhs as "
1132  << rewritten_lhs->toString() << " and rhs as " << rewritten_rhs->toString();
1133  bbox_intersect_oper = makeExpr<Analyzer::BinOper>(
1134  kBOOLEAN, kBBOX_INTERSECT, kONE, rewritten_lhs, rewritten_rhs);
1135  needs_to_return_original_expr = true;
1137  is_poly_point_rewrite_target_func(func_name)) {
1138  // rest of functions reaching here is poly and point geo join query
1139  // to use bbox_intersect in this case, poly column must have its rte == 1
1140  // lhs is the point col_var
1141  auto lhs = func_oper->getOwnArg(2);
1142  auto rewritten_lhs = deep_copy_visitor.visit(lhs.get());
1143  CHECK(rewritten_lhs);
1144  const auto& lhs_ti = rewritten_lhs->get_type_info();
1145 
1146  if (!lhs_ti.is_geometry() && !is_constructed_point(rewritten_lhs.get())) {
1147  // TODO(adb): If ST_Contains is passed geospatial literals instead of columns,
1148  // the function will be expanded during translation rather than during code
1149  // generation. While this scenario does not make sense for the bbox_intersect, we
1150  // need to detect and abort the bbox_intersect rewrite. Adding a
1151  // GeospatialConstant dervied class to the Analyzer may prove to be a better way
1152  // to handle geo literals, but for now we ensure the LHS type is a geospatial
1153  // type, which would mean the function has not been expanded to the physical
1154  // types, yet.
1155  LOG(INFO) << "Unable to rewrite " << func_name
1156  << " to bounding box intersection conjunction. LHS input type is "
1157  "neither a geospatial "
1158  "column nor a constructed point"
1159  << func_oper->toString();
1161  }
1162 
1163  // rhs is coordinates of the poly col
1164  auto rhs = func_oper->getOwnArg(1);
1165  auto rewritten_rhs = deep_copy_visitor.visit(rhs.get());
1166  CHECK(rewritten_rhs);
1167 
1168  if (has_invalid_join_col_order(lhs.get(), rhs.get()).first) {
1169  LOG(INFO) << "Unable to rewrite " << func_name
1170  << " to bounding box intersection conjunction. Cannot build hash table "
1171  "over LHS type. "
1172  "Check join order."
1173  << func_oper->toString();
1175  }
1176 
1177  VLOG(1) << "Rewriting " << func_name
1178  << " to use bounding box intersection with lhs as "
1179  << rewritten_lhs->toString() << " and rhs as " << rewritten_rhs->toString();
1180 
1181  bbox_intersect_oper = makeExpr<Analyzer::BinOper>(
1182  kBOOLEAN, kBBOX_INTERSECT, kONE, rewritten_lhs, rewritten_rhs);
1184  ST_APPROX_OVERLAPS_MULTIPOLYGON_POINT_sv) {
1185  needs_to_return_original_expr = true;
1186  }
1187  }
1188  } else if (rewrite_type == BoundingBoxIntersectJoinRewriteType::RANGE_JOIN) {
1189  auto bin_oper = dynamic_cast<Analyzer::BinOper*>(expr.get());
1190  CHECK(bin_oper);
1191  auto lhs = dynamic_cast<const Analyzer::GeoOperator*>(bin_oper->get_left_operand());
1192  CHECK(lhs);
1193  auto rhs = dynamic_cast<const Analyzer::Constant*>(bin_oper->get_right_operand());
1194  CHECK(rhs);
1195  func_name = lhs->getName();
1196  bbox_intersect_oper = convert_to_range_join_oper(func_name, expr, bin_oper, lhs, rhs);
1197  needs_to_return_original_expr = true;
1198  }
1199  const auto expr_str = !func_name.empty() ? func_name : expr->toString();
1200  if (bbox_intersect_oper) {
1202  res.swap_arguments = swap_args;
1203  BoundingBoxIntersectJoinConjunction bbox_intersect_join_qual;
1204  bbox_intersect_join_qual.join_quals.push_back(bbox_intersect_oper);
1205  if (needs_to_return_original_expr) {
1206  bbox_intersect_join_qual.quals.push_back(expr);
1207  }
1208  res.converted_bbox_intersect_join_info = bbox_intersect_join_qual;
1209  VLOG(1) << "Successfully converted " << expr_str
1210  << " to use bounding box intersection";
1211  return res;
1212  }
1213  VLOG(1) << "Bounding box intersection not enabled for " << expr_str;
1215 }
1216 
1218  JoinQualsPerNestingLevel const& join_quals,
1219  std::vector<InputDescriptor>& input_descs,
1220  std::unordered_map<const RelAlgNode*, int>& input_to_nest_level,
1221  std::vector<size_t>& input_permutation,
1222  std::list<std::shared_ptr<const InputColDescriptor>>& input_col_desc,
1223  Executor const* executor) {
1224  if (!g_enable_bbox_intersect_hashjoin || join_quals.empty()) {
1225  return {join_quals, false, false};
1226  }
1227 
1228  JoinQualsPerNestingLevel join_condition_per_nesting_level;
1229  bool is_reordered{false};
1230  bool has_bbox_intersect{false};
1231  for (const auto& join_condition_in : join_quals) {
1232  JoinCondition join_condition{{}, join_condition_in.type};
1233 
1234  for (const auto& join_qual_expr_in : join_condition_in.quals) {
1235  bool try_to_rewrite_expr_to_bbox_intersect = false;
1238  auto func_oper = dynamic_cast<Analyzer::FunctionOper*>(join_qual_expr_in.get());
1239  if (func_oper) {
1240  const auto func_name = func_oper->getName();
1242  func_name)) {
1243  try_to_rewrite_expr_to_bbox_intersect = true;
1245  }
1246  }
1247  auto bin_oper = dynamic_cast<Analyzer::BinOper*>(join_qual_expr_in.get());
1248  if (bin_oper && (bin_oper->get_optype() == kLE || bin_oper->get_optype() == kLT)) {
1249  auto lhs =
1250  dynamic_cast<const Analyzer::GeoOperator*>(bin_oper->get_left_operand());
1251  auto rhs = dynamic_cast<const Analyzer::Constant*>(bin_oper->get_right_operand());
1252  if (g_enable_distance_rangejoin && lhs && rhs) {
1253  try_to_rewrite_expr_to_bbox_intersect = true;
1255  }
1256  }
1258  if (try_to_rewrite_expr_to_bbox_intersect) {
1259  translation_res =
1261  input_descs,
1262  input_to_nest_level,
1263  input_permutation,
1264  input_col_desc,
1265  rewrite_type,
1266  executor);
1267  }
1268  if (translation_res.converted_bbox_intersect_join_info) {
1269  const auto& bbox_intersect_quals =
1270  *translation_res.converted_bbox_intersect_join_info;
1271  has_bbox_intersect = true;
1272  // Add qual for bounding box intersection
1273  join_condition.quals.insert(join_condition.quals.end(),
1274  bbox_intersect_quals.join_quals.begin(),
1275  bbox_intersect_quals.join_quals.end());
1276  // Add original quals
1277  join_condition.quals.insert(join_condition.quals.end(),
1278  bbox_intersect_quals.quals.begin(),
1279  bbox_intersect_quals.quals.end());
1280  } else {
1281  join_condition.quals.push_back(join_qual_expr_in);
1282  }
1283  is_reordered |= translation_res.swap_arguments;
1284  }
1285  join_condition_per_nesting_level.push_back(join_condition);
1286  }
1287  return {join_condition_per_nesting_level, has_bbox_intersect, is_reordered};
1288 }
1289 
1299  public:
1301  for (const auto& join_condition : join_quals) {
1302  for (const auto& qual : join_condition.quals) {
1303  auto qual_bin_oper = dynamic_cast<Analyzer::BinOper*>(qual.get());
1304  if (qual_bin_oper) {
1305  join_qual_pairs.emplace_back(qual_bin_oper->get_left_operand(),
1306  qual_bin_oper->get_right_operand());
1307  }
1308  }
1309  }
1310  }
1311 
1312  bool visitFunctionOper(const Analyzer::FunctionOper* func_oper) const override {
1314  func_oper->getName())) {
1315  const auto lhs = func_oper->getArg(2);
1316  const auto rhs = func_oper->getArg(1);
1317  for (const auto& qual_pair : join_qual_pairs) {
1318  if (*lhs == *qual_pair.first && *rhs == *qual_pair.second) {
1319  return true;
1320  }
1321  }
1322  }
1323  return false;
1324  }
1325 
1326  bool defaultResult() const override { return false; }
1327 
1328  private:
1329  std::vector<std::pair<const Analyzer::Expr*, const Analyzer::Expr*>> join_qual_pairs;
1330 };
1331 
1332 std::list<std::shared_ptr<Analyzer::Expr>> strip_join_covered_filter_quals(
1333  const std::list<std::shared_ptr<Analyzer::Expr>>& quals,
1334  const JoinQualsPerNestingLevel& join_quals) {
1336  return quals;
1337  }
1338 
1339  if (join_quals.empty()) {
1340  return quals;
1341  }
1342 
1343  std::list<std::shared_ptr<Analyzer::Expr>> quals_to_return;
1344 
1345  JoinCoveredQualVisitor visitor(join_quals);
1346  for (const auto& qual : quals) {
1347  if (!visitor.visit(qual.get())) {
1348  // Not a covered qual, don't elide it from the filtered count
1349  quals_to_return.push_back(qual);
1350  }
1351  }
1352 
1353  return quals_to_return;
1354 }
1355 
1356 std::shared_ptr<Analyzer::Expr> fold_expr(const Analyzer::Expr* expr) {
1357  if (!expr) {
1358  return nullptr;
1359  }
1360  const auto expr_no_likelihood = strip_likelihood(expr);
1361  ConstantFoldingVisitor visitor;
1362  auto rewritten_expr = visitor.visit(expr_no_likelihood);
1363  if (visitor.get_num_overflows() > 0 && rewritten_expr->get_type_info().is_integer() &&
1364  rewritten_expr->get_type_info().get_type() != kBIGINT) {
1365  auto rewritten_expr_const =
1366  std::dynamic_pointer_cast<const Analyzer::Constant>(rewritten_expr);
1367  if (!rewritten_expr_const) {
1368  // Integer expression didn't fold completely the first time due to
1369  // overflows in smaller type subexpressions, trying again with a cast
1370  const auto& ti = SQLTypeInfo(kBIGINT, false);
1371  auto bigint_expr_no_likelihood = expr_no_likelihood->deep_copy()->add_cast(ti);
1372  auto rewritten_expr_take2 = visitor.visit(bigint_expr_no_likelihood.get());
1373  auto rewritten_expr_take2_const =
1374  std::dynamic_pointer_cast<Analyzer::Constant>(rewritten_expr_take2);
1375  if (rewritten_expr_take2_const) {
1376  // Managed to fold, switch to the new constant
1377  rewritten_expr = rewritten_expr_take2_const;
1378  }
1379  }
1380  }
1381  const auto expr_with_likelihood = dynamic_cast<const Analyzer::LikelihoodExpr*>(expr);
1382  if (expr_with_likelihood) {
1383  // Add back likelihood
1384  return std::make_shared<Analyzer::LikelihoodExpr>(
1385  rewritten_expr, expr_with_likelihood->get_likelihood());
1386  }
1387  return rewritten_expr;
1388 }
1389 
1391  const Analyzer::ColumnVar* val_side,
1392  const int max_rte_covered) {
1393  if (key_side->getTableKey() == val_side->getTableKey() &&
1394  key_side->get_rte_idx() == val_side->get_rte_idx() &&
1395  key_side->get_rte_idx() > max_rte_covered) {
1396  return true;
1397  }
1398  return false;
1399 }
1400 
1402  std::unordered_map<int, llvm::Value*>& scan_idx_to_hash_pos) {
1403  int ret = INT32_MIN;
1404  for (auto& kv : scan_idx_to_hash_pos) {
1405  if (kv.first > ret) {
1406  ret = kv.first;
1407  }
1408  }
1409  return ret;
1410 }
1411 
1412 size_t get_table_cardinality(shared::TableKey const& table_key,
1413  Executor const* executor) {
1414  if (table_key.table_id > 0) {
1415  auto const td = Catalog_Namespace::get_metadata_for_table(table_key);
1416  CHECK(td);
1417  CHECK(td->fragmenter);
1418  return td->fragmenter->getNumRows();
1419  }
1420  auto temp_tbl = get_temporary_table(executor->getTemporaryTables(), table_key.table_id);
1421  return temp_tbl->rowCount();
1422 }
int8_t tinyintval
Definition: Datum.h:73
Analyzer::ExpressionPtr rewrite_array_elements(Analyzer::Expr const *expr)
#define CHECK_EQ(x, y)
Definition: Logger.h:301
std::optional< BoundingBoxIntersectJoinConjunction > converted_bbox_intersect_join_info
std::list< std::shared_ptr< Analyzer::Expr > > join_quals
#define IS_LOGIC(X)
Definition: sqldefs.h:64
std::shared_ptr< Analyzer::InValues > visitLikeExpr(const Analyzer::LikeExpr *) const override
std::shared_ptr< Analyzer::InValues > visitDateaddExpr(const Analyzer::DateaddExpr *) const override
Datum apply_numeric_op_to_literals(const StringOpInfo &string_op_info)
Definition: StringOps.cpp:1283
bool self_join_not_covered_by_left_deep_tree(const Analyzer::ColumnVar *key_side, const Analyzer::ColumnVar *val_side, const int max_rte_covered)
SQLTypes
Definition: sqltypes.h:65
std::shared_ptr< Analyzer::WindowFunction > rewrite_avg_window(const Analyzer::Expr *expr)
bool g_strip_join_covered_quals
Definition: Execute.cpp:116
std::shared_ptr< Analyzer::InValues > visitKeyForString(const Analyzer::KeyForStringExpr *) const override
std::shared_ptr< Analyzer::WindowFunction > rewrite_sum_window(const Analyzer::Expr *expr)
static constexpr std::string_view ST_DISTANCE_sv
std::shared_ptr< Analyzer::InValues > visitBinOper(const Analyzer::BinOper *bin_oper) const override
#define LOG(tag)
Definition: Logger.h:285
const Expr * get_right_operand() const
Definition: Analyzer.h:456
bool is_constructed_point(const Analyzer::Expr *expr)
Definition: Execute.h:1682
SQLOps
Definition: sqldefs.h:31
Definition: sqldefs.h:37
size_t get_table_cardinality(shared::TableKey const &table_key, Executor const *executor)
int8_t boolval
Definition: Datum.h:72
Definition: sqldefs.h:38
static bool is_range_join_rewrite_target_func(std::string_view target_func_name)
Definition: sqldefs.h:40
const TableDescriptor * get_metadata_for_table(const ::shared::TableKey &table_key, bool populate_fragmenter)
#define CHECK_GE(x, y)
Definition: Logger.h:306
bool get_contains_agg() const
Definition: Analyzer.h:81
Definition: sqldefs.h:51
Definition: sqldefs.h:32
std::shared_ptr< Analyzer::Expr > ExpressionPtr
Definition: Analyzer.h:184
size_t getArity() const
Definition: Analyzer.h:1674
const Analyzer::Expr * extract_cast_arg(const Analyzer::Expr *expr)
Definition: Execute.h:222
Definition: sqldefs.h:43
std::vector< JoinCondition > JoinQualsPerNestingLevel
T visit(const Analyzer::Expr *expr) const
Analyzer::ExpressionPtr rewrite_expr(const Analyzer::Expr *expr)
bool isNull() const
Definition: Analyzer.h:3025
#define TRANSIENT_DICT_ID
Definition: DbObjectKeys.h:24
std::shared_ptr< Analyzer::InValues > aggregateResult(const std::shared_ptr< Analyzer::InValues > &lhs, const std::shared_ptr< Analyzer::InValues > &rhs) const override
std::list< std::shared_ptr< Analyzer::Expr > > strip_join_covered_filter_quals(const std::list< std::shared_ptr< Analyzer::Expr >> &quals, const JoinQualsPerNestingLevel &join_quals)
std::shared_ptr< Analyzer::InValues > visitUOper(const Analyzer::UOper *uoper) const override
const int get_max_rte_scan_table(std::unordered_map< int, llvm::Value * > &scan_idx_to_hash_pos)
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
std::shared_ptr< Analyzer::Expr > RetType
int32_t intval
Definition: Datum.h:75
static BoundingBoxIntersectJoinTranslationResult createEmptyResult()
std::list< std::shared_ptr< Analyzer::Expr > > quals
std::shared_ptr< Analyzer::InValues > visitCharLength(const Analyzer::CharLengthExpr *) const override
std::shared_ptr< Analyzer::InValues > visitInIntegerSet(const Analyzer::InIntegerSet *) const override
SQLOps get_optype() const
Definition: Analyzer.h:452
float floatval
Definition: Datum.h:77
const ResultSetPtr & get_temporary_table(const TemporaryTables *temporary_tables, const int table_id)
Definition: Execute.h:246
std::shared_ptr< Analyzer::Expr > visitUOper(const Analyzer::UOper *uoper) const override
LiteralArgMap getLiteralArgs() const
Definition: Analyzer.cpp:4310
static constexpr std::string_view ST_INTERSECTSBOX_sv
Classes representing a parse tree.
bool g_enable_hashjoin_many_to_many
Definition: Execute.cpp:113
std::shared_ptr< Analyzer::InValues > visitRegexpExpr(const Analyzer::RegexpExpr *) const override
BoundingBoxIntersectJoinTranslationResult translate_bounding_box_intersect_with_reordering(const std::shared_ptr< Analyzer::Expr > expr, std::vector< InputDescriptor > &input_descs, std::unordered_map< const RelAlgNode *, int > &input_to_nest_level, std::vector< size_t > &input_permutation, std::list< std::shared_ptr< const InputColDescriptor >> &input_col_desc, const BoundingBoxIntersectJoinRewriteType rewrite_type, Executor const *executor)
int64_t bigintval
Definition: Datum.h:76
bool g_enable_distance_rangejoin
Definition: Execute.cpp:112
std::shared_ptr< Analyzer::InValues > visitAggExpr(const Analyzer::AggExpr *) const override
bool visitFunctionOper(const Analyzer::FunctionOper *func_oper) const override
Definition: sqldefs.h:39
bool foldOper(SQLOps optype, SQLTypes type, Datum lhs, Datum rhs, Datum &result, SQLTypes &result_type) const
int16_t smallintval
Definition: Datum.h:74
std::pair< std::string, bool > apply_string_op_to_literals(const StringOpInfo &string_op_info)
Definition: StringOps.cpp:1272
bool defaultResult() const override
std::shared_ptr< Analyzer::Expr > visitStringOper(const Analyzer::StringOper *string_oper) const override
RetType visitArrayOper(const Analyzer::ArrayExpr *array_expr) const override
bool IsNullDatum(const Datum datum, const SQLTypeInfo &ti)
Definition: Datum.cpp:331
BoundingBoxIntersectJoinTranslationInfo convert_bbox_intersect_join(JoinQualsPerNestingLevel const &join_quals, std::vector< InputDescriptor > &input_descs, std::unordered_map< const RelAlgNode *, int > &input_to_nest_level, std::vector< size_t > &input_permutation, std::list< std::shared_ptr< const InputColDescriptor >> &input_col_desc, Executor const *executor)
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
std::shared_ptr< Analyzer::Expr > visitBinOper(const Analyzer::BinOper *bin_oper) const override
size_t getElementCount() const
Definition: Analyzer.h:3023
SQLTypes decimal_to_int_type(const SQLTypeInfo &ti)
Definition: Datum.cpp:561
Definition: sqldefs.h:36
void update_input_to_nest_lv(std::unordered_map< const RelAlgNode *, int > &input_to_nest_level, shared::ColumnKey const &column_key, int target_nest_lv)
static bool is_bbox_intersect_supported_func(std::string_view target_func_name)
std::vector< std::shared_ptr< Analyzer::Expr > > chained_string_op_exprs_
Definition: sqldefs.h:42
std::vector< std::pair< const Analyzer::Expr *, const Analyzer::Expr * > > join_qual_pairs
Definition: sqldefs.h:74
Expression class for string functions The &quot;arg&quot; constructor parameter must be an expression that reso...
Definition: Analyzer.h:1601
std::unordered_set< shared::TableKey > get_physical_table_inputs(const RelAlgNode *ra)
const shared::ColumnKey & getColumnKey() const
Definition: Analyzer.h:198
static std::shared_ptr< Analyzer::Expr > analyzeValue(const std::string &stringval, const bool is_null)
Definition: ParserNode.cpp:147
bool isLocalAlloc() const
Definition: Analyzer.h:3024
SqlStringOpKind get_kind() const
Definition: Analyzer.h:1672
const Analyzer::Expr * getArg(const size_t i) const
Definition: Analyzer.h:2748
#define IS_ARITHMETIC(X)
Definition: sqldefs.h:65
bool g_enable_bbox_intersect_hashjoin
Definition: Execute.cpp:109
const Expr * get_operand() const
Definition: Analyzer.h:384
Datum get_constval() const
Definition: Analyzer.h:348
Definition: sqldefs.h:34
torch::Tensor f(torch::Tensor x, torch::Tensor W_target, torch::Tensor b_target)
std::unordered_map< const Analyzer::Expr *, const SQLTypeInfo > casts_
std::shared_ptr< Analyzer::InValues > visitDatediffExpr(const Analyzer::DatediffExpr *) const override
static const StringDictKey kTransientDictKey
Definition: DbObjectKeys.h:45
std::shared_ptr< Analyzer::InValues > visitLikelihood(const Analyzer::LikelihoodExpr *) const override
auto update_input_col_desc(std::list< std::shared_ptr< const InputColDescriptor >> &input_col_desc, shared::ColumnKey const &column_key, int target_nest_lv)
std::shared_ptr< Analyzer::InValues > visitInValues(const Analyzer::InValues *) const override
#define CHECK(condition)
Definition: Logger.h:291
int update_input_desc(std::vector< InputDescriptor > &input_descs, shared::ColumnKey const &column_key, int target_nest_lv)
static constexpr std::string_view ST_DWITHIN_POINT_POINT_sv
virtual void collect_rte_idx(std::set< int > &rte_idx_set) const
Definition: Analyzer.h:110
Definition: sqldefs.h:35
std::shared_ptr< Analyzer::InValues > visitPCAProject(const Analyzer::PCAProjectExpr *) const override
const Expr * get_left_operand() const
Definition: Analyzer.h:455
Common Enum definitions for SQL processing.
std::shared_ptr< Analyzer::InValues > visitCaseExpr(const Analyzer::CaseExpr *) const override
Definition: sqltypes.h:72
bool any_of(std::vector< Analyzer::Expr * > const &target_exprs)
JoinCoveredQualVisitor(const JoinQualsPerNestingLevel &join_quals)
const std::shared_ptr< Analyzer::Expr > get_own_right_operand() const
Definition: Analyzer.h:460
BoundingBoxIntersectJoinRewriteType
std::string getName() const
Definition: Analyzer.h:2744
int32_t get_rte_idx() const
Definition: Analyzer.h:202
Definition: Datum.h:71
bool is_decimal() const
Definition: sqltypes.h:570
size_t getNonLiteralsArity() const
Definition: Analyzer.h:1686
static bool is_many_to_many_func(std::string_view target_func_name)
std::shared_ptr< Analyzer::InValues > visitDatetruncExpr(const Analyzer::DatetruncExpr *) const override
DEVICE void swap(ARGS &&...args)
Definition: gpu_enabled.h:114
Definition: sqldefs.h:41
std::shared_ptr< Analyzer::InValues > visitMLPredict(const Analyzer::MLPredictExpr *) const override
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const
Definition: Analyzer.h:457
std::vector< std::shared_ptr< Analyzer::Expr > > getOwnArgs() const
Definition: Analyzer.h:1698
SQLOps get_optype() const
Definition: Analyzer.h:383
#define VLOG(n)
Definition: Logger.h:388
std::shared_ptr< Analyzer::Expr > fold_expr(const Analyzer::Expr *expr)
shared::TableKey getTableKey() const
Definition: Analyzer.h:199
#define IS_COMPARISON(X)
Definition: sqldefs.h:61
double doubleval
Definition: Datum.h:78
std::shared_ptr< Analyzer::InValues > visitExtractExpr(const Analyzer::ExtractExpr *) const override
const Analyzer::Expr * getElement(const size_t i) const
Definition: Analyzer.h:3027
SQLQualifier get_qualifier() const
Definition: Analyzer.h:454
const Analyzer::Expr * strip_likelihood(const Analyzer::Expr *expr)
std::shared_ptr< Analyzer::InValues > visitSampleRatio(const Analyzer::SampleRatioExpr *) const override
std::shared_ptr< Analyzer::InValues > visitCardinality(const Analyzer::CardinalityExpr *) const override