23 #include "../Shared/funcannotations.h"
29 template <
typename KeyT =
int64_t,
typename IndexT =
int32_t>
32 const size_t key_stride,
36 auto keys_ptr =
reinterpret_cast<const KeyT*
>(
buffer +
stride * rowid);
37 return keys_ptr[
index];
45 template <
typename KeyT =
int64_t>
81 template <
typename KeyT =
int64_t,
typename NodeT =
int64_t>
83 const size_t heap_size,
87 for (NodeT i = curr_idx, last = static_cast<NodeT>(heap_size); i < last;) {
89 const auto left_child = min(2 * i + 1, last);
90 const auto right_child = min(2 * i + 2, last);
92 const auto left_child = std::min(2 * i + 1, last);
93 const auto right_child = std::min(2 * i + 2, last);
95 auto candidate_idx = last;
96 if (left_child < last) {
97 if (right_child < last) {
98 const auto left_key = accessor.
get(heap[left_child]);
99 const auto right_key = accessor.
get(heap[right_child]);
100 candidate_idx = compare(left_key, right_key) ? left_child : right_child;
102 candidate_idx = left_child;
105 candidate_idx = right_child;
107 if (candidate_idx >= last) {
110 const auto curr_key = accessor.
get(heap[i]);
111 const auto candidate_key = accessor.
get(heap[candidate_idx]);
112 if (compare(curr_key, candidate_key)) {
115 auto temp_id = heap[i];
116 heap[i] = heap[candidate_idx];
117 heap[candidate_idx] = temp_id;
122 template <
typename KeyT =
int64_t,
typename NodeT =
int64_t>
124 const NodeT curr_idx,
127 for (NodeT i = curr_idx; i > 0 && (i - 1) < i;) {
128 const auto parent = (i - 1) / 2;
129 const auto curr_key = accessor.
get(heap[i]);
130 const auto parent_key = accessor.
get(heap[parent]);
131 if (compare(parent_key, curr_key)) {
134 auto temp_id = heap[i];
135 heap[i] = heap[parent];
136 heap[parent] = temp_id;
141 template <
typename KeyT =
int64_t,
typename NodeT =
int64_t>
145 const uint32_t row_size_quad,
146 const uint32_t key_offset,
149 const KeyT curr_key) {
150 const NodeT bin_index = node_count++;
151 heap_ptr[bin_index] = bin_index;
152 int8_t* row_ptr =
reinterpret_cast<int8_t*
>(rows_ptr + bin_index * row_size_quad);
153 auto key_ptr =
reinterpret_cast<KeyT*
>(row_ptr + key_offset);
156 sift_up<KeyT, NodeT>(heap_ptr, bin_index, comparator, accessor);
159 template <
typename KeyT =
int64_t,
typename NodeT =
int64_t>
162 const NodeT node_count,
163 const uint32_t row_size_quad,
164 const uint32_t key_offset,
167 const KeyT curr_key) {
168 const NodeT top_bin_idx =
static_cast<NodeT
>(heap_ptr[0]);
169 int8_t* top_row_ptr =
reinterpret_cast<int8_t*
>(rows_ptr + top_bin_idx * row_size_quad);
170 auto top_key =
reinterpret_cast<KeyT*
>(top_row_ptr + key_offset);
171 if (compare(curr_key, *top_key)) {
177 sift_down<KeyT, NodeT>(heap_ptr, node_count, 0, compare, accessor);
182 template <
typename KeyT =
int64_t>
185 const uint32_t row_size_quad,
186 const uint32_t key_offset,
189 const bool nulls_first,
191 const KeyT curr_key) {
194 int64_t& node_count = heaps[thread_global_index];
195 int64_t* heap_ptr = heaps + thread_count + thread_global_index * k;
197 heaps + thread_count + thread_count * k + thread_global_index * row_size_quad * k;
203 row_size_quad *
sizeof(int64_t),
204 key_offset /
sizeof(KeyT));
205 if (node_count < static_cast<int64_t>(k)) {
214 const auto last_bin_index = node_count - 1;
215 auto row_ptr = rows_ptr + last_bin_index * row_size_quad;
216 row_ptr[0] = last_bin_index;
219 const int64_t top_bin_idx = heap_ptr[0];
231 auto row_ptr = rows_ptr + top_bin_idx * row_size_quad;
232 row_ptr[0] = top_bin_idx;
237 #define DEF_GET_BIN_FROM_K_HEAP(key_type) \
238 extern "C" RUNTIME_EXPORT NEVER_INLINE DEVICE int64_t* get_bin_from_k_heap_##key_type( \
241 const uint32_t row_size_quad, \
242 const uint32_t key_offset, \
243 const bool min_heap, \
244 const bool has_null, \
245 const bool nulls_first, \
246 const key_type null_key, \
247 const key_type curr_key) { \
248 return get_bin_from_k_heap_impl(heaps, \
ALWAYS_INLINE DEVICE void sift_down(NodeT *heap, const size_t heap_size, const NodeT curr_idx, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor)
const NullsOrdering nulls_ordering
ALWAYS_INLINE DEVICE void sift_up(NodeT *heap, const NodeT curr_idx, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor)
ALWAYS_INLINE DEVICE void push_heap(int64_t *heap_ptr, int64_t *rows_ptr, NodeT &node_count, const uint32_t row_size_quad, const uint32_t key_offset, const KeyComparator< KeyT > &comparator, const KeyAccessor< KeyT, NodeT > &accessor, const KeyT curr_key)
__device__ int32_t pos_step_impl()
const HeapOrdering heap_ordering
__device__ int32_t pos_start_impl(const int32_t *row_index_resume)
#define DEF_GET_BIN_FROM_K_HEAP(key_type)
ALWAYS_INLINE DEVICE KeyT get(const IndexT rowid) const
ALWAYS_INLINE DEVICE bool operator()(const KeyT lhs, const KeyT rhs) const
ALWAYS_INLINE DEVICE bool pop_and_push_heap(int64_t *heap_ptr, int64_t *rows_ptr, const NodeT node_count, const uint32_t row_size_quad, const uint32_t key_offset, const KeyComparator< KeyT > &compare, const KeyAccessor< KeyT, NodeT > &accessor, const KeyT curr_key)
ALWAYS_INLINE DEVICE int64_t * get_bin_from_k_heap_impl(int64_t *heaps, const uint32_t k, const uint32_t row_size_quad, const uint32_t key_offset, const bool min_heap, const bool has_null, const bool nulls_first, const KeyT null_key, const KeyT curr_key)
DEVICE KeyComparator(const HeapOrdering hp_order, const bool nullable, const KeyT null_val, const NullsOrdering null_order)
DEVICE KeyAccessor(const int8_t *key_buff, const size_t key_stride, const size_t key_idx)