OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StandardConvertletTable.java
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements. See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to you under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License. You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 // clang-format off
19 
20 package org.apache.calcite.sql2rel;
21 
22 // HEAVY.AI new
24 // end HEAVY.AI new
25 
26 import org.apache.calcite.avatica.util.DateTimeUtils;
27 import org.apache.calcite.avatica.util.TimeUnit;
28 import org.apache.calcite.plan.RelOptUtil;
29 import org.apache.calcite.rel.type.RelDataType;
30 import org.apache.calcite.rel.type.RelDataTypeFactory;
31 import org.apache.calcite.rel.type.RelDataTypeFamily;
32 import org.apache.calcite.rex.RexBuilder;
33 import org.apache.calcite.rex.RexCall;
34 import org.apache.calcite.rex.RexCallBinding;
35 import org.apache.calcite.rex.RexLiteral;
36 import org.apache.calcite.rex.RexNode;
37 import org.apache.calcite.rex.RexRangeRef;
38 import org.apache.calcite.rex.RexUtil;
39 import org.apache.calcite.sql.SqlAggFunction;
40 import org.apache.calcite.sql.SqlBinaryOperator;
41 import org.apache.calcite.sql.SqlCall;
42 import org.apache.calcite.sql.SqlDataTypeSpec;
43 import org.apache.calcite.sql.SqlFunction;
44 import org.apache.calcite.sql.SqlFunctionCategory;
45 import org.apache.calcite.sql.SqlIdentifier;
46 import org.apache.calcite.sql.SqlIntervalLiteral;
47 import org.apache.calcite.sql.SqlIntervalQualifier;
48 import org.apache.calcite.sql.SqlJdbcFunctionCall;
49 import org.apache.calcite.sql.SqlKind;
50 import org.apache.calcite.sql.SqlLiteral;
51 import org.apache.calcite.sql.SqlNode;
52 import org.apache.calcite.sql.SqlNodeList;
53 import org.apache.calcite.sql.SqlNumericLiteral;
55 import org.apache.calcite.sql.SqlUtil;
56 import org.apache.calcite.sql.SqlWindowTableFunction;
57 import org.apache.calcite.sql.fun.SqlArrayValueConstructor;
58 import org.apache.calcite.sql.fun.SqlBetweenOperator;
59 import org.apache.calcite.sql.fun.SqlCase;
60 import org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator;
61 import org.apache.calcite.sql.fun.SqlExtractFunction;
62 import org.apache.calcite.sql.fun.SqlJsonValueFunction;
63 import org.apache.calcite.sql.fun.SqlLibraryOperators;
64 import org.apache.calcite.sql.fun.SqlLiteralChainOperator;
65 import org.apache.calcite.sql.fun.SqlMapValueConstructor;
66 import org.apache.calcite.sql.fun.SqlMultisetQueryConstructor;
67 import org.apache.calcite.sql.fun.SqlMultisetValueConstructor;
68 import org.apache.calcite.sql.fun.SqlOverlapsOperator;
69 import org.apache.calcite.sql.fun.SqlRowOperator;
70 import org.apache.calcite.sql.fun.SqlSequenceValueOperator;
71 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
72 import org.apache.calcite.sql.fun.SqlTrimFunction;
73 import org.apache.calcite.sql.parser.SqlParserPos;
74 import org.apache.calcite.sql.type.SqlOperandTypeChecker;
75 import org.apache.calcite.sql.type.SqlTypeFamily;
76 import org.apache.calcite.sql.type.SqlTypeName;
77 import org.apache.calcite.sql.type.SqlTypeUtil;
78 import org.apache.calcite.sql.validate.SqlValidator;
80 import org.apache.calcite.util.Pair;
81 import org.apache.calcite.util.Util;
82 
83 import com.google.common.collect.ImmutableList;
84 import com.google.common.collect.Lists;
85 
86 import java.math.BigDecimal;
87 import java.math.RoundingMode;
88 import java.util.ArrayList;
89 import java.util.List;
90 import java.util.Objects;
91 
95 public class StandardConvertletTable extends ReflectiveConvertletTable {
98 
99  //~ Constructors -----------------------------------------------------------
100 
102  super();
103 
104  // Register aliases (operators which have a different name but
105  // identical behavior to other operators).
106  addAlias(SqlStdOperatorTable.CHARACTER_LENGTH, SqlStdOperatorTable.CHAR_LENGTH);
107  addAlias(SqlStdOperatorTable.IS_UNKNOWN, SqlStdOperatorTable.IS_NULL);
108  addAlias(SqlStdOperatorTable.IS_NOT_UNKNOWN, SqlStdOperatorTable.IS_NOT_NULL);
109  addAlias(SqlStdOperatorTable.PERCENT_REMAINDER, SqlStdOperatorTable.MOD);
110 
111  // Register convertlets for specific objects.
112  registerOp(SqlStdOperatorTable.CAST, this::convertCast);
113  registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast);
114 
115  // HEAVY.AI new
116  registerOp(HeavyDBSqlOperatorTable.TRY_CAST, this::convertTryCast);
117  // end HEAVY.AI new
118 
119  registerOp(SqlStdOperatorTable.IS_DISTINCT_FROM,
120  (cx, call) -> convertIsDistinctFrom(cx, call, false));
121  registerOp(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
122  (cx, call) -> convertIsDistinctFrom(cx, call, true));
123 
124  registerOp(SqlStdOperatorTable.PLUS, this::convertPlus);
125 
126  registerOp(SqlStdOperatorTable.MINUS, (cx, call) -> {
127  final RexCall e = (RexCall) StandardConvertletTable.this.convertCall(
128  cx, call.getOperator(), call.getOperandList());
129  switch (e.getOperands().get(0).getType().getSqlTypeName()) {
130  case DATE:
131  case TIME:
132  case TIMESTAMP:
133  return convertDatetimeMinus(cx, SqlStdOperatorTable.MINUS_DATE, call);
134  default:
135  return e;
136  }
137  });
138 
139  registerOp(
140  SqlLibraryOperators.LTRIM, new TrimConvertlet(SqlTrimFunction.Flag.LEADING));
141  registerOp(
142  SqlLibraryOperators.RTRIM, new TrimConvertlet(SqlTrimFunction.Flag.TRAILING));
143 
144  registerOp(SqlLibraryOperators.GREATEST, new GreatestConvertlet());
145  registerOp(SqlLibraryOperators.LEAST, new GreatestConvertlet());
146 
147  registerOp(SqlLibraryOperators.NVL, (cx, call) -> {
148  final RexBuilder rexBuilder = cx.getRexBuilder();
149  final RexNode operand0 = cx.convertExpression(call.getOperandList().get(0));
150  final RexNode operand1 = cx.convertExpression(call.getOperandList().get(1));
151  final RelDataType type = cx.getValidator().getValidatedNodeType(call);
152  return rexBuilder.makeCall(type,
153  SqlStdOperatorTable.CASE,
154  ImmutableList.of(
155  rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, operand0),
156  rexBuilder.makeCast(type, operand0),
157  rexBuilder.makeCast(type, operand1)));
158  });
159 
160  registerOp(SqlLibraryOperators.DECODE, (cx, call) -> {
161  final RexBuilder rexBuilder = cx.getRexBuilder();
162  final List<RexNode> operands = convertExpressionList(
163  cx, call.getOperandList(), SqlOperandTypeChecker.Consistency.NONE);
164  final RelDataType type = cx.getValidator().getValidatedNodeType(call);
165  final List<RexNode> exprs = new ArrayList<>();
166  for (int i = 1; i < operands.size() - 1; i += 2) {
167  exprs.add(RelOptUtil.isDistinctFrom(
168  rexBuilder, operands.get(0), operands.get(i), true));
169  exprs.add(operands.get(i + 1));
170  }
171  if (operands.size() % 2 == 0) {
172  exprs.add(Util.last(operands));
173  } else {
174  exprs.add(rexBuilder.makeNullLiteral(type));
175  }
176  return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, exprs);
177  });
178 
179  // Expand "x NOT LIKE y" into "NOT (x LIKE y)"
180  registerOp(SqlStdOperatorTable.NOT_LIKE,
181  (cx, call)
182  -> cx.convertExpression(
183  SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
184  SqlStdOperatorTable.LIKE.createCall(
185  SqlParserPos.ZERO, call.getOperandList()))));
186 
187  // Expand "x NOT SIMILAR y" into "NOT (x SIMILAR y)"
188  registerOp(SqlStdOperatorTable.NOT_SIMILAR_TO,
189  (cx, call)
190  -> cx.convertExpression(
191  SqlStdOperatorTable.NOT.createCall(SqlParserPos.ZERO,
192  SqlStdOperatorTable.SIMILAR_TO.createCall(
193  SqlParserPos.ZERO, call.getOperandList()))));
194 
195  // Unary "+" has no effect, so expand "+ x" into "x".
196  registerOp(SqlStdOperatorTable.UNARY_PLUS,
197  (cx, call) -> cx.convertExpression(call.operand(0)));
198 
199  // "DOT"
200  registerOp(SqlStdOperatorTable.DOT,
201  (cx, call)
202  -> cx.getRexBuilder().makeFieldAccess(
203  cx.convertExpression(call.operand(0)),
204  call.operand(1).toString(),
205  false));
206  // "AS" has no effect, so expand "x AS id" into "x".
207  registerOp(
208  SqlStdOperatorTable.AS, (cx, call) -> cx.convertExpression(call.operand(0)));
209  // "SQRT(x)" is equivalent to "POWER(x, .5)"
210  registerOp(SqlStdOperatorTable.SQRT,
211  (cx, call)
212  -> cx.convertExpression(SqlStdOperatorTable.POWER.createCall(
213  SqlParserPos.ZERO,
214  call.operand(0),
215  SqlLiteral.createExactNumeric("0.5", SqlParserPos.ZERO))));
216 
217  // REVIEW jvs 24-Apr-2006: This only seems to be working from within a
218  // windowed agg. I have added an optimizer rule
219  // org.apache.calcite.rel.rules.AggregateReduceFunctionsRule which handles
220  // other cases post-translation. The reason I did that was to defer the
221  // implementation decision; e.g. we may want to push it down to a foreign
222  // server directly rather than decomposed; decomposition is easier than
223  // recognition.
224 
225  // Convert "avg(<expr>)" to "cast(sum(<expr>) / count(<expr>) as
226  // <type>)". We don't need to handle the empty set specially, because
227  // the SUM is already supposed to come out as NULL in cases where the
228  // COUNT is zero, so the null check should take place first and prevent
229  // division by zero. We need the cast because SUM and COUNT may use
230  // different types, say BIGINT.
231  //
232  // Similarly STDDEV_POP and STDDEV_SAMP, VAR_POP and VAR_SAMP.
233  registerOp(SqlStdOperatorTable.AVG, new AvgVarianceConvertlet(SqlKind.AVG));
234  registerOp(SqlStdOperatorTable.STDDEV_POP,
235  new AvgVarianceConvertlet(SqlKind.STDDEV_POP));
236  registerOp(SqlStdOperatorTable.STDDEV_SAMP,
237  new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
238  registerOp(
239  SqlStdOperatorTable.STDDEV, new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
240  registerOp(SqlStdOperatorTable.VAR_POP, new AvgVarianceConvertlet(SqlKind.VAR_POP));
241  registerOp(SqlStdOperatorTable.VAR_SAMP, new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
242  registerOp(SqlStdOperatorTable.VARIANCE, new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
243  registerOp(SqlStdOperatorTable.COVAR_POP,
244  new RegrCovarianceConvertlet(SqlKind.COVAR_POP));
245  registerOp(SqlStdOperatorTable.COVAR_SAMP,
246  new RegrCovarianceConvertlet(SqlKind.COVAR_SAMP));
247  registerOp(
248  SqlStdOperatorTable.REGR_SXX, new RegrCovarianceConvertlet(SqlKind.REGR_SXX));
249  registerOp(
250  SqlStdOperatorTable.REGR_SYY, new RegrCovarianceConvertlet(SqlKind.REGR_SYY));
251 
252  final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet();
253  registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet);
254  registerOp(SqlStdOperatorTable.CEIL, floorCeilConvertlet);
255 
256  registerOp(SqlStdOperatorTable.TIMESTAMP_ADD, new TimestampAddConvertlet());
257  registerOp(SqlStdOperatorTable.TIMESTAMP_DIFF, new TimestampDiffConvertlet());
258 
259  registerOp(SqlStdOperatorTable.INTERVAL, StandardConvertletTable::convertInterval);
260 
261  // Convert "element(<expr>)" to "$element_slice(<expr>)", if the
262  // expression is a multiset of scalars.
263  if (false) {
264  registerOp(SqlStdOperatorTable.ELEMENT, (cx, call) -> {
265  assert call.operandCount() == 1;
266  final SqlNode operand = call.operand(0);
267  final RelDataType type = cx.getValidator().getValidatedNodeType(operand);
268  if (!type.getComponentType().isStruct()) {
269  return cx.convertExpression(SqlStdOperatorTable.ELEMENT_SLICE.createCall(
270  SqlParserPos.ZERO, operand));
271  }
272 
273  // fallback on default behavior
274  return StandardConvertletTable.this.convertCall(cx, call);
275  });
276  }
277 
278  // Convert "$element_slice(<expr>)" to "element(<expr>).field#0"
279  if (false) {
280  registerOp(SqlStdOperatorTable.ELEMENT_SLICE, (cx, call) -> {
281  assert call.operandCount() == 1;
282  final SqlNode operand = call.operand(0);
283  final RexNode expr = cx.convertExpression(
284  SqlStdOperatorTable.ELEMENT.createCall(SqlParserPos.ZERO, operand));
285  return cx.getRexBuilder().makeFieldAccess(expr, 0);
286  });
287  }
288  }
289 
294  private static RexNode convertInterval(SqlRexContext cx, SqlCall call) {
295  // "INTERVAL n HOUR" becomes "n * INTERVAL '1' HOUR"
296  final SqlNode n = call.operand(0);
297  final SqlIntervalQualifier intervalQualifier = call.operand(1);
298  final SqlIntervalLiteral literal = SqlLiteral.createInterval(
299  1, "1", intervalQualifier, call.getParserPosition());
300  final SqlCall multiply =
301  SqlStdOperatorTable.MULTIPLY.createCall(call.getParserPosition(), n, literal);
302  return cx.convertExpression(multiply);
303  }
304 
305  //~ Methods ----------------------------------------------------------------
306 
307  private RexNode or(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
308  return rexBuilder.makeCall(SqlStdOperatorTable.OR, a0, a1);
309  }
310 
311  private RexNode eq(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
312  return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, a0, a1);
313  }
314 
315  private RexNode ge(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
316  return rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, a0, a1);
317  }
318 
319  private RexNode le(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
320  return rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, a0, a1);
321  }
322 
323  private RexNode and(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
324  return rexBuilder.makeCall(SqlStdOperatorTable.AND, a0, a1);
325  }
326 
327  private static RexNode divideInt(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
328  return rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE_INTEGER, a0, a1);
329  }
330 
331  private RexNode plus(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
332  return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, a0, a1);
333  }
334 
335  private RexNode minus(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
336  return rexBuilder.makeCall(SqlStdOperatorTable.MINUS, a0, a1);
337  }
338 
339  private static RexNode multiply(RexBuilder rexBuilder, RexNode a0, RexNode a1) {
340  return rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, a0, a1);
341  }
342 
343  private RexNode case_(RexBuilder rexBuilder, RexNode... args) {
344  return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args);
345  }
346 
347  // SqlNode helpers
348 
349  private SqlCall plus(SqlParserPos pos, SqlNode a0, SqlNode a1) {
350  return SqlStdOperatorTable.PLUS.createCall(pos, a0, a1);
351  }
352 
356  public RexNode convertCase(SqlRexContext cx, SqlCase call) {
357  SqlNodeList whenList = call.getWhenOperands();
358  SqlNodeList thenList = call.getThenOperands();
359  assert whenList.size() == thenList.size();
360 
361  RexBuilder rexBuilder = cx.getRexBuilder();
362  final List<RexNode> exprList = new ArrayList<>();
363  final RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
364  final RexLiteral unknownLiteral =
365  rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.BOOLEAN));
366  final RexLiteral nullLiteral =
367  rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.NULL));
368  for (int i = 0; i < whenList.size(); i++) {
369  if (SqlUtil.isNullLiteral(whenList.get(i), false)) {
370  exprList.add(unknownLiteral);
371  } else {
372  exprList.add(cx.convertExpression(whenList.get(i)));
373  }
374  if (SqlUtil.isNullLiteral(thenList.get(i), false)) {
375  exprList.add(nullLiteral);
376  } else {
377  exprList.add(cx.convertExpression(thenList.get(i)));
378  }
379  }
380  if (SqlUtil.isNullLiteral(call.getElseOperand(), false)) {
381  exprList.add(nullLiteral);
382  } else {
383  exprList.add(cx.convertExpression(call.getElseOperand()));
384  }
385 
386  RelDataType type = rexBuilder.deriveReturnType(call.getOperator(), exprList);
387  for (int i : elseArgs(exprList.size())) {
388  exprList.set(i, rexBuilder.ensureType(type, exprList.get(i), false));
389  }
390  return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, exprList);
391  }
392 
393  public RexNode convertMultiset(
394  SqlRexContext cx, SqlMultisetValueConstructor op, SqlCall call) {
395  final RelDataType originalType = cx.getValidator().getValidatedNodeType(call);
396  RexRangeRef rr = cx.getSubQueryExpr(call);
397  assert rr != null;
398  RelDataType msType = rr.getType().getFieldList().get(0).getType();
399  RexNode expr = cx.getRexBuilder().makeInputRef(msType, rr.getOffset());
400  assert msType.getComponentType().isStruct();
401  if (!originalType.getComponentType().isStruct()) {
402  // If the type is not a struct, the multiset operator will have
403  // wrapped the type as a record. Add a call to the $SLICE operator
404  // to compensate. For example,
405  // if '<ms>' has type 'RECORD (INTEGER x) MULTISET',
406  // then '$SLICE(<ms>) has type 'INTEGER MULTISET'.
407  // This will be removed as the expression is translated.
408  expr = cx.getRexBuilder().makeCall(
409  originalType, SqlStdOperatorTable.SLICE, ImmutableList.of(expr));
410  }
411  return expr;
412  }
413 
414  public RexNode convertArray(
415  SqlRexContext cx, SqlArrayValueConstructor op, SqlCall call) {
416  return convertCall(cx, call);
417  }
418 
419  public RexNode convertMap(SqlRexContext cx, SqlMapValueConstructor op, SqlCall call) {
420  return convertCall(cx, call);
421  }
422 
423  public RexNode convertMultisetQuery(
424  SqlRexContext cx, SqlMultisetQueryConstructor op, SqlCall call) {
425  final RelDataType originalType = cx.getValidator().getValidatedNodeType(call);
426  RexRangeRef rr = cx.getSubQueryExpr(call);
427  assert rr != null;
428  RelDataType msType = rr.getType().getFieldList().get(0).getType();
429  RexNode expr = cx.getRexBuilder().makeInputRef(msType, rr.getOffset());
430  assert msType.getComponentType().isStruct();
431  if (!originalType.getComponentType().isStruct()) {
432  // If the type is not a struct, the multiset operator will have
433  // wrapped the type as a record. Add a call to the $SLICE operator
434  // to compensate. For example,
435  // if '<ms>' has type 'RECORD (INTEGER x) MULTISET',
436  // then '$SLICE(<ms>) has type 'INTEGER MULTISET'.
437  // This will be removed as the expression is translated.
438  expr = cx.getRexBuilder().makeCall(SqlStdOperatorTable.SLICE, expr);
439  }
440  return expr;
441  }
442 
443  public RexNode convertJdbc(SqlRexContext cx, SqlJdbcFunctionCall op, SqlCall call) {
444  // Yuck!! The function definition contains arguments!
445  // TODO: adopt a more conventional definition/instance structure
446  final SqlCall convertedCall = op.getLookupCall();
447  return cx.convertExpression(convertedCall);
448  }
449 
450  protected RexNode convertCast(SqlRexContext cx, final SqlCall call) {
451  RelDataTypeFactory typeFactory = cx.getTypeFactory();
452  assert call.getKind() == SqlKind.CAST;
453  final SqlNode left = call.operand(0);
454  final SqlNode right = call.operand(1);
455  if (right instanceof SqlIntervalQualifier) {
456  final SqlIntervalQualifier intervalQualifier = (SqlIntervalQualifier) right;
457  if (left instanceof SqlIntervalLiteral) {
458  RexLiteral sourceInterval = (RexLiteral) cx.convertExpression(left);
459  BigDecimal sourceValue = (BigDecimal) sourceInterval.getValue();
460  RexLiteral castedInterval =
461  cx.getRexBuilder().makeIntervalLiteral(sourceValue, intervalQualifier);
462  return castToValidatedType(cx, call, castedInterval);
463  } else if (left instanceof SqlNumericLiteral) {
464  RexLiteral sourceInterval = (RexLiteral) cx.convertExpression(left);
465  BigDecimal sourceValue = (BigDecimal) sourceInterval.getValue();
466  final BigDecimal multiplier = intervalQualifier.getUnit().multiplier;
467  sourceValue = sourceValue.multiply(multiplier);
468  RexLiteral castedInterval =
469  cx.getRexBuilder().makeIntervalLiteral(sourceValue, intervalQualifier);
470  return castToValidatedType(cx, call, castedInterval);
471  }
472  return castToValidatedType(cx, call, cx.convertExpression(left));
473  }
474  SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
475  RelDataType type = dataType.deriveType(cx.getValidator());
476  if (type == null) {
477  type = cx.getValidator().getValidatedNodeType(dataType.getTypeName());
478  }
479  RexNode arg = cx.convertExpression(left);
480  if (arg.getType().isNullable()) {
481  type = typeFactory.createTypeWithNullability(type, true);
482  }
483  if (SqlUtil.isNullLiteral(left, false)) {
484  final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator();
485  validator.setValidatedNodeType(left, type);
486  return cx.convertExpression(left);
487  }
488  if (null != dataType.getCollectionsTypeName()) {
489  final RelDataType argComponentType = arg.getType().getComponentType();
490  final RelDataType componentType = type.getComponentType();
491  if (argComponentType.isStruct() && !componentType.isStruct()) {
492  RelDataType tt = typeFactory.builder()
493  .add(argComponentType.getFieldList().get(0).getName(),
494  componentType)
495  .build();
496  tt = typeFactory.createTypeWithNullability(tt, componentType.isNullable());
497  boolean isn = type.isNullable();
498  type = typeFactory.createMultisetType(tt, -1);
499  type = typeFactory.createTypeWithNullability(type, isn);
500  }
501  }
502  return cx.getRexBuilder().makeCast(type, arg);
503  }
504 
505  // HEAVY.AI new
506  protected RexNode convertTryCast(SqlRexContext cx, final SqlCall call) {
507  RelDataTypeFactory typeFactory = cx.getTypeFactory();
508  // assert call.getKind() == SqlKind.CAST;
509  final SqlNode left = call.operand(0);
510  final SqlNode right = call.operand(1);
511 
512  SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
513  RelDataType type = dataType.deriveType(cx.getValidator());
514  if (type == null) {
515  type = cx.getValidator().getValidatedNodeType(dataType.getTypeName());
516  }
517  RexNode arg = cx.convertExpression(left);
518  if (arg.getType().isNullable()) {
519  type = typeFactory.createTypeWithNullability(type, true);
520  }
521  if (SqlUtil.isNullLiteral(left, false)) {
522  final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator();
523  validator.setValidatedNodeType(left, type);
524  return cx.convertExpression(left);
525  }
526  return cx.getRexBuilder().makeCall(
527  type, HeavyDBSqlOperatorTable.TRY_CAST, ImmutableList.of(arg));
528  }
529  // end HEAVY.AI new
530 
531  protected RexNode convertFloorCeil(SqlRexContext cx, SqlCall call) {
532  final boolean floor = call.getKind() == SqlKind.FLOOR;
533  // Rewrite floor, ceil of interval
534  if (call.operandCount() == 1 && call.operand(0) instanceof SqlIntervalLiteral) {
535  final SqlIntervalLiteral literal = call.operand(0);
536  SqlIntervalLiteral.IntervalValue interval =
537  literal.getValueAs(SqlIntervalLiteral.IntervalValue.class);
538  BigDecimal val = interval.getIntervalQualifier().getStartUnit().multiplier;
539  RexNode rexInterval = cx.convertExpression(literal);
540 
541  final RexBuilder rexBuilder = cx.getRexBuilder();
542  RexNode zero = rexBuilder.makeExactLiteral(BigDecimal.valueOf(0));
543  RexNode cond = ge(rexBuilder, rexInterval, zero);
544 
545  RexNode pad = rexBuilder.makeExactLiteral(val.subtract(BigDecimal.ONE));
546  RexNode cast = rexBuilder.makeReinterpretCast(
547  rexInterval.getType(), pad, rexBuilder.makeLiteral(false));
548  RexNode sum = floor ? minus(rexBuilder, rexInterval, cast)
549  : plus(rexBuilder, rexInterval, cast);
550 
551  RexNode kase = floor ? case_(rexBuilder, rexInterval, cond, sum)
552  : case_(rexBuilder, sum, cond, rexInterval);
553 
554  RexNode factor = rexBuilder.makeExactLiteral(val);
555  RexNode div = divideInt(rexBuilder, kase, factor);
556  return multiply(rexBuilder, div, factor);
557  }
558 
559  // normal floor, ceil function
560  return convertFunction(cx, (SqlFunction) call.getOperator(), call);
561  }
562 
568  public RexNode convertExtract(SqlRexContext cx, SqlExtractFunction op, SqlCall call) {
569  return convertFunction(cx, (SqlFunction) call.getOperator(), call);
570  }
571 
572  private RexNode mod(
573  RexBuilder rexBuilder, RelDataType resType, RexNode res, BigDecimal val) {
574  if (val.equals(BigDecimal.ONE)) {
575  return res;
576  }
577  return rexBuilder.makeCall(
578  SqlStdOperatorTable.MOD, res, rexBuilder.makeExactLiteral(val, resType));
579  }
580 
581  private static RexNode divide(RexBuilder rexBuilder, RexNode res, BigDecimal val) {
582  if (val.equals(BigDecimal.ONE)) {
583  return res;
584  }
585  // If val is between 0 and 1, rather than divide by val, multiply by its
586  // reciprocal. For example, rather than divide by 0.001 multiply by 1000.
587  if (val.compareTo(BigDecimal.ONE) < 0 && val.signum() == 1) {
588  try {
589  final BigDecimal reciprocal =
590  BigDecimal.ONE.divide(val, RoundingMode.UNNECESSARY);
591  return multiply(rexBuilder, res, rexBuilder.makeExactLiteral(reciprocal));
592  } catch (ArithmeticException e) {
593  // ignore - reciprocal is not an integer
594  }
595  }
596  return divideInt(rexBuilder, res, rexBuilder.makeExactLiteral(val));
597  }
598 
599  public RexNode convertDatetimeMinus(
600  SqlRexContext cx, SqlDatetimeSubtractionOperator op, SqlCall call) {
601  // Rewrite datetime minus
602  final RexBuilder rexBuilder = cx.getRexBuilder();
603  final List<SqlNode> operands = call.getOperandList();
604  final List<RexNode> exprs =
605  convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE);
606 
607  final RelDataType resType = cx.getValidator().getValidatedNodeType(call);
608  return rexBuilder.makeCall(resType, op, exprs.subList(0, 2));
609  }
610 
611  public RexNode convertFunction(SqlRexContext cx, SqlFunction fun, SqlCall call) {
612  final List<SqlNode> operands = call.getOperandList();
613  final List<RexNode> exprs =
614  convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE);
615  if (fun.getFunctionType() == SqlFunctionCategory.USER_DEFINED_CONSTRUCTOR) {
616  return makeConstructorCall(cx, fun, exprs);
617  }
618  RelDataType returnType = cx.getValidator().getValidatedNodeTypeIfKnown(call);
619  if (returnType == null) {
620  returnType = cx.getRexBuilder().deriveReturnType(fun, exprs);
621  }
622  return cx.getRexBuilder().makeCall(returnType, fun, exprs);
623  }
624 
625  public RexNode convertWindowFunction(
626  SqlRexContext cx, SqlWindowTableFunction fun, SqlCall call) {
627  // The first operand of window function is actually a query, skip that.
628  final List<SqlNode> operands = Util.skip(call.getOperandList(), 1);
629  final List<RexNode> exprs =
630  convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE);
631  RelDataType returnType = cx.getValidator().getValidatedNodeTypeIfKnown(call);
632  if (returnType == null) {
633  returnType = cx.getRexBuilder().deriveReturnType(fun, exprs);
634  }
635  return cx.getRexBuilder().makeCall(returnType, fun, exprs);
636  }
637 
638  public RexNode convertJsonValueFunction(
639  SqlRexContext cx, SqlJsonValueFunction fun, SqlCall call) {
640  // For Expression with explicit return type:
641  // i.e. json_value('{"foo":"bar"}', 'lax $.foo', returning varchar(2000))
642  // use the specified type as the return type.
643  List<SqlNode> operands = call.getOperandList();
644  boolean hasExplicitReturningType = SqlJsonValueFunction.hasExplicitTypeSpec(
645  operands.toArray(SqlNode.EMPTY_ARRAY));
646  if (hasExplicitReturningType) {
647  operands = SqlJsonValueFunction.removeTypeSpecOperands(call);
648  }
649  final List<RexNode> exprs =
650  convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE);
651  RelDataType returnType = cx.getValidator().getValidatedNodeTypeIfKnown(call);
652  return cx.getRexBuilder().makeCall(returnType, fun, exprs);
653  }
654 
655  public RexNode convertSequenceValue(
656  SqlRexContext cx, SqlSequenceValueOperator fun, SqlCall call) {
657  final List<SqlNode> operands = call.getOperandList();
658  assert operands.size() == 1;
659  assert operands.get(0) instanceof SqlIdentifier;
660  final SqlIdentifier id = (SqlIdentifier) operands.get(0);
661  final String key = Util.listToString(id.names);
662  RelDataType returnType = cx.getValidator().getValidatedNodeType(call);
663  return cx.getRexBuilder().makeCall(
664  returnType, fun, ImmutableList.of(cx.getRexBuilder().makeLiteral(key)));
665  }
666 
667  public RexNode convertAggregateFunction(
668  SqlRexContext cx, SqlAggFunction fun, SqlCall call) {
669  final List<SqlNode> operands = call.getOperandList();
670  final List<RexNode> exprs;
671  if (call.isCountStar()) {
672  exprs = ImmutableList.of();
673  } else {
674  exprs = convertExpressionList(cx, operands, SqlOperandTypeChecker.Consistency.NONE);
675  }
676  RelDataType returnType = cx.getValidator().getValidatedNodeTypeIfKnown(call);
677  final int groupCount = cx.getGroupCount();
678  if (returnType == null) {
679  RexCallBinding binding =
680  new RexCallBinding(cx.getTypeFactory(), fun, exprs, ImmutableList.of()) {
681  @Override
682  public int getGroupCount() {
683  return groupCount;
684  }
685  };
686  returnType = fun.inferReturnType(binding);
687  }
688  return cx.getRexBuilder().makeCall(returnType, fun, exprs);
689  }
690 
691  private static RexNode makeConstructorCall(
692  SqlRexContext cx, SqlFunction constructor, List<RexNode> exprs) {
693  final RexBuilder rexBuilder = cx.getRexBuilder();
694  RelDataType type = rexBuilder.deriveReturnType(constructor, exprs);
695 
696  int n = type.getFieldCount();
697  ImmutableList.Builder<RexNode> initializationExprs = ImmutableList.builder();
698  final InitializerContext initializerContext = new InitializerContext() {
699  public RexBuilder getRexBuilder() {
700  return rexBuilder;
701  }
702 
703  public SqlNode validateExpression(RelDataType rowType, SqlNode expr) {
704  throw new UnsupportedOperationException();
705  }
706 
707  public RexNode convertExpression(SqlNode e) {
708  throw new UnsupportedOperationException();
709  }
710  };
711  for (int i = 0; i < n; ++i) {
712  initializationExprs.add(
713  cx.getInitializerExpressionFactory().newAttributeInitializer(
714  type, constructor, i, exprs, initializerContext));
715  }
716 
717  List<RexNode> defaultCasts = RexUtil.generateCastExpressions(
718  rexBuilder, type, initializationExprs.build());
719 
720  return rexBuilder.makeNewInvocation(type, defaultCasts);
721  }
722 
733  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
734  return convertCall(cx, call.getOperator(), call.getOperandList());
735  }
736 
741  private RexNode convertCall(SqlRexContext cx, SqlOperator op, List<SqlNode> operands) {
742  final RexBuilder rexBuilder = cx.getRexBuilder();
743  final SqlOperandTypeChecker.Consistency consistency =
744  op.getOperandTypeChecker() == null
745  ? SqlOperandTypeChecker.Consistency.NONE
746  : op.getOperandTypeChecker().getConsistency();
747  final List<RexNode> exprs = convertExpressionList(cx, operands, consistency);
748  RelDataType type = rexBuilder.deriveReturnType(op, exprs);
749  return rexBuilder.makeCall(type, op, RexUtil.flatten(exprs, op));
750  }
751 
752  private List<Integer> elseArgs(int count) {
753  // If list is odd, e.g. [0, 1, 2, 3, 4] we get [1, 3, 4]
754  // If list is even, e.g. [0, 1, 2, 3, 4, 5] we get [2, 4, 5]
755  final List<Integer> list = new ArrayList<>();
756  for (int i = count % 2;;) {
757  list.add(i);
758  i += 2;
759  if (i >= count) {
760  list.add(i - 1);
761  break;
762  }
763  }
764  return list;
765  }
766 
767  private static List<RexNode> convertExpressionList(SqlRexContext cx,
768  List<SqlNode> nodes,
769  SqlOperandTypeChecker.Consistency consistency) {
770  final List<RexNode> exprs = new ArrayList<>();
771  for (SqlNode node : nodes) {
772  exprs.add(cx.convertExpression(node));
773  }
774  if (exprs.size() > 1) {
775  final RelDataType type = consistentType(cx, consistency, RexUtil.types(exprs));
776  if (type != null) {
777  final List<RexNode> oldExprs = Lists.newArrayList(exprs);
778  exprs.clear();
779  for (RexNode expr : oldExprs) {
780  exprs.add(cx.getRexBuilder().ensureType(type, expr, true));
781  }
782  }
783  }
784  return exprs;
785  }
786 
787  private static RelDataType consistentType(SqlRexContext cx,
788  SqlOperandTypeChecker.Consistency consistency,
789  List<RelDataType> types) {
790  switch (consistency) {
791  case COMPARE:
792  if (SqlTypeUtil.areSameFamily(types)) {
793  // All arguments are of same family. No need for explicit casts.
794  return null;
795  }
796  final List<RelDataType> nonCharacterTypes = new ArrayList<>();
797  for (RelDataType type : types) {
798  if (type.getFamily() != SqlTypeFamily.CHARACTER) {
799  nonCharacterTypes.add(type);
800  }
801  }
802  if (!nonCharacterTypes.isEmpty()) {
803  final int typeCount = types.size();
804  types = nonCharacterTypes;
805  if (nonCharacterTypes.size() < typeCount) {
806  final RelDataTypeFamily family = nonCharacterTypes.get(0).getFamily();
807  if (family instanceof SqlTypeFamily) {
808  // The character arguments might be larger than the numeric
809  // argument. Give ourselves some headroom.
810  switch ((SqlTypeFamily) family) {
811  case INTEGER:
812  case NUMERIC:
813  nonCharacterTypes.add(
814  cx.getTypeFactory().createSqlType(SqlTypeName.BIGINT));
815  }
816  }
817  }
818  }
819  // fall through
820  case LEAST_RESTRICTIVE:
821  return cx.getTypeFactory().leastRestrictive(types);
822  default:
823  return null;
824  }
825  }
826 
827  private RexNode convertPlus(SqlRexContext cx, SqlCall call) {
828  final RexNode rex = convertCall(cx, call);
829  switch (rex.getType().getSqlTypeName()) {
830  case DATE:
831  case TIME:
832  case TIMESTAMP:
833  // Use special "+" operator for datetime + interval.
834  // Re-order operands, if necessary, so that interval is second.
835  final RexBuilder rexBuilder = cx.getRexBuilder();
836  List<RexNode> operands = ((RexCall) rex).getOperands();
837  if (operands.size() == 2) {
838  final SqlTypeName sqlTypeName = operands.get(0).getType().getSqlTypeName();
839  switch (sqlTypeName) {
840  case INTERVAL_YEAR:
841  case INTERVAL_YEAR_MONTH:
842  case INTERVAL_MONTH:
843  case INTERVAL_DAY:
844  case INTERVAL_DAY_HOUR:
845  case INTERVAL_DAY_MINUTE:
846  case INTERVAL_DAY_SECOND:
847  case INTERVAL_HOUR:
848  case INTERVAL_HOUR_MINUTE:
849  case INTERVAL_HOUR_SECOND:
850  case INTERVAL_MINUTE:
851  case INTERVAL_MINUTE_SECOND:
852  case INTERVAL_SECOND:
853  operands = ImmutableList.of(operands.get(1), operands.get(0));
854  }
855  }
856  return rexBuilder.makeCall(
857  rex.getType(), SqlStdOperatorTable.DATETIME_PLUS, operands);
858  default:
859  return rex;
860  }
861  }
862 
863  private RexNode convertIsDistinctFrom(SqlRexContext cx, SqlCall call, boolean neg) {
864  RexNode op0 = cx.convertExpression(call.operand(0));
865  RexNode op1 = cx.convertExpression(call.operand(1));
866  return RelOptUtil.isDistinctFrom(cx.getRexBuilder(), op0, op1, neg);
867  }
868 
874  public RexNode convertBetween(SqlRexContext cx, SqlBetweenOperator op, SqlCall call) {
875  final List<RexNode> list = convertExpressionList(
876  cx, call.getOperandList(), op.getOperandTypeChecker().getConsistency());
877  final RexNode x = list.get(SqlBetweenOperator.VALUE_OPERAND);
878  final RexNode y = list.get(SqlBetweenOperator.LOWER_OPERAND);
879  final RexNode z = list.get(SqlBetweenOperator.UPPER_OPERAND);
880 
881  final RexBuilder rexBuilder = cx.getRexBuilder();
882  RexNode ge1 = ge(rexBuilder, x, y);
883  RexNode le1 = le(rexBuilder, x, z);
884  RexNode and1 = and(rexBuilder, ge1, le1);
885 
886  RexNode res;
887  final SqlBetweenOperator.Flag symmetric = op.flag;
888  switch (symmetric) {
889  case ASYMMETRIC:
890  res = and1;
891  break;
892  case SYMMETRIC:
893  RexNode ge2 = ge(rexBuilder, x, z);
894  RexNode le2 = le(rexBuilder, x, y);
895  RexNode and2 = and(rexBuilder, ge2, le2);
896  res = or(rexBuilder, and1, and2);
897  break;
898  default:
899  throw Util.unexpected(symmetric);
900  }
901  final SqlBetweenOperator betweenOp = (SqlBetweenOperator) call.getOperator();
902  if (betweenOp.isNegated()) {
903  res = rexBuilder.makeCall(SqlStdOperatorTable.NOT, res);
904  }
905  return res;
906  }
907 
914  public RexNode convertLiteralChain(
915  SqlRexContext cx, SqlLiteralChainOperator op, SqlCall call) {
916  Util.discard(cx);
917 
918  SqlLiteral sum = SqlLiteralChainOperator.concatenateOperands(call);
919  return cx.convertLiteral(sum);
920  }
921 
927  public RexNode convertRow(SqlRexContext cx, SqlRowOperator op, SqlCall call) {
928  if (cx.getValidator().getValidatedNodeType(call).getSqlTypeName()
929  != SqlTypeName.COLUMN_LIST) {
930  return convertCall(cx, call);
931  }
932  final RexBuilder rexBuilder = cx.getRexBuilder();
933  final List<RexNode> columns = new ArrayList<>();
934  for (SqlNode operand : call.getOperandList()) {
935  columns.add(rexBuilder.makeLiteral(((SqlIdentifier) operand).getSimple()));
936  }
937  final RelDataType type =
938  rexBuilder.deriveReturnType(SqlStdOperatorTable.COLUMN_LIST, columns);
939  return rexBuilder.makeCall(type, SqlStdOperatorTable.COLUMN_LIST, columns);
940  }
941 
947  public RexNode convertOverlaps(SqlRexContext cx, SqlOverlapsOperator op, SqlCall call) {
948  // for intervals [t0, t1] overlaps [t2, t3], we can find if the
949  // intervals overlaps by: ~(t1 < t2 or t3 < t0)
950  assert call.getOperandList().size() == 2;
951 
952  final Pair<RexNode, RexNode> left =
953  convertOverlapsOperand(cx, call.getParserPosition(), call.operand(0));
954  final RexNode r0 = left.left;
955  final RexNode r1 = left.right;
956  final Pair<RexNode, RexNode> right =
957  convertOverlapsOperand(cx, call.getParserPosition(), call.operand(1));
958  final RexNode r2 = right.left;
959  final RexNode r3 = right.right;
960 
961  // Sort end points into start and end, such that (s0 <= e0) and (s1 <= e1).
962  final RexBuilder rexBuilder = cx.getRexBuilder();
963  RexNode leftSwap = le(rexBuilder, r0, r1);
964  final RexNode s0 = case_(rexBuilder, leftSwap, r0, r1);
965  final RexNode e0 = case_(rexBuilder, leftSwap, r1, r0);
966  RexNode rightSwap = le(rexBuilder, r2, r3);
967  final RexNode s1 = case_(rexBuilder, rightSwap, r2, r3);
968  final RexNode e1 = case_(rexBuilder, rightSwap, r3, r2);
969  // (e0 >= s1) AND (e1 >= s0)
970  switch (op.kind) {
971  case OVERLAPS:
972  return and(rexBuilder, ge(rexBuilder, e0, s1), ge(rexBuilder, e1, s0));
973  case CONTAINS:
974  return and(rexBuilder, le(rexBuilder, s0, s1), ge(rexBuilder, e0, e1));
975  case PERIOD_EQUALS:
976  return and(rexBuilder, eq(rexBuilder, s0, s1), eq(rexBuilder, e0, e1));
977  case PRECEDES:
978  return le(rexBuilder, e0, s1);
979  case IMMEDIATELY_PRECEDES:
980  return eq(rexBuilder, e0, s1);
981  case SUCCEEDS:
982  return ge(rexBuilder, s0, e1);
983  case IMMEDIATELY_SUCCEEDS:
984  return eq(rexBuilder, s0, e1);
985  default:
986  throw new AssertionError(op);
987  }
988  }
989 
990  private Pair<RexNode, RexNode> convertOverlapsOperand(
991  SqlRexContext cx, SqlParserPos pos, SqlNode operand) {
992  final SqlNode a0;
993  final SqlNode a1;
994  switch (operand.getKind()) {
995  case ROW:
996  a0 = ((SqlCall) operand).operand(0);
997  final SqlNode a10 = ((SqlCall) operand).operand(1);
998  final RelDataType t1 = cx.getValidator().getValidatedNodeType(a10);
999  if (SqlTypeUtil.isInterval(t1)) {
1000  // make t1 = t0 + t1 when t1 is an interval.
1001  a1 = plus(pos, a0, a10);
1002  } else {
1003  a1 = a10;
1004  }
1005  break;
1006  default:
1007  a0 = operand;
1008  a1 = operand;
1009  }
1010 
1011  final RexNode r0 = cx.convertExpression(a0);
1012  final RexNode r1 = cx.convertExpression(a1);
1013  return Pair.of(r0, r1);
1014  }
1015 
1021  public RexNode castToValidatedType(SqlRexContext cx, SqlCall call, RexNode value) {
1022  return castToValidatedType(call, value, cx.getValidator(), cx.getRexBuilder());
1023  }
1024 
1030  public static RexNode castToValidatedType(
1031  SqlNode node, RexNode e, SqlValidator validator, RexBuilder rexBuilder) {
1032  final RelDataType type = validator.getValidatedNodeType(node);
1033  if (e.getType() == type) {
1034  return e;
1035  }
1036  return rexBuilder.makeCast(type, e);
1037  }
1038 
1043  private static class RegrCovarianceConvertlet implements SqlRexConvertlet {
1044  private final SqlKind kind;
1045 
1047  this.kind = kind;
1048  }
1049 
1050  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
1051  assert call.operandCount() == 2;
1052  final SqlNode arg1 = call.operand(0);
1053  final SqlNode arg2 = call.operand(1);
1054  final SqlNode expr;
1055  final RelDataType type = cx.getValidator().getValidatedNodeType(call);
1056  switch (kind) {
1057  case COVAR_POP:
1058  expr = expandCovariance(arg1, arg2, null, type, cx, true);
1059  break;
1060  case COVAR_SAMP:
1061  expr = expandCovariance(arg1, arg2, null, type, cx, false);
1062  break;
1063  case REGR_SXX:
1064  expr = expandRegrSzz(arg2, arg1, type, cx, true);
1065  break;
1066  case REGR_SYY:
1067  expr = expandRegrSzz(arg1, arg2, type, cx, true);
1068  break;
1069  default:
1070  throw Util.unexpected(kind);
1071  }
1072  RexNode rex = cx.convertExpression(expr);
1073  return cx.getRexBuilder().ensureType(type, rex, true);
1074  }
1075 
1076  private SqlNode expandRegrSzz(final SqlNode arg1,
1077  final SqlNode arg2,
1078  final RelDataType avgType,
1079  final SqlRexContext cx,
1080  boolean variance) {
1081  final SqlParserPos pos = SqlParserPos.ZERO;
1082  final SqlNode count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg1, arg2);
1083  final SqlNode varPop =
1084  expandCovariance(arg1, variance ? arg1 : arg2, arg2, avgType, cx, true);
1085  final RexNode varPopRex = cx.convertExpression(varPop);
1086  final SqlNode varPopCast;
1087  varPopCast = getCastedSqlNode(varPop, avgType, pos, varPopRex);
1088  return SqlStdOperatorTable.MULTIPLY.createCall(pos, varPopCast, count);
1089  }
1090 
1091  private SqlNode expandCovariance(final SqlNode arg0Input,
1092  final SqlNode arg1Input,
1093  final SqlNode dependent,
1094  final RelDataType varType,
1095  final SqlRexContext cx,
1096  boolean biased) {
1097  // covar_pop(x1, x2) ==>
1098  // (sum(x1 * x2) - sum(x2) * sum(x1) / count(x1, x2))
1099  // / count(x1, x2)
1100  //
1101  // covar_samp(x1, x2) ==>
1102  // (sum(x1 * x2) - sum(x1) * sum(x2) / count(x1, x2))
1103  // / (count(x1, x2) - 1)
1104  final SqlParserPos pos = SqlParserPos.ZERO;
1105  final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
1106 
1107  final RexNode arg0Rex = cx.convertExpression(arg0Input);
1108  final RexNode arg1Rex = cx.convertExpression(arg1Input);
1109 
1110  final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex);
1111  final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex);
1112  final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
1113  final SqlNode sumArgSquared;
1114  final SqlNode sum0;
1115  final SqlNode sum1;
1116  final SqlNode count;
1117  if (dependent == null) {
1118  sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
1119  sum0 = SqlStdOperatorTable.SUM.createCall(pos, arg0, arg1);
1120  sum1 = SqlStdOperatorTable.SUM.createCall(pos, arg1, arg0);
1121  count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg0, arg1);
1122  } else {
1123  sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared, dependent);
1124  sum0 = SqlStdOperatorTable.SUM.createCall(
1125  pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent);
1126  sum1 = SqlStdOperatorTable.SUM.createCall(
1127  pos, arg1, Objects.equals(dependent, arg1Input) ? arg0 : dependent);
1128  count = SqlStdOperatorTable.REGR_COUNT.createCall(
1129  pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent);
1130  }
1131 
1132  final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum0, sum1);
1133  final SqlNode countCasted =
1134  getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
1135 
1136  final SqlNode avgSumSquared =
1137  SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted);
1138  final SqlNode diff =
1139  SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
1140  SqlNode denominator;
1141  if (biased) {
1142  denominator = countCasted;
1143  } else {
1144  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
1145  denominator = new SqlCase(SqlParserPos.ZERO,
1146  countCasted,
1147  SqlNodeList.of(
1148  SqlStdOperatorTable.EQUALS.createCall(pos, countCasted, one)),
1149  SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
1150  SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one));
1151  }
1152 
1153  return SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator);
1154  }
1155 
1156  private SqlNode getCastedSqlNode(
1157  SqlNode argInput, RelDataType varType, SqlParserPos pos, RexNode argRex) {
1158  SqlNode arg;
1159  if (argRex != null && !argRex.getType().equals(varType)) {
1160  arg = SqlStdOperatorTable.CAST.createCall(
1161  pos, argInput, SqlTypeUtil.convertTypeToSpec(varType));
1162  } else {
1163  arg = argInput;
1164  }
1165  return arg;
1166  }
1167  }
1168 
1173  private static class AvgVarianceConvertlet implements SqlRexConvertlet {
1174  private final SqlKind kind;
1175 
1176  AvgVarianceConvertlet(SqlKind kind) {
1177  this.kind = kind;
1178  }
1179 
1180  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
1181  assert call.operandCount() == 1;
1182  final SqlNode arg = call.operand(0);
1183  final SqlNode expr;
1184  final RelDataType type = cx.getValidator().getValidatedNodeType(call);
1185  switch (kind) {
1186  case AVG:
1187  expr = expandAvg(arg, type, cx);
1188  break;
1189  case STDDEV_POP:
1190  expr = expandVariance(arg, type, cx, true, true);
1191  break;
1192  case STDDEV_SAMP:
1193  expr = expandVariance(arg, type, cx, false, true);
1194  break;
1195  case VAR_POP:
1196  expr = expandVariance(arg, type, cx, true, false);
1197  break;
1198  case VAR_SAMP:
1199  expr = expandVariance(arg, type, cx, false, false);
1200  break;
1201  default:
1202  throw Util.unexpected(kind);
1203  }
1204  RexNode rex = cx.convertExpression(expr);
1205  return cx.getRexBuilder().ensureType(type, rex, true);
1206  }
1207 
1208  private SqlNode expandAvg(
1209  final SqlNode arg, final RelDataType avgType, final SqlRexContext cx) {
1210  final SqlParserPos pos = SqlParserPos.ZERO;
1211  final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
1212  final RexNode sumRex = cx.convertExpression(sum);
1213  final SqlNode sumCast;
1214  sumCast = getCastedSqlNode(sum, avgType, pos, sumRex);
1215  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
1216  return SqlStdOperatorTable.DIVIDE.createCall(pos, sumCast, count);
1217  }
1218 
1219  private SqlNode expandVariance(final SqlNode argInput,
1220  final RelDataType varType,
1221  final SqlRexContext cx,
1222  boolean biased,
1223  boolean sqrt) {
1224  // stddev_pop(x) ==>
1225  // power(
1226  // (sum(x * x) - sum(x) * sum(x) / count(x))
1227  // / count(x),
1228  // .5)
1229  //
1230  // stddev_samp(x) ==>
1231  // power(
1232  // (sum(x * x) - sum(x) * sum(x) / count(x))
1233  // / (count(x) - 1),
1234  // .5)
1235  //
1236  // var_pop(x) ==>
1237  // (sum(x * x) - sum(x) * sum(x) / count(x))
1238  // / count(x)
1239  //
1240  // var_samp(x) ==>
1241  // (sum(x * x) - sum(x) * sum(x) / count(x))
1242  // / (count(x) - 1)
1243  final SqlParserPos pos = SqlParserPos.ZERO;
1244 
1245  final SqlNode arg =
1246  getCastedSqlNode(argInput, varType, pos, cx.convertExpression(argInput));
1247 
1248  final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
1249  final SqlNode argSquaredCasted = getCastedSqlNode(
1250  argSquared, varType, pos, cx.convertExpression(argSquared));
1251  final SqlNode sumArgSquared =
1252  SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted);
1253  final SqlNode sumArgSquaredCasted = getCastedSqlNode(
1254  sumArgSquared, varType, pos, cx.convertExpression(sumArgSquared));
1255  final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
1256  final SqlNode sumCasted =
1257  getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum));
1258  final SqlNode sumSquared =
1259  SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted);
1260  final SqlNode sumSquaredCasted = getCastedSqlNode(
1261  sumSquared, varType, pos, cx.convertExpression(sumSquared));
1262  final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
1263  final SqlNode countCasted =
1264  getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
1265  final SqlNode avgSumSquared =
1266  SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted, countCasted);
1267  final SqlNode avgSumSquaredCasted = getCastedSqlNode(
1268  avgSumSquared, varType, pos, cx.convertExpression(avgSumSquared));
1269  final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(
1270  pos, sumArgSquaredCasted, avgSumSquaredCasted);
1271  final SqlNode diffCasted =
1272  getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff));
1273  final SqlNode denominator;
1274  if (biased) {
1275  denominator = countCasted;
1276  } else {
1277  final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
1278  final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
1279  denominator = new SqlCase(SqlParserPos.ZERO,
1280  count,
1281  SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, count, one)),
1282  SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
1283  SqlStdOperatorTable.MINUS.createCall(pos, count, one));
1284  }
1285  final SqlNode div =
1286  SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator);
1287  final SqlNode divCasted =
1288  getCastedSqlNode(div, varType, pos, cx.convertExpression(div));
1289 
1290  SqlNode result = div;
1291  if (sqrt) {
1292  final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
1293  result = SqlStdOperatorTable.POWER.createCall(pos, divCasted, half);
1294  }
1295  return result;
1296  }
1297 
1298  private SqlNode getCastedSqlNode(
1299  SqlNode argInput, RelDataType varType, SqlParserPos pos, RexNode argRex) {
1300  SqlNode arg;
1301  if (argRex != null && !argRex.getType().equals(varType)) {
1302  arg = SqlStdOperatorTable.CAST.createCall(
1303  pos, argInput, SqlTypeUtil.convertTypeToSpec(varType));
1304  } else {
1305  arg = argInput;
1306  }
1307  return arg;
1308  }
1309  }
1310 
1315  private static class TrimConvertlet implements SqlRexConvertlet {
1316  private final SqlTrimFunction.Flag flag;
1317 
1318  TrimConvertlet(SqlTrimFunction.Flag flag) {
1319  this.flag = flag;
1320  }
1321 
1322  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
1323  final RexBuilder rexBuilder = cx.getRexBuilder();
1324  final RexNode operand = cx.convertExpression(call.getOperandList().get(0));
1325  return rexBuilder.makeCall(SqlStdOperatorTable.TRIM,
1326  rexBuilder.makeFlag(flag),
1327  rexBuilder.makeLiteral(" "),
1328  operand);
1329  }
1330  }
1331 
1333  private static class GreatestConvertlet implements SqlRexConvertlet {
1334  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
1335  // Translate
1336  // GREATEST(a, b, c, d)
1337  // to
1338  // CASE
1339  // WHEN a IS NULL OR b IS NULL OR c IS NULL OR d IS NULL
1340  // THEN NULL
1341  // WHEN a > b AND a > c AND a > d
1342  // THEN a
1343  // WHEN b > c AND b > d
1344  // THEN b
1345  // WHEN c > d
1346  // THEN c
1347  // ELSE d
1348  // END
1349  final RexBuilder rexBuilder = cx.getRexBuilder();
1350  final RelDataType type = cx.getValidator().getValidatedNodeType(call);
1351  final SqlBinaryOperator op;
1352  switch (call.getKind()) {
1353  case GREATEST:
1354  op = SqlStdOperatorTable.GREATER_THAN;
1355  break;
1356  case LEAST:
1357  op = SqlStdOperatorTable.LESS_THAN;
1358  break;
1359  default:
1360  throw new AssertionError();
1361  }
1362  final List<RexNode> exprs = convertExpressionList(
1363  cx, call.getOperandList(), SqlOperandTypeChecker.Consistency.NONE);
1364  final List<RexNode> list = new ArrayList<>();
1365  final List<RexNode> orList = new ArrayList<>();
1366  for (RexNode expr : exprs) {
1367  orList.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, expr));
1368  }
1369  list.add(RexUtil.composeDisjunction(rexBuilder, orList));
1370  list.add(rexBuilder.makeNullLiteral(type));
1371  for (int i = 0; i < exprs.size() - 1; i++) {
1372  RexNode expr = exprs.get(i);
1373  final List<RexNode> andList = new ArrayList<>();
1374  for (int j = i + 1; j < exprs.size(); j++) {
1375  final RexNode expr2 = exprs.get(j);
1376  andList.add(rexBuilder.makeCall(op, expr, expr2));
1377  }
1378  list.add(RexUtil.composeConjunction(rexBuilder, andList));
1379  list.add(expr);
1380  }
1381  list.add(exprs.get(exprs.size() - 1));
1382  return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, list);
1383  }
1384  }
1385 
1387  private class FloorCeilConvertlet implements SqlRexConvertlet {
1388  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
1389  return convertFloorCeil(cx, call);
1390  }
1391  }
1392 
1394  private static class TimestampAddConvertlet implements SqlRexConvertlet {
1395  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
1396  // TIMESTAMPADD(unit, count, timestamp)
1397  // => timestamp + count * INTERVAL '1' UNIT
1398  final RexBuilder rexBuilder = cx.getRexBuilder();
1399  final SqlLiteral unitLiteral = call.operand(0);
1400  final TimeUnit unit = unitLiteral.symbolValue(TimeUnit.class);
1401  RexNode interval2Add;
1402  SqlIntervalQualifier qualifier =
1403  new SqlIntervalQualifier(unit, null, unitLiteral.getParserPosition());
1404  RexNode op1 = cx.convertExpression(call.operand(1));
1405  switch (unit) {
1406  case MICROSECOND:
1407  case NANOSECOND:
1408  interval2Add = divide(rexBuilder,
1409  multiply(rexBuilder,
1410  rexBuilder.makeIntervalLiteral(BigDecimal.ONE, qualifier),
1411  op1),
1412  BigDecimal.ONE.divide(unit.multiplier, RoundingMode.UNNECESSARY));
1413  break;
1414  default:
1415  interval2Add = multiply(rexBuilder,
1416  rexBuilder.makeIntervalLiteral(unit.multiplier, qualifier),
1417  op1);
1418  }
1419 
1420  return rexBuilder.makeCall(SqlStdOperatorTable.DATETIME_PLUS,
1421  cx.convertExpression(call.operand(2)),
1422  interval2Add);
1423  }
1424  }
1425 
1427  private static class TimestampDiffConvertlet implements SqlRexConvertlet {
1428  public RexNode convertCall(SqlRexContext cx, SqlCall call) {
1429  // TIMESTAMPDIFF(unit, t1, t2)
1430  // => (t2 - t1) UNIT
1431  final RexBuilder rexBuilder = cx.getRexBuilder();
1432  final SqlLiteral unitLiteral = call.operand(0);
1433  TimeUnit unit = unitLiteral.symbolValue(TimeUnit.class);
1434  BigDecimal multiplier = BigDecimal.ONE;
1435  BigDecimal divider = BigDecimal.ONE;
1436  SqlTypeName sqlTypeName =
1437  unit == TimeUnit.NANOSECOND ? SqlTypeName.BIGINT : SqlTypeName.INTEGER;
1438  switch (unit) {
1439  case MICROSECOND:
1440  case MILLISECOND:
1441  case NANOSECOND:
1442  case WEEK:
1443  multiplier = BigDecimal.valueOf(DateTimeUtils.MILLIS_PER_SECOND);
1444  divider = unit.multiplier;
1445  unit = TimeUnit.SECOND;
1446  break;
1447  case QUARTER:
1448  divider = unit.multiplier;
1449  unit = TimeUnit.MONTH;
1450  break;
1451  }
1452  final SqlIntervalQualifier qualifier =
1453  new SqlIntervalQualifier(unit, null, SqlParserPos.ZERO);
1454  final RexNode op2 = cx.convertExpression(call.operand(2));
1455  final RexNode op1 = cx.convertExpression(call.operand(1));
1456  final RelDataType intervalType = cx.getTypeFactory().createTypeWithNullability(
1457  cx.getTypeFactory().createSqlIntervalType(qualifier),
1458  op1.getType().isNullable() || op2.getType().isNullable());
1459  final RexCall rexCall = (RexCall) rexBuilder.makeCall(
1460  intervalType, SqlStdOperatorTable.MINUS_DATE, ImmutableList.of(op2, op1));
1461  final RelDataType intType = cx.getTypeFactory().createTypeWithNullability(
1462  cx.getTypeFactory().createSqlType(sqlTypeName),
1463  SqlTypeUtil.containsNullable(rexCall.getType()));
1464  RexNode e = rexBuilder.makeCast(intType, rexCall);
1465  return rexBuilder.multiplyDivide(e, multiplier, divider);
1466  }
1467  }
1468 }
RexNode convertOverlaps(SqlRexContext cx, SqlOverlapsOperator op, SqlCall call)
RexNode convertSequenceValue(SqlRexContext cx, SqlSequenceValueOperator fun, SqlCall call)
SqlNode expandRegrSzz(final SqlNode arg1, final SqlNode arg2, final RelDataType avgType, final SqlRexContext cx, boolean variance)
RexNode mod(RexBuilder rexBuilder, RelDataType resType, RexNode res, BigDecimal val)
RexNode ge(RexBuilder rexBuilder, RexNode a0, RexNode a1)
RexNode plus(RexBuilder rexBuilder, RexNode a0, RexNode a1)
SqlNode expandCovariance(final SqlNode arg0Input, final SqlNode arg1Input, final SqlNode dependent, final RelDataType varType, final SqlRexContext cx, boolean biased)
RexNode eq(RexBuilder rexBuilder, RexNode a0, RexNode a1)
RexNode or(RexBuilder rexBuilder, RexNode a0, RexNode a1)
SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType, SqlParserPos pos, RexNode argRex)
RexNode convertMultisetQuery(SqlRexContext cx, SqlMultisetQueryConstructor op, SqlCall call)
RexNode convertWindowFunction(SqlRexContext cx, SqlWindowTableFunction fun, SqlCall call)
RexNode convertFloorCeil(SqlRexContext cx, SqlCall call)
static RexNode convertInterval(SqlRexContext cx, SqlCall call)
SqlCall plus(SqlParserPos pos, SqlNode a0, SqlNode a1)
Pair< RexNode, RexNode > convertOverlapsOperand(SqlRexContext cx, SqlParserPos pos, SqlNode operand)
SqlNode expandVariance(final SqlNode argInput, final RelDataType varType, final SqlRexContext cx, boolean biased, boolean sqrt)
RexNode castToValidatedType(SqlRexContext cx, SqlCall call, RexNode value)
RexNode convertJsonValueFunction(SqlRexContext cx, SqlJsonValueFunction fun, SqlCall call)
RexNode convertAggregateFunction(SqlRexContext cx, SqlAggFunction fun, SqlCall call)
RexNode convertBetween(SqlRexContext cx, SqlBetweenOperator op, SqlCall call)
static RexNode multiply(RexBuilder rexBuilder, RexNode a0, RexNode a1)
static RexNode divide(RexBuilder rexBuilder, RexNode res, BigDecimal val)
RexNode convertCase(SqlRexContext cx, SqlCase call)
RexNode convertLiteralChain(SqlRexContext cx, SqlLiteralChainOperator op, SqlCall call)
RexNode and(RexBuilder rexBuilder, RexNode a0, RexNode a1)
RexNode convertPlus(SqlRexContext cx, SqlCall call)
static RelDataType consistentType(SqlRexContext cx, SqlOperandTypeChecker.Consistency consistency, List< RelDataType > types)
static RexNode castToValidatedType(SqlNode node, RexNode e, SqlValidator validator, RexBuilder rexBuilder)
RexNode convertFunction(SqlRexContext cx, SqlFunction fun, SqlCall call)
static List< RexNode > convertExpressionList(SqlRexContext cx, List< SqlNode > nodes, SqlOperandTypeChecker.Consistency consistency)
std::string toString(const Executor::ExtModuleKinds &kind)
Definition: Execute.h:1703
RexNode convertExtract(SqlRexContext cx, SqlExtractFunction op, SqlCall call)
SqlNode expandAvg(final SqlNode arg, final RelDataType avgType, final SqlRexContext cx)
RexNode convertCall(SqlRexContext cx, SqlOperator op, List< SqlNode > operands)
RexNode convertCall(SqlRexContext cx, SqlCall call)
RexNode convertIsDistinctFrom(SqlRexContext cx, SqlCall call, boolean neg)
SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType, SqlParserPos pos, RexNode argRex)
RexNode convertArray(SqlRexContext cx, SqlArrayValueConstructor op, SqlCall call)
static RexNode divideInt(RexBuilder rexBuilder, RexNode a0, RexNode a1)
RexNode convertJdbc(SqlRexContext cx, SqlJdbcFunctionCall op, SqlCall call)
RexNode convertRow(SqlRexContext cx, SqlRowOperator op, SqlCall call)
RexNode convertMap(SqlRexContext cx, SqlMapValueConstructor op, SqlCall call)
RexNode convertMultiset(SqlRexContext cx, SqlMultisetValueConstructor op, SqlCall call)
SqlOperandTypeChecker getOperandTypeChecker()
constexpr double n
Definition: Utm.h:38
RexNode le(RexBuilder rexBuilder, RexNode a0, RexNode a1)
RexNode minus(RexBuilder rexBuilder, RexNode a0, RexNode a1)
RexNode convertCast(SqlRexContext cx, final SqlCall call)
RexNode case_(RexBuilder rexBuilder, RexNode...args)
RexNode convertTryCast(SqlRexContext cx, final SqlCall call)
RexNode convertDatetimeMinus(SqlRexContext cx, SqlDatetimeSubtractionOperator op, SqlCall call)
static RexNode makeConstructorCall(SqlRexContext cx, SqlFunction constructor, List< RexNode > exprs)