24 const int64_t num_rows,
27 const size_t max_thread_count = std::thread::hardware_concurrency();
29 const size_t num_threads = std::min(
30 max_thread_count, ((num_rows + max_inputs_per_thread - 1) / max_inputs_per_thread));
32 std::vector<T> local_col_mins(num_threads, std::numeric_limits<T>::max());
33 std::vector<T> local_col_maxes(num_threads, std::numeric_limits<T>::lowest());
34 std::vector<double> local_col_sums(num_threads, 0.);
35 std::vector<int64_t> local_col_non_null_or_filtered_counts(num_threads, 0L);
36 tbb::task_arena limited_arena(num_threads);
37 limited_arena.execute([&] {
39 tbb::blocked_range<int64_t>(0, num_rows),
40 [&](
const tbb::blocked_range<int64_t>& r) {
41 const int64_t start_idx = r.begin();
42 const int64_t end_idx = r.end();
43 T local_col_min = std::numeric_limits<T>::max();
44 T local_col_max = std::numeric_limits<T>::lowest();
45 double local_col_sum = 0.;
46 int64_t local_col_non_null_or_filtered_count = 0;
47 for (int64_t r = start_idx; r < end_idx; ++r) {
48 const T val = data[r];
49 if constexpr (std::is_same_v<T, float> || std::is_same_v<T, double>) {
50 if (std::isnan(val) || std::isinf(val)) {
54 if (val == inline_null_value<T>()) {
57 if (!predicate(val)) {
60 if (val < local_col_min) {
63 if (val > local_col_max) {
66 local_col_sum += data[r];
67 local_col_non_null_or_filtered_count++;
69 size_t thread_idx = tbb::this_task_arena::current_thread_index();
70 if (local_col_min < local_col_mins[thread_idx]) {
71 local_col_mins[thread_idx] = local_col_min;
73 if (local_col_max > local_col_maxes[thread_idx]) {
74 local_col_maxes[thread_idx] = local_col_max;
76 local_col_sums[thread_idx] += local_col_sum;
77 local_col_non_null_or_filtered_counts[thread_idx] +=
78 local_col_non_null_or_filtered_count;
88 for (
size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
89 if (local_col_mins[thread_idx] < column_stats.
min) {
90 column_stats.
min = local_col_mins[thread_idx];
92 if (local_col_maxes[thread_idx] > column_stats.
max) {
93 column_stats.
max = local_col_maxes[thread_idx];
95 col_sum += local_col_sums[thread_idx];
97 local_col_non_null_or_filtered_counts[thread_idx];
101 column_stats.
sum = col_sum;
109 const int64_t num_rows,
113 const int64_t num_rows,
117 const int64_t num_rows,
121 const int64_t num_rows,
125 const int64_t num_rows,
129 const int64_t num_rows,
132 template <
typename T>
159 if (str ==
"COUNT") {
174 throw std::runtime_error(
"Invalid StatsRequestAggType: " + str);
178 const std::string& str) {
182 if (str ==
"LT" || str ==
"<") {
185 if (str ==
"GT" || str ==
">") {
188 throw std::runtime_error(
"Invalid StatsRequestPredicateOp: " + str);
192 const std::string& pattern_str,
193 const std::string& replacement_str) {
194 std::string replaced_str(str);
196 size_t search_start_index = 0;
197 const auto pattern_str_len = pattern_str.size();
198 const auto replacement_str_len = replacement_str.size();
201 search_start_index = replaced_str.find(pattern_str, search_start_index);
202 if (search_start_index == std::string::npos) {
205 replaced_str.replace(search_start_index, pattern_str_len, replacement_str);
206 search_start_index += replacement_str_len;
212 const std::string& stats_requests_json_str,
213 const int64_t num_attrs) {
214 std::vector<StatsRequest> stats_requests;
215 rapidjson::Document doc;
218 const auto fixed_stats_requests_json_str =
221 if (doc.Parse(fixed_stats_requests_json_str.c_str()).HasParseError()) {
223 std::cout <<
"DEBUG: Failed JSON: " << fixed_stats_requests_json_str << std::endl;
224 throw std::runtime_error(
"Could not parse Stats Requests JSON.");
227 if (!doc.IsArray()) {
228 throw std::runtime_error(
"Stats Request JSON did not contain valid root Array.");
230 const std::vector<std::string> required_keys = {
231 "name",
"attr_id",
"agg_type",
"filter_type"};
233 for (
const auto& stat_request_obj : doc.GetArray()) {
234 for (
const auto& required_key : required_keys) {
235 if (!stat_request_obj.HasMember(required_key)) {
236 throw std::runtime_error(
"Stats Request JSON missing key " + required_key +
".");
238 if (required_key ==
"attr_id") {
239 if (!stat_request_obj[required_key].IsUint()) {
240 throw std::runtime_error(required_key +
" must be int type");
243 if (!stat_request_obj[required_key].IsString()) {
244 throw std::runtime_error(required_key +
" must be string type");
249 stats_request.
name = stat_request_obj[
"name"].GetString();
250 stats_request.
attr_id = stat_request_obj[
"attr_id"].GetInt() - 1;
251 if (stats_request.
attr_id < 0 || stats_request.
attr_id >= num_attrs) {
252 throw std::runtime_error(
"Invalid attr_id: " +
256 std::string agg_type_str = stat_request_obj[
"agg_type"].GetString();
258 agg_type_str.begin(), agg_type_str.end(), agg_type_str.begin(), ::toupper);
261 std::string filter_type_str = stat_request_obj[
"filter_type"].GetString();
263 filter_type_str.end(),
264 filter_type_str.begin(),
269 if (!stat_request_obj.HasMember(
"filter_val")) {
270 throw std::runtime_error(
"Stats Request JSON missing expected filter_val");
272 if (!stat_request_obj[
"filter_val"].IsNumber()) {
273 throw std::runtime_error(
"Stats Request JSON filter_val should be numeric.");
275 stats_request.
filter_val = stat_request_obj[
"filter_val"].GetDouble();
277 stats_requests.emplace_back(stats_request);
279 return stats_requests;
283 const std::vector<StatsRequest>& stats_requests) {
284 std::vector<std::pair<const char*, double>> stats_key_value_pairs;
285 for (
const auto& stats_request : stats_requests) {
286 stats_key_value_pairs.emplace_back(
287 std::make_pair(stats_request.name.c_str(), stats_request.result));
289 return stats_key_value_pairs;
std::vector< StatsRequest > parse_stats_requests_json(const std::string &stats_requests_json_str, const int64_t num_attrs)
std::vector< std::pair< const char *, double > > get_stats_key_value_pairs(const std::vector< StatsRequest > &stats_requests)
DEVICE int64_t size() const
NEVER_INLINE HOST ColumnStats< T > get_column_stats(const T *data, const int64_t num_rows, const StatsRequestPredicate &predicate)
StatsRequestPredicateOp convert_string_to_stats_request_predicate_op(const std::string &str)
DEVICE T * getPtr() const
std::string replace_substrings(const std::string &str, const std::string &pattern_str, const std::string &replacement_str)
const size_t max_inputs_per_thread
int64_t non_null_or_filtered_count
OUTPUT transform(INPUT const &input, FUNC const &func)
StatsRequestAggType convert_string_to_stats_request_agg_type(const std::string &str)
StatsRequestPredicateOp filter_type
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
StatsRequestAggType agg_type