OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor Class Reference
+ Inheritance diagram for anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor:
+ Collaboration diagram for anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor:

Public Member Functions

 ConstantFoldingVisitor ()
 
int32_t get_num_overflows ()
 
void reset_num_overflows ()
 
- Public Member Functions inherited from ScalarExprVisitor< std::shared_ptr< Analyzer::Expr > >
std::shared_ptr< Analyzer::Exprvisit (const Analyzer::Expr *expr) const
 

Protected Attributes

bool in_string_op_chain_ {false}
 
std::vector< std::shared_ptr
< Analyzer::Expr > > 
chained_string_op_exprs_
 
std::unordered_map< const
Analyzer::Expr *, const
SQLTypeInfo
casts_
 
int32_t num_overflows_
 

Private Member Functions

template<typename T >
bool foldComparison (SQLOps optype, T t1, T t2) const
 
template<typename T >
bool foldLogic (SQLOps optype, T t1, T t2) const
 
template<typename T >
foldArithmetic (SQLOps optype, T t1, T t2) const
 
bool foldOper (SQLOps optype, SQLTypes type, Datum lhs, Datum rhs, Datum &result, SQLTypes &result_type) const
 
std::shared_ptr< Analyzer::ExprvisitUOper (const Analyzer::UOper *uoper) const override
 
std::shared_ptr< Analyzer::ExprvisitBinOper (const Analyzer::BinOper *bin_oper) const override
 
std::shared_ptr< Analyzer::ExprvisitStringOper (const Analyzer::StringOper *string_oper) const override
 

Additional Inherited Members

- Protected Types inherited from DeepCopyVisitor
using RetType = std::shared_ptr< Analyzer::Expr >
 
- Protected Member Functions inherited from DeepCopyVisitor
RetType visitColumnVar (const Analyzer::ColumnVar *col_var) const override
 
RetType visitColumnVarTuple (const Analyzer::ExpressionTuple *col_var_tuple) const override
 
RetType visitVar (const Analyzer::Var *var) const override
 
RetType visitConstant (const Analyzer::Constant *constant) const override
 
RetType visitGeoExpr (const Analyzer::GeoExpr *geo_expr) const override
 
RetType visitInValues (const Analyzer::InValues *in_values) const override
 
RetType visitInIntegerSet (const Analyzer::InIntegerSet *in_integer_set) const override
 
RetType visitCharLength (const Analyzer::CharLengthExpr *char_length) const override
 
RetType visitKeyForString (const Analyzer::KeyForStringExpr *expr) const override
 
RetType visitSampleRatio (const Analyzer::SampleRatioExpr *expr) const override
 
RetType visitMLPredict (const Analyzer::MLPredictExpr *expr) const override
 
RetType visitPCAProject (const Analyzer::PCAProjectExpr *expr) const override
 
RetType visitCardinality (const Analyzer::CardinalityExpr *cardinality) const override
 
RetType visitLikeExpr (const Analyzer::LikeExpr *like) const override
 
RetType visitRegexpExpr (const Analyzer::RegexpExpr *regexp) const override
 
RetType visitWidthBucket (const Analyzer::WidthBucketExpr *width_bucket_expr) const override
 
RetType visitCaseExpr (const Analyzer::CaseExpr *case_expr) const override
 
RetType visitDatetruncExpr (const Analyzer::DatetruncExpr *datetrunc) const override
 
RetType visitExtractExpr (const Analyzer::ExtractExpr *extract) const override
 
RetType visitArrayOper (const Analyzer::ArrayExpr *array_expr) const override
 
RetType visitGeoUOper (const Analyzer::GeoUOper *geo_expr) const override
 
RetType visitGeoBinOper (const Analyzer::GeoBinOper *geo_expr) const override
 
RetType visitWindowFunction (const Analyzer::WindowFunction *window_func) const override
 
RetType visitFunctionOper (const Analyzer::FunctionOper *func_oper) const override
 
RetType visitDatediffExpr (const Analyzer::DatediffExpr *datediff) const override
 
RetType visitDateaddExpr (const Analyzer::DateaddExpr *dateadd) const override
 
RetType visitFunctionOperWithCustomTypeHandling (const Analyzer::FunctionOperWithCustomTypeHandling *func_oper) const override
 
RetType visitLikelihood (const Analyzer::LikelihoodExpr *likelihood) const override
 
RetType visitAggExpr (const Analyzer::AggExpr *agg) const override
 
RetType visitOffsetInFragment (const Analyzer::OffsetInFragment *) const override
 
- Protected Member Functions inherited from ScalarExprVisitor< std::shared_ptr< Analyzer::Expr > >
virtual std::shared_ptr
< Analyzer::Expr
visitRangeJoinOper (const Analyzer::RangeOper *range_oper) const
 
virtual std::shared_ptr
< Analyzer::Expr
aggregateResult (const std::shared_ptr< Analyzer::Expr > &aggregate, const std::shared_ptr< Analyzer::Expr > &next_result) const
 
virtual void visitBegin () const
 
virtual std::shared_ptr
< Analyzer::Expr
defaultResult () const
 

Detailed Description

Definition at line 224 of file ExpressionRewrite.cpp.

Constructor & Destructor Documentation

anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::ConstantFoldingVisitor ( )
inline

Member Function Documentation

template<typename T >
T anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::foldArithmetic ( SQLOps  optype,
t1,
t2 
) const
inlineprivate

Definition at line 264 of file ExpressionRewrite.cpp.

References kDIVIDE, kMINUS, kMULTIPLY, and kPLUS.

264  {
265  bool t2_is_zero = (t2 == (t2 - t2));
266  bool t2_is_negative = (t2 < (t2 - t2));
267  switch (optype) {
268  case kPLUS:
269  // The MIN limit for float and double is the smallest representable value,
270  // not the lowest negative value! Switching to C++11 lowest.
271  if ((t2_is_negative && t1 < std::numeric_limits<T>::lowest() - t2) ||
272  (!t2_is_negative && t1 > std::numeric_limits<T>::max() - t2)) {
273  num_overflows_++;
274  throw std::runtime_error("Plus overflow");
275  }
276  return t1 + t2;
277  case kMINUS:
278  if ((t2_is_negative && t1 > std::numeric_limits<T>::max() + t2) ||
279  (!t2_is_negative && t1 < std::numeric_limits<T>::lowest() + t2)) {
280  num_overflows_++;
281  throw std::runtime_error("Minus overflow");
282  }
283  return t1 - t2;
284  case kMULTIPLY: {
285  if (t2_is_zero) {
286  return t2;
287  }
288  auto ct1 = t1;
289  auto ct2 = t2;
290  // Need to keep t2's sign on the left
291  if (t2_is_negative) {
292  if (t1 == std::numeric_limits<T>::lowest() ||
293  t2 == std::numeric_limits<T>::lowest()) {
294  // negation could overflow - bail
295  num_overflows_++;
296  throw std::runtime_error("Mul neg overflow");
297  }
298  ct1 = -t1; // ct1 gets t2's negativity
299  ct2 = -t2; // ct2 is now positive
300  }
301  // Don't check overlow if we are folding FP mul by a fraction
302  bool ct2_is_fraction = (ct2 < (ct2 / ct2));
303  if (!ct2_is_fraction) {
304  if (ct1 > std::numeric_limits<T>::max() / ct2 ||
305  ct1 < std::numeric_limits<T>::lowest() / ct2) {
306  num_overflows_++;
307  throw std::runtime_error("Mul overflow");
308  }
309  }
310  return t1 * t2;
311  }
312  case kDIVIDE:
313  if (t2_is_zero) {
314  throw std::runtime_error("Will not fold division by zero");
315  }
316  return t1 / t2;
317  default:
318  break;
319  }
320  throw std::runtime_error("Unable to fold");
321  }
Definition: sqldefs.h:43
Definition: sqldefs.h:42
template<typename T >
bool anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::foldComparison ( SQLOps  optype,
t1,
t2 
) const
inlineprivate

Definition at line 226 of file ExpressionRewrite.cpp.

References kEQ, kGE, kGT, kLE, kLT, and kNE.

226  {
227  switch (optype) {
228  case kEQ:
229  return t1 == t2;
230  case kNE:
231  return t1 != t2;
232  case kLT:
233  return t1 < t2;
234  case kLE:
235  return t1 <= t2;
236  case kGT:
237  return t1 > t2;
238  case kGE:
239  return t1 >= t2;
240  default:
241  break;
242  }
243  throw std::runtime_error("Unable to fold");
244  return false;
245  }
Definition: sqldefs.h:37
Definition: sqldefs.h:38
Definition: sqldefs.h:32
Definition: sqldefs.h:36
Definition: sqldefs.h:34
Definition: sqldefs.h:35
template<typename T >
bool anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::foldLogic ( SQLOps  optype,
t1,
t2 
) const
inlineprivate

Definition at line 248 of file ExpressionRewrite.cpp.

References kAND, kNOT, and kOR.

248  {
249  switch (optype) {
250  case kAND:
251  return t1 && t2;
252  case kOR:
253  return t1 || t2;
254  case kNOT:
255  return !t1;
256  default:
257  break;
258  }
259  throw std::runtime_error("Unable to fold");
260  return false;
261  }
Definition: sqldefs.h:40
Definition: sqldefs.h:39
Definition: sqldefs.h:41
bool anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::foldOper ( SQLOps  optype,
SQLTypes  type,
Datum  lhs,
Datum  rhs,
Datum result,
SQLTypes result_type 
) const
inlineprivate

Definition at line 323 of file ExpressionRewrite.cpp.

References Datum::bigintval, Datum::boolval, CHECK, Datum::doubleval, Datum::floatval, Datum::intval, IS_ARITHMETIC, IS_COMPARISON, IS_LOGIC, kBIGINT, kBOOLEAN, kDOUBLE, kFLOAT, kINT, kSMALLINT, kTINYINT, Datum::smallintval, Datum::tinyintval, and run_benchmark_import::type.

328  {
329  result_type = type;
330 
331  try {
332  switch (type) {
333  case kBOOLEAN:
334  if (IS_COMPARISON(optype)) {
335  result.boolval = foldComparison<bool>(optype, lhs.boolval, rhs.boolval);
336  result_type = kBOOLEAN;
337  return true;
338  }
339  if (IS_LOGIC(optype)) {
340  result.boolval = foldLogic<bool>(optype, lhs.boolval, rhs.boolval);
341  result_type = kBOOLEAN;
342  return true;
343  }
344  CHECK(!IS_ARITHMETIC(optype));
345  break;
346  case kTINYINT:
347  if (IS_COMPARISON(optype)) {
348  result.boolval =
349  foldComparison<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
350  result_type = kBOOLEAN;
351  return true;
352  }
353  if (IS_ARITHMETIC(optype)) {
354  result.tinyintval =
355  foldArithmetic<int8_t>(optype, lhs.tinyintval, rhs.tinyintval);
356  result_type = kTINYINT;
357  return true;
358  }
359  CHECK(!IS_LOGIC(optype));
360  break;
361  case kSMALLINT:
362  if (IS_COMPARISON(optype)) {
363  result.boolval =
364  foldComparison<int16_t>(optype, lhs.smallintval, rhs.smallintval);
365  result_type = kBOOLEAN;
366  return true;
367  }
368  if (IS_ARITHMETIC(optype)) {
369  result.smallintval =
370  foldArithmetic<int16_t>(optype, lhs.smallintval, rhs.smallintval);
371  result_type = kSMALLINT;
372  return true;
373  }
374  CHECK(!IS_LOGIC(optype));
375  break;
376  case kINT:
377  if (IS_COMPARISON(optype)) {
378  result.boolval = foldComparison<int32_t>(optype, lhs.intval, rhs.intval);
379  result_type = kBOOLEAN;
380  return true;
381  }
382  if (IS_ARITHMETIC(optype)) {
383  result.intval = foldArithmetic<int32_t>(optype, lhs.intval, rhs.intval);
384  result_type = kINT;
385  return true;
386  }
387  CHECK(!IS_LOGIC(optype));
388  break;
389  case kBIGINT:
390  if (IS_COMPARISON(optype)) {
391  result.boolval =
392  foldComparison<int64_t>(optype, lhs.bigintval, rhs.bigintval);
393  result_type = kBOOLEAN;
394  return true;
395  }
396  if (IS_ARITHMETIC(optype)) {
397  result.bigintval =
398  foldArithmetic<int64_t>(optype, lhs.bigintval, rhs.bigintval);
399  result_type = kBIGINT;
400  return true;
401  }
402  CHECK(!IS_LOGIC(optype));
403  break;
404  case kFLOAT:
405  if (IS_COMPARISON(optype)) {
406  result.boolval = foldComparison<float>(optype, lhs.floatval, rhs.floatval);
407  result_type = kBOOLEAN;
408  return true;
409  }
410  if (IS_ARITHMETIC(optype)) {
411  result.floatval = foldArithmetic<float>(optype, lhs.floatval, rhs.floatval);
412  result_type = kFLOAT;
413  return true;
414  }
415  CHECK(!IS_LOGIC(optype));
416  break;
417  case kDOUBLE:
418  if (IS_COMPARISON(optype)) {
419  result.boolval = foldComparison<double>(optype, lhs.doubleval, rhs.doubleval);
420  result_type = kBOOLEAN;
421  return true;
422  }
423  if (IS_ARITHMETIC(optype)) {
424  result.doubleval =
425  foldArithmetic<double>(optype, lhs.doubleval, rhs.doubleval);
426  result_type = kDOUBLE;
427  return true;
428  }
429  CHECK(!IS_LOGIC(optype));
430  break;
431  default:
432  break;
433  }
434  } catch (...) {
435  return false;
436  }
437  return false;
438  }
int8_t tinyintval
Definition: Datum.h:73
#define IS_LOGIC(X)
Definition: sqldefs.h:64
int8_t boolval
Definition: Datum.h:72
int32_t intval
Definition: Datum.h:75
float floatval
Definition: Datum.h:77
int64_t bigintval
Definition: Datum.h:76
int16_t smallintval
Definition: Datum.h:74
#define IS_ARITHMETIC(X)
Definition: sqldefs.h:65
#define CHECK(condition)
Definition: Logger.h:291
Definition: sqltypes.h:72
#define IS_COMPARISON(X)
Definition: sqldefs.h:61
double doubleval
Definition: Datum.h:78
int32_t anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::get_num_overflows ( )
inline
void anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::reset_num_overflows ( )
inline
std::shared_ptr<Analyzer::Expr> anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::visitBinOper ( const Analyzer::BinOper bin_oper) const
inlineoverrideprivatevirtual

Reimplemented from DeepCopyVisitor.

Definition at line 520 of file ExpressionRewrite.cpp.

References Datum::boolval, CHECK, decimal_to_int_type(), f(), Analyzer::Expr::get_contains_agg(), Analyzer::BinOper::get_left_operand(), Analyzer::BinOper::get_optype(), Analyzer::BinOper::get_own_left_operand(), Analyzer::BinOper::get_own_right_operand(), Analyzer::BinOper::get_qualifier(), Analyzer::Expr::get_type_info(), IS_COMPARISON, SQLTypeInfo::is_decimal(), kAND, kBOOLEAN, kDIVIDE, kDOUBLE, kEQ, kFLOAT, kGE, kGT, kLE, kLT, kMINUS, kMULTIPLY, kNE, kOR, and kPLUS.

521  {
522  const auto optype = bin_oper->get_optype();
523  auto ti = bin_oper->get_type_info();
524  auto left_operand = bin_oper->get_own_left_operand();
525  auto right_operand = bin_oper->get_own_right_operand();
526 
527  // Check if bin_oper result is cast to a larger int or fp type
528  if (casts_.find(bin_oper) != casts_.end()) {
529  const auto cast_ti = casts_[bin_oper];
530  const auto& lhs_ti = bin_oper->get_left_operand()->get_type_info();
531  // Propagate cast down to the operands for folding
532  if ((cast_ti.is_integer() || cast_ti.is_fp()) && lhs_ti.is_integer() &&
533  cast_ti.get_size() > lhs_ti.get_size() &&
534  (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY)) {
535  // Before folding, cast the operands to the bigger type to avoid overflows.
536  // Currently upcasting smaller integer types to larger integers or double.
537  left_operand = left_operand->deep_copy()->add_cast(cast_ti);
538  right_operand = right_operand->deep_copy()->add_cast(cast_ti);
539  ti = cast_ti;
540  }
541  }
542 
543  const auto lhs = visit(left_operand.get());
544  const auto rhs = visit(right_operand.get());
545 
546  auto const_lhs = std::dynamic_pointer_cast<Analyzer::Constant>(lhs);
547  auto const_rhs = std::dynamic_pointer_cast<Analyzer::Constant>(rhs);
548  const auto& lhs_ti = lhs->get_type_info();
549  const auto& rhs_ti = rhs->get_type_info();
550  auto lhs_type = lhs_ti.is_decimal() ? decimal_to_int_type(lhs_ti) : lhs_ti.get_type();
551  auto rhs_type = rhs_ti.is_decimal() ? decimal_to_int_type(rhs_ti) : rhs_ti.get_type();
552 
553  if (const_lhs && const_rhs && lhs_type == rhs_type) {
554  auto lhs_datum = const_lhs->get_constval();
555  auto rhs_datum = const_rhs->get_constval();
556  Datum result_datum = {};
557  SQLTypes result_type;
558  if (foldOper(optype, lhs_type, lhs_datum, rhs_datum, result_datum, result_type)) {
559  // Fold all ops that don't take in decimal operands, and also decimal comparisons
560  if (!lhs_ti.is_decimal() || IS_COMPARISON(optype)) {
561  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
562  }
563  // Decimal arithmetic has been done as kBIGINT. Selectively fold some decimal ops,
564  // using result_datum and BinOper expr typeinfo which was adjusted for these ops.
565  if (optype == kMINUS || optype == kPLUS || optype == kMULTIPLY) {
566  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
567  }
568  }
569  }
570 
571  if (optype == kAND && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
572  if (const_rhs && !const_rhs->get_is_null()) {
573  auto rhs_datum = const_rhs->get_constval();
574  if (rhs_datum.boolval == false) {
575  Datum d;
576  d.boolval = false;
577  // lhs && false --> false
578  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
579  }
580  // lhs && true --> lhs
581  return lhs;
582  }
583  if (const_lhs && !const_lhs->get_is_null()) {
584  auto lhs_datum = const_lhs->get_constval();
585  if (lhs_datum.boolval == false) {
586  Datum d;
587  d.boolval = false;
588  // false && rhs --> false
589  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
590  }
591  // true && rhs --> rhs
592  return rhs;
593  }
594  }
595  if (optype == kOR && lhs_type == rhs_type && lhs_type == kBOOLEAN) {
596  if (const_rhs && !const_rhs->get_is_null()) {
597  auto rhs_datum = const_rhs->get_constval();
598  if (rhs_datum.boolval == true) {
599  Datum d;
600  d.boolval = true;
601  // lhs || true --> true
602  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
603  }
604  // lhs || false --> lhs
605  return lhs;
606  }
607  if (const_lhs && !const_lhs->get_is_null()) {
608  auto lhs_datum = const_lhs->get_constval();
609  if (lhs_datum.boolval == true) {
610  Datum d;
611  d.boolval = true;
612  // true || rhs --> true
613  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
614  }
615  // false || rhs --> rhs
616  return rhs;
617  }
618  }
619  if (*lhs == *rhs) {
620  if (!lhs_ti.get_notnull()) {
621  CHECK(!rhs_ti.get_notnull());
622  // We can't fold the ostensible tautaulogy
623  // for nullable lhs and rhs types, as
624  // lhs <> rhs when they are null
625 
626  // We likely could turn this into a lhs is not null
627  // operatation, but is it worth it?
628  return makeExpr<Analyzer::BinOper>(ti,
629  bin_oper->get_contains_agg(),
630  bin_oper->get_optype(),
631  bin_oper->get_qualifier(),
632  lhs,
633  rhs);
634  }
635  CHECK(rhs_ti.get_notnull());
636  // Tautologies: v=v; v<=v; v>=v
637  if (optype == kEQ || optype == kLE || optype == kGE) {
638  Datum d;
639  d.boolval = true;
640  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
641  }
642  // Contradictions: v!=v; v<v; v>v
643  if (optype == kNE || optype == kLT || optype == kGT) {
644  Datum d;
645  d.boolval = false;
646  return makeExpr<Analyzer::Constant>(kBOOLEAN, false, d);
647  }
648  // v-v
649  if (optype == kMINUS) {
650  Datum d = {};
651  return makeExpr<Analyzer::Constant>(lhs_type, false, d);
652  }
653  }
654  // Convert fp division by a constant to multiplication by 1/constant
655  if (optype == kDIVIDE && const_rhs && rhs_ti.is_fp()) {
656  auto rhs_datum = const_rhs->get_constval();
657  std::shared_ptr<Analyzer::Expr> recip_rhs = nullptr;
658  if (rhs_ti.get_type() == kFLOAT) {
659  if (rhs_datum.floatval == 1.0) {
660  return lhs;
661  }
662  auto f = std::fabs(rhs_datum.floatval);
663  if (f > 1.0 || (f != 0.0 && 1.0 < f * std::numeric_limits<float>::max())) {
664  rhs_datum.floatval = 1.0 / rhs_datum.floatval;
665  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
666  }
667  } else if (rhs_ti.get_type() == kDOUBLE) {
668  if (rhs_datum.doubleval == 1.0) {
669  return lhs;
670  }
671  auto d = std::fabs(rhs_datum.doubleval);
672  if (d > 1.0 || (d != 0.0 && 1.0 < d * std::numeric_limits<double>::max())) {
673  rhs_datum.doubleval = 1.0 / rhs_datum.doubleval;
674  recip_rhs = makeExpr<Analyzer::Constant>(rhs_type, false, rhs_datum);
675  }
676  }
677  if (recip_rhs) {
678  return makeExpr<Analyzer::BinOper>(ti,
679  bin_oper->get_contains_agg(),
680  kMULTIPLY,
681  bin_oper->get_qualifier(),
682  lhs,
683  recip_rhs);
684  }
685  }
686 
687  return makeExpr<Analyzer::BinOper>(ti,
688  bin_oper->get_contains_agg(),
689  bin_oper->get_optype(),
690  bin_oper->get_qualifier(),
691  lhs,
692  rhs);
693  }
SQLTypes
Definition: sqltypes.h:65
Definition: sqldefs.h:37
int8_t boolval
Definition: Datum.h:72
Definition: sqldefs.h:38
Definition: sqldefs.h:40
bool get_contains_agg() const
Definition: Analyzer.h:81
Definition: sqldefs.h:32
Definition: sqldefs.h:43
std::shared_ptr< Analyzer::Expr > visit(const Analyzer::Expr *expr) const
SQLOps get_optype() const
Definition: Analyzer.h:452
Definition: sqldefs.h:39
bool foldOper(SQLOps optype, SQLTypes type, Datum lhs, Datum rhs, Datum &result, SQLTypes &result_type) const
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
SQLTypes decimal_to_int_type(const SQLTypeInfo &ti)
Definition: Datum.cpp:561
Definition: sqldefs.h:36
Definition: sqldefs.h:42
Definition: sqldefs.h:34
torch::Tensor f(torch::Tensor x, torch::Tensor W_target, torch::Tensor b_target)
std::unordered_map< const Analyzer::Expr *, const SQLTypeInfo > casts_
#define CHECK(condition)
Definition: Logger.h:291
Definition: sqldefs.h:35
const Expr * get_left_operand() const
Definition: Analyzer.h:455
const std::shared_ptr< Analyzer::Expr > get_own_right_operand() const
Definition: Analyzer.h:460
Definition: Datum.h:71
bool is_decimal() const
Definition: sqltypes.h:570
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const
Definition: Analyzer.h:457
#define IS_COMPARISON(X)
Definition: sqldefs.h:61
SQLQualifier get_qualifier() const
Definition: Analyzer.h:454

+ Here is the call graph for this function:

std::shared_ptr<Analyzer::Expr> anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::visitStringOper ( const Analyzer::StringOper string_oper) const
inlineoverrideprivatevirtual

Reimplemented from DeepCopyVisitor.

Definition at line 695 of file ExpressionRewrite.cpp.

References Parser::StringLiteral::analyzeValue(), StringOps_Namespace::apply_numeric_op_to_literals(), StringOps_Namespace::apply_string_op_to_literals(), CHECK, Analyzer::StringOper::get_kind(), Analyzer::Expr::get_type_info(), Analyzer::StringOper::getArity(), Analyzer::StringOper::getLiteralArgs(), Analyzer::StringOper::getNonLiteralsArity(), Analyzer::StringOper::getOwnArgs(), and IsNullDatum().

696  {
697  // Todo(todd): For clarity and modularity we should move string
698  // operator rewrites into their own visitor class.
699  // String operation rewrites were originally put here as they only
700  // handled string operators on rewrite, but now handle variable
701  // inputs as well.
702  const auto original_args = string_oper->getOwnArgs();
703  std::vector<std::shared_ptr<Analyzer::Expr>> rewritten_args;
704  const auto non_literal_arity = string_oper->getNonLiteralsArity();
705  const auto parent_in_string_op_chain = in_string_op_chain_;
706  const auto in_string_op_chain = non_literal_arity <= 1UL;
707  in_string_op_chain_ = in_string_op_chain;
708 
709  size_t rewritten_arg_literal_arity = 0;
710  for (auto original_arg : original_args) {
711  rewritten_args.emplace_back(visit(original_arg.get()));
712  if (dynamic_cast<const Analyzer::Constant*>(rewritten_args.back().get())) {
713  rewritten_arg_literal_arity++;
714  }
715  }
716  in_string_op_chain_ = parent_in_string_op_chain;
717  const auto kind = string_oper->get_kind();
718  const auto& return_ti = string_oper->get_type_info();
719 
720  if (string_oper->getArity() == rewritten_arg_literal_arity) {
721  Analyzer::StringOper literal_string_oper(
722  kind, string_oper->get_type_info(), rewritten_args);
723  const auto literal_args = literal_string_oper.getLiteralArgs();
724  const auto string_op_info =
725  StringOps_Namespace::StringOpInfo(kind, return_ti, literal_args);
726  if (return_ti.is_string()) {
727  const auto literal_result =
729  return Parser::StringLiteral::analyzeValue(literal_result.first,
730  literal_result.second);
731  }
732  const auto literal_datum =
734  auto nullable_return_ti = return_ti;
735  nullable_return_ti.set_notnull(false);
736  return makeExpr<Analyzer::Constant>(nullable_return_ti,
737  IsNullDatum(literal_datum, nullable_return_ti),
738  literal_datum);
739  }
740  chained_string_op_exprs_.emplace_back(
741  makeExpr<Analyzer::StringOper>(kind, return_ti, rewritten_args));
742  if (parent_in_string_op_chain && in_string_op_chain) {
743  CHECK(rewritten_args[0]->get_type_info().is_string());
744  return rewritten_args[0]->deep_copy();
745  } else {
746  auto new_string_oper = makeExpr<Analyzer::StringOper>(
747  kind, return_ti, rewritten_args, chained_string_op_exprs_);
748  chained_string_op_exprs_.clear();
749  return new_string_oper;
750  }
751  }
Datum apply_numeric_op_to_literals(const StringOpInfo &string_op_info)
Definition: StringOps.cpp:1283
size_t getArity() const
Definition: Analyzer.h:1674
std::shared_ptr< Analyzer::Expr > visit(const Analyzer::Expr *expr) const
LiteralArgMap getLiteralArgs() const
Definition: Analyzer.cpp:4310
std::pair< std::string, bool > apply_string_op_to_literals(const StringOpInfo &string_op_info)
Definition: StringOps.cpp:1272
bool IsNullDatum(const Datum datum, const SQLTypeInfo &ti)
Definition: Datum.cpp:331
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
std::vector< std::shared_ptr< Analyzer::Expr > > chained_string_op_exprs_
Expression class for string functions The &quot;arg&quot; constructor parameter must be an expression that reso...
Definition: Analyzer.h:1601
static std::shared_ptr< Analyzer::Expr > analyzeValue(const std::string &stringval, const bool is_null)
Definition: ParserNode.cpp:147
SqlStringOpKind get_kind() const
Definition: Analyzer.h:1672
#define CHECK(condition)
Definition: Logger.h:291
size_t getNonLiteralsArity() const
Definition: Analyzer.h:1686
std::vector< std::shared_ptr< Analyzer::Expr > > getOwnArgs() const
Definition: Analyzer.h:1698

+ Here is the call graph for this function:

std::shared_ptr<Analyzer::Expr> anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::visitUOper ( const Analyzer::UOper uoper) const
inlineoverrideprivatevirtual

Reimplemented from DeepCopyVisitor.

Definition at line 440 of file ExpressionRewrite.cpp.

References CHECK_EQ, decimal_to_int_type(), Analyzer::Constant::get_constval(), Analyzer::Expr::get_contains_agg(), Analyzer::UOper::get_operand(), Analyzer::UOper::get_optype(), Analyzer::Expr::get_type_info(), kBOOLEAN, kCAST, kEQ, kMINUS, kNOT, and kUMINUS.

441  {
442  const auto unvisited_operand = uoper->get_operand();
443  const auto optype = uoper->get_optype();
444  const auto& ti = uoper->get_type_info();
445  if (optype == kCAST) {
446  // Cache the cast type so it could be used in operand rewriting/folding
447  casts_.insert({unvisited_operand, ti});
448  }
449  const auto operand = visit(unvisited_operand);
450 
451  const auto& operand_ti = operand->get_type_info();
452  const auto operand_type =
453  operand_ti.is_decimal() ? decimal_to_int_type(operand_ti) : operand_ti.get_type();
454  const auto const_operand =
455  std::dynamic_pointer_cast<const Analyzer::Constant>(operand);
456 
457  if (const_operand) {
458  const auto operand_datum = const_operand->get_constval();
459  Datum zero_datum = {};
460  Datum result_datum = {};
461  SQLTypes result_type;
462  switch (optype) {
463  case kNOT: {
464  if (foldOper(kEQ,
465  operand_type,
466  zero_datum,
467  operand_datum,
468  result_datum,
469  result_type)) {
470  CHECK_EQ(result_type, kBOOLEAN);
471  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
472  }
473  break;
474  }
475  case kUMINUS: {
476  if (foldOper(kMINUS,
477  operand_type,
478  zero_datum,
479  operand_datum,
480  result_datum,
481  result_type)) {
482  if (!operand_ti.is_decimal()) {
483  return makeExpr<Analyzer::Constant>(result_type, false, result_datum);
484  }
485  return makeExpr<Analyzer::Constant>(ti, false, result_datum);
486  }
487  break;
488  }
489  case kCAST: {
490  // Trying to fold number to number casts only
491  if (!ti.is_number() || !operand_ti.is_number()) {
492  break;
493  }
494  // Disallowing folding of FP to DECIMAL casts for now:
495  // allowing them would make this test pass:
496  // update dectest set d=cast( 1234.0 as float );
497  // which is expected to throw in Update.ImplicitCastToNumericTypes
498  // due to cast codegen currently not supporting these casts
499  if (ti.is_decimal() && operand_ti.is_fp()) {
500  break;
501  }
502  auto operand_copy = const_operand->deep_copy();
503  auto cast_operand = operand_copy->add_cast(ti);
504  auto const_cast_operand =
505  std::dynamic_pointer_cast<const Analyzer::Constant>(cast_operand);
506  if (const_cast_operand) {
507  auto const_cast_datum = const_cast_operand->get_constval();
508  return makeExpr<Analyzer::Constant>(ti, false, const_cast_datum);
509  }
510  }
511  default:
512  break;
513  }
514  }
515 
516  return makeExpr<Analyzer::UOper>(
517  uoper->get_type_info(), uoper->get_contains_agg(), optype, operand);
518  }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
SQLTypes
Definition: sqltypes.h:65
bool get_contains_agg() const
Definition: Analyzer.h:81
Definition: sqldefs.h:51
Definition: sqldefs.h:32
std::shared_ptr< Analyzer::Expr > visit(const Analyzer::Expr *expr) const
bool foldOper(SQLOps optype, SQLTypes type, Datum lhs, Datum rhs, Datum &result, SQLTypes &result_type) const
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
SQLTypes decimal_to_int_type(const SQLTypeInfo &ti)
Definition: Datum.cpp:561
Definition: sqldefs.h:42
const Expr * get_operand() const
Definition: Analyzer.h:384
Datum get_constval() const
Definition: Analyzer.h:348
std::unordered_map< const Analyzer::Expr *, const SQLTypeInfo > casts_
Definition: Datum.h:71
Definition: sqldefs.h:41
SQLOps get_optype() const
Definition: Analyzer.h:383

+ Here is the call graph for this function:

Member Data Documentation

std::unordered_map<const Analyzer::Expr*, const SQLTypeInfo> anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::casts_
mutableprotected

Definition at line 756 of file ExpressionRewrite.cpp.

std::vector<std::shared_ptr<Analyzer::Expr> > anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::chained_string_op_exprs_
mutableprotected

Definition at line 755 of file ExpressionRewrite.cpp.

bool anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::in_string_op_chain_ {false}
mutableprotected

Definition at line 754 of file ExpressionRewrite.cpp.

int32_t anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::num_overflows_
mutableprotected

Definition at line 757 of file ExpressionRewrite.cpp.


The documentation for this class was generated from the following file: