OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
RelLeftDeepInnerJoin.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "RelLeftDeepInnerJoin.h"
18 #include "Logger/Logger.h"
19 #include "RelAlgDag.h"
20 #include "RexVisitor.h"
21 
22 #include <numeric>
23 
25  const std::shared_ptr<RelFilter>& filter,
26  std::vector<std::shared_ptr<const RelAlgNode>> inputs,
27  std::vector<std::shared_ptr<const RelJoin>>& original_joins)
28  : condition_(filter ? filter->getAndReleaseCondition() : nullptr)
29  , original_filter_(filter)
30  , original_joins_(original_joins) {
31  std::vector<std::unique_ptr<const RexScalar>> operands;
32  bool is_notnull = true;
33  // Accumulate join conditions from the (explicit) joins themselves and
34  // from the filter node at the root of the left-deep tree pattern.
35  outer_conditions_per_level_.resize(original_joins.size());
36  for (size_t nesting_level = 0; nesting_level < original_joins.size(); ++nesting_level) {
37  const auto& original_join = original_joins[nesting_level];
38  const auto condition_true =
39  dynamic_cast<const RexLiteral*>(original_join->getCondition());
40  if (!condition_true || !condition_true->getVal<bool>()) {
41  if (dynamic_cast<const RexOperator*>(original_join->getCondition())) {
42  is_notnull =
43  is_notnull && dynamic_cast<const RexOperator*>(original_join->getCondition())
44  ->getType()
45  .get_notnull();
46  }
47  switch (original_join->getJoinType()) {
48  case JoinType::INNER:
49  case JoinType::SEMI:
50  case JoinType::ANTI: {
51  if (original_join->getCondition()) {
52  operands.emplace_back(original_join->getAndReleaseCondition());
53  }
54  break;
55  }
56  case JoinType::LEFT: {
57  if (original_join->getCondition()) {
58  outer_conditions_per_level_[nesting_level].reset(
59  original_join->getAndReleaseCondition());
60  }
61  break;
62  }
63  default:
64  CHECK(false);
65  }
66  }
67  }
68  if (!operands.empty()) {
69  if (condition_) {
70  CHECK(dynamic_cast<const RexOperator*>(condition_.get()));
71  is_notnull =
72  is_notnull &&
73  static_cast<const RexOperator*>(condition_.get())->getType().get_notnull();
74  operands.emplace_back(std::move(condition_));
75  }
76  if (operands.size() > 1) {
77  condition_.reset(
78  new RexOperator(kAND, operands, SQLTypeInfo(kBOOLEAN, is_notnull)));
79  } else {
80  condition_ = std::move(operands.front());
81  }
82  }
83  if (!condition_) {
84  condition_.reset(new RexLiteral(true, kBOOLEAN, kBOOLEAN, 0, 0, 0, 0));
85  }
86  for (const auto& input : inputs) {
87  addManagedInput(input);
88  }
89 }
90 
92  return condition_.get();
93 }
94 
96  const size_t nesting_level) const {
97  CHECK_GE(nesting_level, size_t(1));
98  CHECK_LE(nesting_level, outer_conditions_per_level_.size());
99  // Outer join conditions are collected depth-first while the returned condition
100  // must be consistent with the order of the loops (which is reverse depth-first).
101  return outer_conditions_per_level_[outer_conditions_per_level_.size() - nesting_level]
102  .get();
103 }
104 
105 const JoinType RelLeftDeepInnerJoin::getJoinType(const size_t nesting_level) const {
106  CHECK_LE(nesting_level, original_joins_.size());
107  return original_joins_[original_joins_.size() - nesting_level]->getJoinType();
108 }
109 
111  if (!config.attributes_only) {
112  std::string ret = ::typeName(this) + "(";
113  ret += condition_->toString(config);
114  if (!config.skip_input_nodes) {
115  for (const auto& input : inputs_) {
116  ret += " " + input->toString(config);
117  }
118  } else {
119  ret += ", input node id={";
120  for (auto& input : inputs_) {
121  ret += std::to_string(input->getId()) + " ";
122  }
123  ret += "}";
124  }
125  ret += ")";
126  return ret;
127  } else {
128  return ::typeName(this) + "()";
129  }
130 }
131 
133  size_t total_size = 0;
134  for (const auto& input : inputs_) {
135  total_size += input->size();
136  }
137  return total_size;
138 }
139 
141  return outer_conditions_per_level_.size();
142 }
143 
144 std::shared_ptr<RelAlgNode> RelLeftDeepInnerJoin::deepCopy() const {
145  CHECK(false);
146  return nullptr;
147 }
148 
150  if (node == original_filter_.get()) {
151  return true;
152  }
153  for (const auto& original_join : original_joins_) {
154  if (original_join.get() == node) {
155  return true;
156  }
157  }
158  return false;
159 }
160 
162  return original_filter_.get();
163 }
164 
165 std::vector<std::shared_ptr<const RelJoin>> RelLeftDeepInnerJoin::getOriginalJoins()
166  const {
167  std::vector<std::shared_ptr<const RelJoin>> original_joins;
168  original_joins.assign(original_joins_.begin(), original_joins_.end());
169  return original_joins;
170 }
171 
172 namespace {
173 
175  std::deque<std::shared_ptr<const RelAlgNode>>& inputs,
176  std::vector<std::shared_ptr<const RelJoin>>& original_joins,
177  const std::shared_ptr<const RelJoin>& join) {
178  original_joins.push_back(join);
179  CHECK_EQ(size_t(2), join->inputCount());
180  const auto left_input_join =
181  std::dynamic_pointer_cast<const RelJoin>(join->getAndOwnInput(0));
182  if (left_input_join) {
183  inputs.push_front(join->getAndOwnInput(1));
184  collect_left_deep_join_inputs(inputs, original_joins, left_input_join);
185  } else {
186  inputs.push_front(join->getAndOwnInput(1));
187  inputs.push_front(join->getAndOwnInput(0));
188  }
189 }
190 
191 std::pair<std::shared_ptr<RelLeftDeepInnerJoin>, std::shared_ptr<const RelAlgNode>>
192 create_left_deep_join(const std::shared_ptr<RelAlgNode>& left_deep_join_root) {
193  const auto old_root = get_left_deep_join_root(left_deep_join_root);
194  if (!old_root) {
195  return {nullptr, nullptr};
196  }
197  std::deque<std::shared_ptr<const RelAlgNode>> inputs_deque;
198  const auto left_deep_join_filter =
199  std::dynamic_pointer_cast<RelFilter>(left_deep_join_root);
200  const auto join =
201  std::dynamic_pointer_cast<const RelJoin>(left_deep_join_root->getAndOwnInput(0));
202  CHECK(join);
203  std::vector<std::shared_ptr<const RelJoin>> original_joins;
204  collect_left_deep_join_inputs(inputs_deque, original_joins, join);
205  std::vector<std::shared_ptr<const RelAlgNode>> inputs(inputs_deque.begin(),
206  inputs_deque.end());
207  return {std::make_shared<RelLeftDeepInnerJoin>(
208  left_deep_join_filter, inputs, original_joins),
209  old_root};
210 }
211 
212 class RebindRexInputsFromLeftDeepJoin : public RexVisitor<void*> {
213  public:
215  : left_deep_join_(left_deep_join) {
216  std::vector<size_t> input_sizes;
217  CHECK_GT(left_deep_join->inputCount(), size_t(1));
218  for (size_t i = 0; i < left_deep_join->inputCount(); ++i) {
219  input_sizes.push_back(left_deep_join->getInput(i)->size());
220  }
221  input_size_prefix_sums_.resize(input_sizes.size());
223  input_sizes.begin(), input_sizes.end(), input_size_prefix_sums_.begin());
224  }
225 
226  void* visitInput(const RexInput* rex_input) const override {
227  const auto source_node = rex_input->getSourceNode();
228  if (left_deep_join_->coversOriginalNode(source_node)) {
229  const auto it = std::lower_bound(input_size_prefix_sums_.begin(),
230  input_size_prefix_sums_.end(),
231  rex_input->getIndex(),
232  std::less_equal<size_t>());
233  CHECK(it != input_size_prefix_sums_.end());
234  const auto input_node =
235  left_deep_join_->getInput(std::distance(input_size_prefix_sums_.begin(), it));
236  if (it != input_size_prefix_sums_.begin()) {
237  const auto prev_input_count = *(it - 1);
238  CHECK_LE(prev_input_count, rex_input->getIndex());
239  const auto input_index = rex_input->getIndex() - prev_input_count;
240  rex_input->setIndex(input_index);
241  }
242  rex_input->setSourceNode(input_node);
243  }
244  return nullptr;
245  };
246 
247  private:
248  std::vector<size_t> input_size_prefix_sums_;
250 };
251 
252 } // namespace
253 
254 // Recognize the left-deep join tree pattern with an optional filter as root
255 // with `node` as the parent of the join sub-tree. On match, return the root
256 // of the recognized tree (either the filter node or the outermost join).
257 std::shared_ptr<const RelAlgNode> get_left_deep_join_root(
258  const std::shared_ptr<RelAlgNode>& node) {
259  const auto left_deep_join_filter = dynamic_cast<const RelFilter*>(node.get());
260  if (left_deep_join_filter) {
261  const auto join = dynamic_cast<const RelJoin*>(left_deep_join_filter->getInput(0));
262  if (!join) {
263  return nullptr;
264  }
265  if (join->getJoinType() == JoinType::INNER || join->getJoinType() == JoinType::SEMI ||
266  join->getJoinType() == JoinType::ANTI) {
267  return node;
268  }
269  }
270  if (!node || node->inputCount() != 1) {
271  return nullptr;
272  }
273  const auto join = dynamic_cast<const RelJoin*>(node->getInput(0));
274  if (!join) {
275  return nullptr;
276  }
277  return node->getAndOwnInput(0);
278 }
279 
281  const RelLeftDeepInnerJoin* left_deep_join) {
282  RebindRexInputsFromLeftDeepJoin rebind_rex_inputs_from_left_deep_join(left_deep_join);
283  rebind_rex_inputs_from_left_deep_join.visit(rex);
284 }
285 
286 void create_left_deep_join(std::vector<std::shared_ptr<RelAlgNode>>& nodes) {
287  std::list<std::shared_ptr<RelAlgNode>> new_nodes;
288  for (auto& left_deep_join_candidate : nodes) {
289  std::shared_ptr<RelLeftDeepInnerJoin> left_deep_join;
290  std::shared_ptr<const RelAlgNode> old_root;
291  std::tie(left_deep_join, old_root) = create_left_deep_join(left_deep_join_candidate);
292  if (!left_deep_join) {
293  continue;
294  }
295  CHECK_GE(left_deep_join->inputCount(), size_t(2));
296  for (size_t nesting_level = 1; nesting_level <= left_deep_join->inputCount() - 1;
297  ++nesting_level) {
298  const auto outer_condition = left_deep_join->getOuterCondition(nesting_level);
299  if (outer_condition) {
300  rebind_inputs_from_left_deep_join(outer_condition, left_deep_join.get());
301  }
302  }
303  rebind_inputs_from_left_deep_join(left_deep_join->getInnerCondition(),
304  left_deep_join.get());
305  for (auto& node : nodes) {
306  if (node && node->hasInput(old_root.get())) {
307  node->replaceInput(left_deep_join_candidate, left_deep_join);
308  std::shared_ptr<const RelJoin> old_join;
309  if (std::dynamic_pointer_cast<const RelJoin>(left_deep_join_candidate)) {
310  old_join = std::static_pointer_cast<const RelJoin>(left_deep_join_candidate);
311  } else {
312  CHECK_EQ(size_t(1), left_deep_join_candidate->inputCount());
313  old_join = std::dynamic_pointer_cast<const RelJoin>(
314  left_deep_join_candidate->getAndOwnInput(0));
315  }
316  while (old_join) {
317  node->replaceInput(old_join, left_deep_join);
318  old_join =
319  std::dynamic_pointer_cast<const RelJoin>(old_join->getAndOwnInput(0));
320  }
321  }
322  }
323 
324  new_nodes.emplace_back(std::move(left_deep_join));
325  }
326 
327  // insert the new left join nodes to the front of the owned RelAlgNode list.
328  // This is done to ensure all created RelAlgNodes exist in this list for later
329  // visitation, such as RelAlgDag::resetQueryExecutionState.
330  nodes.insert(nodes.begin(), new_nodes.begin(), new_nodes.end());
331 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
std::vector< std::unique_ptr< const RexScalar > > outer_conditions_per_level_
Definition: RelAlgDag.h:2007
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
JoinType
Definition: sqldefs.h:238
size_t size() const override
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
bool coversOriginalNode(const RelAlgNode *node) const
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
const RexScalar * getOuterCondition(const size_t nesting_level) const
std::shared_ptr< RelFilter > original_filter_
Definition: RelAlgDag.h:2008
std::string join(T const &container, std::string const &delim)
#define CHECK_GE(x, y)
Definition: Logger.h:306
virtual size_t getOuterConditionsSize() const
std::vector< std::shared_ptr< const RelJoin > > original_joins_
Definition: RelAlgDag.h:2009
#define CHECK_GT(x, y)
Definition: Logger.h:305
std::shared_ptr< const RelAlgNode > getAndOwnInput(const size_t idx) const
Definition: RelAlgDag.h:890
std::string to_string(char const *&&v)
unsigned getIndex() const
Definition: RelAlgDag.h:174
RelLeftDeepInnerJoin()=default
void setIndex(const unsigned in_index) const
Definition: RelAlgDag.h:176
std::shared_ptr< RelAlgNode > deepCopy() const override
DEVICE void partial_sum(ARGS &&...args)
Definition: gpu_enabled.h:87
const RelAlgNode * getInput(const size_t idx) const
Definition: RelAlgDag.h:877
Definition: sqldefs.h:39
std::vector< std::shared_ptr< const RelJoin > > getOriginalJoins() const
DEVICE auto lower_bound(ARGS &&...args)
Definition: gpu_enabled.h:78
#define CHECK_LE(x, y)
Definition: Logger.h:304
void setSourceNode(const RelAlgNode *node) const
Definition: RelAlgDag.h:1061
virtual size_t size() const =0
const RelAlgNode * getSourceNode() const
Definition: RelAlgDag.h:1056
const RelFilter * getOriginalFilter() const
std::string typeName(const T *v)
Definition: toString.h:106
std::unique_ptr< const RexScalar > condition_
Definition: RelAlgDag.h:2006
const JoinType getJoinType(const size_t nesting_level) const
const RexScalar * getInnerCondition() const
#define CHECK(condition)
Definition: Logger.h:291
void collect_left_deep_join_inputs(std::deque< std::shared_ptr< const RelAlgNode >> &inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins, const std::shared_ptr< const RelJoin > &join)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
Definition: RelAlgDag.cpp:528
const size_t inputCount() const
Definition: RelAlgDag.h:875
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)
RelAlgInputs inputs_
Definition: RelAlgDag.h:945