OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelAlgVisitor.h
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef QUERYENGINE_RELALGVISITOR_H
18 #define QUERYENGINE_RELALGVISITOR_H
19 
20 #include "RelAlgDag.h"
21 
22 template <class T>
24  public:
25  T visit(const RelAlgNode* rel_alg) const {
26  auto result = defaultResult();
27  for (size_t i = 0; i < rel_alg->inputCount(); ++i) {
28  result = aggregateResult(result, visit(rel_alg->getInput(i)));
29  }
30  const auto aggregate = dynamic_cast<const RelAggregate*>(rel_alg);
31  if (aggregate) {
32  return aggregateResult(result, visitAggregate(aggregate));
33  }
34  const auto compound = dynamic_cast<const RelCompound*>(rel_alg);
35  if (compound) {
36  return aggregateResult(result, visitCompound(compound));
37  }
38  const auto filter = dynamic_cast<const RelFilter*>(rel_alg);
39  if (filter) {
40  return aggregateResult(result, visitFilter(filter));
41  }
42  const auto join = dynamic_cast<const RelJoin*>(rel_alg);
43  if (join) {
45  }
46  const auto left_deep_inner_join = dynamic_cast<const RelLeftDeepInnerJoin*>(rel_alg);
47  if (left_deep_inner_join) {
48  return aggregateResult(result, visitLeftDeepInnerJoin(left_deep_inner_join));
49  }
50  const auto project = dynamic_cast<const RelProject*>(rel_alg);
51  if (project) {
52  return aggregateResult(result, visitProject(project));
53  }
54  const auto scan = dynamic_cast<const RelScan*>(rel_alg);
55  if (scan) {
56  return aggregateResult(result, visitScan(scan));
57  }
58  const auto sort = dynamic_cast<const RelSort*>(rel_alg);
59  if (sort) {
61  }
62  const auto logical_values = dynamic_cast<const RelLogicalValues*>(rel_alg);
63  if (logical_values) {
64  return aggregateResult(result, visitLogicalValues(logical_values));
65  }
66  const auto modify = dynamic_cast<const RelModify*>(rel_alg);
67  if (modify) {
68  return aggregateResult(result, visitModify(modify));
69  }
70  const auto table_func = dynamic_cast<const RelTableFunction*>(rel_alg);
71  if (table_func) {
72  return aggregateResult(result, visitTableFunction(table_func));
73  }
74  const auto logical_union = dynamic_cast<const RelLogicalUnion*>(rel_alg);
75  if (logical_union) {
76  return aggregateResult(result, visitLogicalUnion(logical_union));
77  }
78  LOG(FATAL) << "Unhandled rel_alg type: "
80  return {};
81  }
82 
83  virtual T visitAggregate(const RelAggregate*) const { return defaultResult(); }
84 
85  virtual T visitCompound(const RelCompound*) const { return defaultResult(); }
86 
87  virtual T visitFilter(const RelFilter*) const { return defaultResult(); }
88 
89  virtual T visitJoin(const RelJoin*) const { return defaultResult(); }
90 
92  return defaultResult();
93  }
94 
95  virtual T visitProject(const RelProject*) const { return defaultResult(); }
96 
97  virtual T visitScan(const RelScan*) const { return defaultResult(); }
98 
99  virtual T visitSort(const RelSort*) const { return defaultResult(); }
100 
101  virtual T visitLogicalValues(const RelLogicalValues*) const { return defaultResult(); }
102 
103  virtual T visitModify(const RelModify*) const { return defaultResult(); }
104 
105  virtual T visitTableFunction(const RelTableFunction*) const { return defaultResult(); }
106 
107  virtual T visitLogicalUnion(const RelLogicalUnion*) const { return defaultResult(); }
108 
109  protected:
110  virtual T aggregateResult(const T& aggregate, const T& next_result) const {
111  return next_result;
112  }
113 
114  virtual T defaultResult() const { return T{}; }
115 };
116 
117 #endif // QUERYENGINE_RELALGVISITOR_H
virtual T visitCompound(const RelCompound *) const
Definition: RelAlgVisitor.h:85
virtual T visitTableFunction(const RelTableFunction *) const
#define LOG(tag)
Definition: Logger.h:285
virtual T visitJoin(const RelJoin *) const
Definition: RelAlgVisitor.h:89
std::string join(T const &container, std::string const &delim)
DEVICE void sort(ARGS &&...args)
Definition: gpu_enabled.h:105
virtual T visitFilter(const RelFilter *) const
Definition: RelAlgVisitor.h:87
virtual T visitLogicalUnion(const RelLogicalUnion *) const
virtual T visitAggregate(const RelAggregate *) const
Definition: RelAlgVisitor.h:83
T visit(const RelAlgNode *rel_alg) const
Definition: RelAlgVisitor.h:25
const RelAlgNode * getInput(const size_t idx) const
Definition: RelAlgDag.h:877
virtual T visitSort(const RelSort *) const
Definition: RelAlgVisitor.h:99
virtual T visitModify(const RelModify *) const
virtual std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const =0
virtual T visitProject(const RelProject *) const
Definition: RelAlgVisitor.h:95
virtual T visitLogicalValues(const RelLogicalValues *) const
static RelRexToStringConfig defaults()
Definition: RelAlgDag.h:78
virtual T visitLeftDeepInnerJoin(const RelLeftDeepInnerJoin *) const
Definition: RelAlgVisitor.h:91
virtual T defaultResult() const
virtual T visitScan(const RelScan *) const
Definition: RelAlgVisitor.h:97
const size_t inputCount() const
Definition: RelAlgDag.h:875
virtual T aggregateResult(const T &aggregate, const T &next_result) const