25 const std::shared_ptr<RelFilter>& filter,
26 std::vector<std::shared_ptr<const RelAlgNode>> inputs,
27 std::vector<std::shared_ptr<const RelJoin>>& original_joins)
28 : condition_(filter ? filter->getAndReleaseCondition() : nullptr)
29 , original_filter_(filter)
30 , original_joins_(original_joins) {
31 std::vector<std::unique_ptr<const RexScalar>> operands;
32 bool is_notnull =
true;
35 outer_conditions_per_level_.resize(original_joins.size());
36 for (
size_t nesting_level = 0; nesting_level < original_joins.size(); ++nesting_level) {
37 const auto& original_join = original_joins[nesting_level];
38 const auto condition_true =
39 dynamic_cast<const RexLiteral*
>(original_join->getCondition());
40 if (!condition_true || !condition_true->getVal<
bool>()) {
41 if (dynamic_cast<const RexOperator*>(original_join->getCondition())) {
43 is_notnull &&
dynamic_cast<const RexOperator*
>(original_join->getCondition())
47 switch (original_join->getJoinType()) {
51 if (original_join->getCondition()) {
52 operands.emplace_back(original_join->getAndReleaseCondition());
57 if (original_join->getCondition()) {
58 outer_conditions_per_level_[nesting_level].reset(
59 original_join->getAndReleaseCondition());
68 if (!operands.empty()) {
70 CHECK(dynamic_cast<const RexOperator*>(condition_.get()));
73 static_cast<const RexOperator*
>(condition_.get())->getType().get_notnull();
74 operands.emplace_back(std::move(condition_));
76 if (operands.size() > 1) {
80 condition_ = std::move(operands.front());
86 for (
const auto& input : inputs) {
87 addManagedInput(input);
96 const size_t nesting_level)
const {
115 for (
const auto& input :
inputs_) {
116 ret +=
" " + input->toString(config);
119 ret +=
", input node id={";
133 size_t total_size = 0;
134 for (
const auto& input :
inputs_) {
135 total_size += input->size();
154 if (original_join.get() == node) {
167 std::vector<std::shared_ptr<const RelJoin>> original_joins;
169 return original_joins;
175 std::deque<std::shared_ptr<const RelAlgNode>>& inputs,
176 std::vector<std::shared_ptr<const RelJoin>>& original_joins,
177 const std::shared_ptr<const RelJoin>&
join) {
178 original_joins.push_back(join);
179 CHECK_EQ(
size_t(2), join->inputCount());
180 const auto left_input_join =
181 std::dynamic_pointer_cast<
const RelJoin>(join->getAndOwnInput(0));
182 if (left_input_join) {
183 inputs.push_front(join->getAndOwnInput(1));
186 inputs.push_front(join->getAndOwnInput(1));
187 inputs.push_front(join->getAndOwnInput(0));
191 std::pair<std::shared_ptr<RelLeftDeepInnerJoin>, std::shared_ptr<const RelAlgNode>>
195 return {
nullptr,
nullptr};
197 std::deque<std::shared_ptr<const RelAlgNode>> inputs_deque;
198 const auto left_deep_join_filter =
199 std::dynamic_pointer_cast<
RelFilter>(left_deep_join_root);
201 std::dynamic_pointer_cast<
const RelJoin>(left_deep_join_root->getAndOwnInput(0));
203 std::vector<std::shared_ptr<const RelJoin>> original_joins;
205 std::vector<std::shared_ptr<const RelAlgNode>> inputs(inputs_deque.begin(),
207 return {std::make_shared<RelLeftDeepInnerJoin>(
208 left_deep_join_filter, inputs, original_joins),
215 : left_deep_join_(left_deep_join) {
216 std::vector<size_t> input_sizes;
218 for (
size_t i = 0; i < left_deep_join->
inputCount(); ++i) {
219 input_sizes.push_back(left_deep_join->
getInput(i)->
size());
221 input_size_prefix_sums_.resize(input_sizes.size());
223 input_sizes.begin(), input_sizes.end(), input_size_prefix_sums_.begin());
228 if (left_deep_join_->coversOriginalNode(source_node)) {
230 input_size_prefix_sums_.end(),
232 std::less_equal<size_t>());
233 CHECK(it != input_size_prefix_sums_.end());
234 const auto input_node =
235 left_deep_join_->getInput(std::distance(input_size_prefix_sums_.begin(), it));
236 if (it != input_size_prefix_sums_.begin()) {
237 const auto prev_input_count = *(it - 1);
239 const auto input_index = rex_input->
getIndex() - prev_input_count;
248 std::vector<size_t> input_size_prefix_sums_;
258 const std::shared_ptr<RelAlgNode>& node) {
259 const auto left_deep_join_filter =
dynamic_cast<const RelFilter*
>(node.get());
260 if (left_deep_join_filter) {
261 const auto join =
dynamic_cast<const RelJoin*
>(left_deep_join_filter->getInput(0));
270 if (!node || node->inputCount() != 1) {
273 const auto join =
dynamic_cast<const RelJoin*
>(node->getInput(0));
282 RebindRexInputsFromLeftDeepJoin rebind_rex_inputs_from_left_deep_join(left_deep_join);
283 rebind_rex_inputs_from_left_deep_join.visit(rex);
287 std::list<std::shared_ptr<RelAlgNode>> new_nodes;
288 for (
auto& left_deep_join_candidate : nodes) {
289 std::shared_ptr<RelLeftDeepInnerJoin> left_deep_join;
290 std::shared_ptr<const RelAlgNode> old_root;
292 if (!left_deep_join) {
295 CHECK_GE(left_deep_join->inputCount(), size_t(2));
296 for (
size_t nesting_level = 1; nesting_level <= left_deep_join->inputCount() - 1;
298 const auto outer_condition = left_deep_join->getOuterCondition(nesting_level);
299 if (outer_condition) {
304 left_deep_join.get());
305 for (
auto& node : nodes) {
306 if (node && node->hasInput(old_root.get())) {
307 node->replaceInput(left_deep_join_candidate, left_deep_join);
308 std::shared_ptr<const RelJoin> old_join;
309 if (std::dynamic_pointer_cast<const RelJoin>(left_deep_join_candidate)) {
310 old_join = std::static_pointer_cast<
const RelJoin>(left_deep_join_candidate);
312 CHECK_EQ(
size_t(1), left_deep_join_candidate->inputCount());
313 old_join = std::dynamic_pointer_cast<
const RelJoin>(
314 left_deep_join_candidate->getAndOwnInput(0));
319 std::dynamic_pointer_cast<
const RelJoin>(old_join->getAndOwnInput(0));
324 new_nodes.emplace_back(std::move(left_deep_join));
330 nodes.insert(nodes.begin(), new_nodes.begin(), new_nodes.end());
std::vector< std::unique_ptr< const RexScalar > > outer_conditions_per_level_
std::string toString(RelRexToStringConfig config=RelRexToStringConfig::defaults()) const override
size_t size() const override
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
bool coversOriginalNode(const RelAlgNode *node) const
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
const RexScalar * getOuterCondition(const size_t nesting_level) const
std::shared_ptr< RelFilter > original_filter_
virtual size_t getOuterConditionsSize() const
std::vector< std::shared_ptr< const RelJoin > > original_joins_
std::shared_ptr< const RelAlgNode > getAndOwnInput(const size_t idx) const
RelLeftDeepInnerJoin()=default
std::shared_ptr< RelAlgNode > deepCopy() const override
DEVICE void partial_sum(ARGS &&...args)
const RelAlgNode * getInput(const size_t idx) const
std::vector< std::shared_ptr< const RelJoin > > getOriginalJoins() const
DEVICE auto lower_bound(ARGS &&...args)
virtual size_t size() const =0
const RelFilter * getOriginalFilter() const
std::string typeName(const T *v)
std::unique_ptr< const RexScalar > condition_
const JoinType getJoinType(const size_t nesting_level) const
const RexScalar * getInnerCondition() const
void collect_left_deep_join_inputs(std::deque< std::shared_ptr< const RelAlgNode >> &inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins, const std::shared_ptr< const RelJoin > &join)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
const size_t inputCount() const
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)