43 #include <type_traits>
49 template <
typename RealType,
typename IndexType>
55 template <
typename RealType,
typename IndexType>
62 static constexpr RealType
infinity = std::numeric_limits<RealType>::infinity();
63 static constexpr RealType
nan = std::numeric_limits<RealType>::quiet_NaN();
70 :
sums_(sums, size, size),
counts_(counts, size, size) {}
128 template <
typename RealType2,
typename IndexType2>
129 friend std::ostream& operator<<(std::ostream&, Centroids<RealType2, IndexType2>
const&);
133 template <
typename RealType,
typename IndexType>
167 template <
typename RealType,
typename IndexType>
175 size_t nbytes()
const {
return sums_.size() * (
sizeof(RealType) +
sizeof(IndexType)); }
183 template <
typename RealType,
typename IndexType =
size_t>
197 std::optional<RealType>
q_{std::nullopt};
211 IndexType
const total_weight,
219 IndexType
const idx1,
220 IndexType
const prefix_sum)
const;
224 DEVICE RealType
slope(IndexType
const idx1, IndexType
const idx2)
const;
232 :
centroids_(mem.sums().data(), mem.counts().data(), mem.size()) {
238 IndexType buf_allocate,
239 IndexType centroids_allocate)
248 buf_ = std::move(rhs.buf_);
252 q_ = std::move(rhs.q_);
261 static IndexType
nbytes(IndexType buf_allocate, IndexType centroids_allocate) {
262 return (buf_allocate + centroids_allocate) * (
sizeof(RealType) +
sizeof(IndexType));
293 RealType
const q)
const;
337 template <
typename RealType,
typename IndexType>
344 return lhs < rhs || (lhs == rhs && b.
value1() < a.
value1());
350 template <
typename RealType,
typename IndexType>
352 if (inc_ == -1 && curr_idx_ != 0) {
362 IndexType
const offset = inc_ == 1 ? 0 : buff.
curr_idx_;
363 IndexType
const buff_size =
368 IndexType
const curr_size = inc_ == 1 ? curr_idx_ + 1 : size() - curr_idx_;
369 IndexType
const total_size = curr_size + buff_sums.
size();
370 assert(total_size <= sums_.capacity());
371 sums_.resize(total_size);
372 gpu_enabled::copy(buff_sums.begin(), buff_sums.end(), sums_.begin() + curr_size);
373 assert(total_size <= counts_.capacity());
374 counts_.resize(total_size);
375 gpu_enabled::copy(buff_counts.begin(), buff_counts.end(), counts_.begin() + curr_size);
384 template <
typename RealType,
typename IndexType>
386 IndexType
const max_count) {
387 if (counts_[curr_idx_] + centroid.
nextCount() <= max_count) {
388 sums_[curr_idx_] += centroid.
nextSum();
389 counts_[curr_idx_] += centroid.
nextCount();
396 template <
typename RealType,
typename IndexType>
399 if (curr_idx_ != next_idx_) {
400 sums_[curr_idx_] = sums_[next_idx_];
401 counts_[curr_idx_] = counts_[next_idx_];
406 template <
typename RealType,
typename IndexType>
410 curr_idx_ = ~IndexType(0);
415 static_assert(std::is_unsigned<IndexType>::value,
416 "IndexType must be an unsigned type.");
417 next_idx_ = curr_idx_ + inc_;
421 template <
typename RealType,
typename IndexType>
424 out <<
"Centroids<" <<
typeid(RealType).
name() <<
',' <<
typeid(IndexType).
name()
425 <<
">(size(" << centroids.
size() <<
") curr_idx_(" << centroids.
curr_idx_
426 <<
") next_idx_(" << centroids.
next_idx_ <<
") sums_(";
427 for (IndexType i = 0; i < centroids.
sums_.
size(); ++i) {
428 out << (i ?
" " :
"") << std::setprecision(20) << centroids.
sums_[i];
431 for (IndexType i = 0; i < centroids.
counts_.
size(); ++i) {
432 out << (i ?
" " :
"") << centroids.
counts_[i];
440 template <
typename RealType,
typename IndexType>
446 , centroids_(centroids)
447 , total_weight_(centroids->totalWeight() + buf->totalWeight())
448 , forward_(forward) {
455 template <
typename RealType,
typename IndexType>
458 if (buf_->hasNext()) {
459 if (centroids_->hasNext()) {
460 return (*buf_ < *centroids_) == forward_ ? buf_ : centroids_;
463 }
else if (centroids_->hasNext()) {
474 template <
typename RealType,
typename IndexType>
479 IndexType count_merged_{0};
480 IndexType count_skipped_{0};
497 template <
typename T>
500 IndexType
const merged,
503 T* src = begin + inc * (skipped - 1);
504 T* dst = src + inc * merged;
505 for (; skipped; --skipped, src -= inc, dst -= inc) {
510 std::copy_backward(begin, begin + skipped, begin + skipped + merged);
512 std::copy(begin + 1 - skipped, begin + 1, begin + 1 - skipped - merged);
519 return data_[0].centroid_ != centroid;
522 return mean_.sum_ * next_centroid->
nextCount() !=
523 next_centroid->
nextSum() * mean_.count_;
526 IndexType
const idx = index(next_centroid);
527 if (data_[idx].count_skipped_) {
528 ++data_[idx].count_merged_;
532 return data_[0].centroid_;
536 shiftCentroids(data_[0]);
537 data_[0].centroid_->next_idx_ = data_[0].start_;
538 if (data_[1].centroid_) {
539 shiftCentroids(data_[1]);
540 data_[1].centroid_->next_idx_ = data_[1].start_;
546 data_[0] = {next_centroid, next_centroid->
next_idx_, 0, 1};
550 IndexType
const idx = index(next_centroid);
551 if (idx == 1 && data_[1].centroid_ ==
nullptr) {
552 data_[1] = {next_centroid, next_centroid->
next_idx_, 0, 1};
554 if (data_[idx].count_merged_) {
555 shiftCentroids(data_[idx]);
556 data_[idx].count_merged_ = 0;
558 ++data_[idx].count_skipped_;
567 template <
typename RealType,
typename IndexType>
569 Skipped<RealType, IndexType> skipped;
570 while (
auto* next_centroid = getNextCentroid()) {
572 if (skipped.isDifferentMean(next_centroid)) {
574 }
else if (curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
575 skipped.merged(next_centroid);
577 skipped.skipSubsequent(next_centroid);
579 }
else if (!curr_centroid_->mergeIfFits(*next_centroid, max_count)) {
580 skipped.skipFirst(next_centroid);
584 skipped.shiftCentroidsAndSetNext();
596 template <
typename RealType,
typename IndexType>
598 if (centroids_->max_ < buf_->max_) {
599 centroids_->max_ = buf_->max_;
601 if (buf_->min_ < centroids_->min_) {
602 centroids_->min_ = buf_->min_;
607 template <
typename RealType,
typename IndexType>
609 prefix_sum_ += curr_centroid_->currCount();
613 template <
typename RealType,
typename IndexType>
615 if ((curr_centroid_ = getNextCentroid())) {
616 curr_centroid_->moveNextToCurrent();
622 template <
typename RealType,
typename IndexType>
624 if (buf_.sums_.full()) {
627 buf_.sums_.push_back(value);
628 buf_.counts_.push_back(1);
632 template <
typename RealType,
typename IndexType>
634 if (buf_.capacity() == 0) {
635 auto* p0 = simple_allocator_->allocate(buf_allocate_ *
sizeof(RealType));
636 auto* p1 = simple_allocator_->allocate(buf_allocate_ *
sizeof(IndexType));
640 p0 = simple_allocator_->allocate(centroids_allocate_ *
sizeof(RealType));
641 p1 = simple_allocator_->allocate(centroids_allocate_ *
sizeof(IndexType));
649 template <
typename RealType,
typename IndexType>
652 IndexType
const total_weight,
654 IndexType
const max_bins = centroids_.capacity();
655 if (total_weight <= max_bins) {
657 }
else if (use_linear_scaling_function_) {
659 return 2 * total_weight / max_bins;
662 RealType
const x = 2.0 * sum / total_weight - 1;
663 RealType
const f_inv = 0.5 + 0.5 * std::sin(c + std::asin(x));
664 constexpr RealType eps = 1e-5;
665 IndexType
const dsum =
static_cast<IndexType
>(total_weight * f_inv + eps);
666 return dsum < sum ? 0 : dsum - sum;
671 template <
typename RealType,
typename IndexType>
675 buf_.min_ = buf_.sums_.front();
676 buf_.max_ = buf_.sums_.back();
677 mergeCentroids(buf_);
682 template <
typename RealType,
typename IndexType>
685 std::lock_guard<std::mutex>
lock_guard(merge_buffer_final_called_mutex_);
687 if (!merge_buffer_final_called_) {
689 assert(centroids_.size() <= buf_.capacity());
690 partialSumOfCounts(buf_.counts_.data());
691 merge_buffer_final_called_ =
true;
695 template <
typename RealType,
typename IndexType>
700 if (buf_.capacity() == 0) {
704 buf_.counts_.set(counts, size);
707 buf_.min_ = buf_.sums_.front();
708 buf_.max_ = buf_.sums_.back();
709 mergeCentroids(buf_);
718 template <
typename RealType,
typename IndexType>
721 constexpr RealType two_pi = 6.283185307179586476925286766559005768e+00;
723 RealType
const c = two_pi / centroids_.capacity();
728 for (CM cm(&buf, ¢roids_, forward_); cm.hasNext(); cm.next()) {
731 IndexType
const max_cardinality = maxCardinality(cm.prefixSum(), cm.totalWeight(), c);
732 cm.
merge(max_cardinality);
735 centroids_.appendAndSortCurrent(buf);
741 template <
typename CountsIterator>
748 template <
typename RealType,
typename IndexType>
752 }
else if (centroids_.size() == 1) {
753 return oneCentroid(x);
754 }
else if (centroids_.counts_.front() == 2) {
755 RealType
const sum = centroids_.sums_.front();
756 return x == 1 ? 0.5 * sum : sum - min();
758 RealType
const count = centroids_.counts_.front();
759 RealType
const dx = x - RealType(0.5) * (1 + count);
760 RealType
const mean = (centroids_.sums_.front() - min()) / (count - 1);
761 return mean + slope(0, 0 < dx) * dx;
766 template <
typename RealType,
typename IndexType>
769 IndexType
const idx1,
770 IndexType
const prefix_sum)
const {
771 if (
isSingleton(centroids_.counts_.begin() + idx1)) {
772 RealType
const sum1 = centroids_.sums_[idx1];
773 if (x == prefix_sum - centroids_.counts_[idx1]) {
774 if (
isSingleton(centroids_.counts_.begin() + idx1 - 1)) {
775 return 0.5 * (centroids_.sums_[idx1 - 1] + sum1);
776 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
777 return 0.5 * (centroids_.sums_[idx1 - 1] - min() + sum1);
782 RealType
const dx = x + RealType(0.5) * centroids_.counts_[idx1] - prefix_sum;
783 IndexType
const idx2 = idx1 + 2 * (0 < dx) - 1;
784 return centroids_.mean(idx1) + slope(idx1, idx2) * dx;
789 template <
typename RealType,
typename IndexType>
791 IndexType
const N)
const {
795 IndexType
const idx1 = centroids_.size() - 1;
796 RealType
const sum1 = centroids_.sums_[idx1];
797 IndexType
const count1 = centroids_.counts_[idx1];
799 if (
isSingleton(centroids_.counts_.begin() + (idx1 - 1))) {
800 return 0.5 * (centroids_.sums_[idx1 - 1] + sum1);
801 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
802 return 0.5 * (centroids_.sums_[idx1 - 1] - min() + sum1);
806 }
else if (count1 == 2) {
809 }
else if (x == N - 2) {
810 RealType
const sum2 = centroids_.sums_[idx1 - 1];
811 if (
isSingleton(centroids_.counts_.begin() + (idx1 - 1))) {
812 return 0.5 * (sum2 + sum1 - max());
813 }
else if (idx1 == 1 && centroids_.counts_[0] == 2) {
814 return 0.5 * (sum2 - min() + sum1 - max());
819 RealType
const dx = x + RealType(0.5) * (count1 + 1) - N;
820 RealType
const mean = (sum1 - max()) / (count1 - 1);
821 return mean + slope(idx1, idx1 - (dx < 0)) * dx;
826 template <
typename RealType,
typename IndexType>
828 IndexType
const N = centroids_.counts_.front();
832 return 0.5 * centroids_.sums_.front();
835 return 0.5 * (centroids_.sums_.front() - min());
837 RealType
const s = centroids_.sums_.front() - max();
838 return x == 1 ? 0.5 * s : s - min();
841 RealType
const dx = x - RealType(0.5) *
N;
842 RealType
const mean = (centroids_.sums_.front() - (min() + max())) / (N - 2);
843 RealType
const slope = 2 * (0 < dx ? max() - mean : mean - min()) / (N - 2);
844 return mean + slope * dx;
849 template <
typename RealType,
typename IndexType>
851 IndexType*
const buf)
const {
853 return {buf, centroids_.size()};
856 template <
typename RealType,
typename IndexType>
859 RealType
const q)
const {
860 if (centroids_.size()) {
861 IndexType
const N = partial_sum.
back();
862 RealType
const x = q *
N;
864 if (it1 == partial_sum.
begin()) {
865 return firstCentroid(x);
866 }
else if (it1 == partial_sum.
end()) {
868 }
else if (it1 + 1 == partial_sum.
end()) {
869 return lastCentroid(x, N);
871 return interiorCentroid(x, it1 - partial_sum.
begin(), *it1);
874 return centroids_.nan;
882 template <
typename RealType,
typename IndexType>
884 IndexType idx2)
const {
885 IndexType
const M = centroids_.size();
887 RealType
const n =
static_cast<RealType
>(centroids_.counts_[idx1]);
888 RealType
const s = centroids_.sums_[idx1];
889 return idx1 == 0 ? 2 * (s - n * min()) / ((n - 1) * (n - 1))
890 : 2 * (n * max() - s) / ((n - 1) * (n - 1));
892 bool const min1 = idx1 == 0;
893 bool const max1 = idx1 == M - 1;
894 bool const min2 = idx2 == 0;
895 bool const max2 = idx2 == M - 1;
896 RealType
const n1 =
static_cast<RealType
>(centroids_.counts_[idx1] - min1 - max1);
897 RealType
const s1 = centroids_.sums_[idx1] - (min1 ? min() : max1 ? max() : 0);
898 RealType
const s2 = centroids_.sums_[idx2] - (min2 ? min() : max2 ? max() : 0);
899 if (
isSingleton(centroids_.counts_.begin() + idx2)) {
900 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - s1) / (n1 * n1);
902 RealType
const n2 =
static_cast<RealType
>(centroids_.counts_[idx2] - min2 - max2);
903 return (idx1 < idx2 ? 2 : -2) * (n1 * s2 - n2 * s1) / (n1 * n2 * (n1 + n2));
DEVICE auto upper_bound(ARGS &&...args)
std::lock_guard< T > lock_guard
DEVICE void setCurrCentroid()
static constexpr RealType infinity
DEVICE void push_back(RealType const value, RealType const count)
DEVICE void push_back(T const &value)
std::vector< RealType > sums_
DEVICE size_type capacity() const
Centroids< RealType, IndexType > * buf_
DEVICE void skipSubsequent(Centroids< RealType, IndexType > *next_centroid)
std::ostream & operator<<(std::ostream &out, Centroids< RealType, IndexType > const ¢roids)
CentroidsMemory(size_t const size)
DEVICE RealType mean(IndexType const i) const
DEVICE bool index(Centroids< RealType, IndexType > *centroid) const
DEVICE void resetIndices(bool const forward)
DEVICE Centroids< RealType, IndexType > & centroids()
DEVICE void moveNextToCurrent()
DEVICE void add(RealType value)
DEVICE VectorView< IndexType const > partialSumOfCounts(IndexType *const buf) const
Centroids< RealType, IndexType > * centroid_
DEVICE RealType quantile(RealType const q) const
SimpleAllocator * simple_allocator_
IndexType const total_weight_
DEVICE void sort(ARGS &&...args)
DEVICE void merged(Centroids< RealType, IndexType > *next_centroid)
DEVICE bool mergeIfFits(Centroids ¢roid, IndexType const max_count)
DEVICE bool hasNext() const
DEVICE TDigest(Memory &mem)
DEVICE void mergeBuffer()
DEVICE IndexType totalWeight() const
TDigest & operator=(TDigest &&rhs)
VectorView< RealType > sums()
bool use_linear_scaling_function_
VectorView< RealType > sums_
VectorView< IndexType > counts_
std::vector< IndexType > counts_
static DEVICE void shiftCentroids(Data &data)
DEVICE void mergeBufferFinal()
DEVICE RealType currMean() const
DEVICE RealType quantile()
DEVICE void setCentroids(Memory &mem)
DEVICE size_type size() const
DEVICE TDigest(RealType q, SimpleAllocator *simple_allocator, IndexType buf_allocate, IndexType centroids_allocate)
static IndexType nbytes(IndexType buf_allocate, IndexType centroids_allocate)
Centroids< RealType, IndexType > * centroids_
DEVICE RealType max() const
DEVICE void fill(ARGS &&...args)
DEVICE void set(T *data, size_type const size)
DEVICE auto copy(ARGS &&...args)
std::optional< RealType > q_
DEVICE Centroids(VectorView< RealType > sums, VectorView< IndexType > counts)
bool merge_buffer_final_called_
DEVICE CentroidsMerger(Centroids< RealType, IndexType > *buf, Centroids< RealType, IndexType > *centroids, bool const forward)
DEVICE void mergeTDigest(TDigest &t_digest)
DEVICE IndexType nextCount() const
DEVICE void partial_sum(ARGS &&...args)
DEVICE void setCentroids(VectorView< RealType > const sums, VectorView< IndexType > const counts)
DEVICE RealType slope(IndexType const idx1, IndexType const idx2) const
DEVICE auto accumulate(ARGS &&...args)
Centroids< RealType, IndexType > buf_
DEVICE RealType nextSum() const
DEVICE RealType lastCentroid(RealType const x, IndexType const N) const
DEVICE void skipFirst(Centroids< RealType, IndexType > *next_centroid)
static DEVICE void shiftRange(T *const begin, IndexType skipped, IndexType const merged, int const inc)
DEVICE void setBuffer(Memory &mem)
DEVICE bool operator()(Value const &a, Value const &b) const
static constexpr RealType nan
DEVICE Centroids< RealType, IndexType > * getNextCentroid() const
DEVICE IndexType prefixSum() const
std::mutex merge_buffer_final_called_mutex_
DEVICE Centroids(RealType *sums, IndexType *counts, IndexType const size)
DEVICE void merge(IndexType const max_count)
DEVICE bool operator<(Centroids const &b) const
DEVICE bool hasNext() const
DEVICE void shiftCentroidsAndSetNext()
Centroids< RealType, IndexType > * curr_centroid_
DEVICE IndexType maxCardinality(IndexType const sum, IndexType const total_weight, RealType const c)
DEVICE size_t size() const
DEVICE IndexType currCount() const
DEVICE IndexType totalWeight() const
DEVICE bool isDifferentMean(Centroids< RealType, IndexType > *next_centroid) const
DEVICE IndexType capacity() const
DEVICE void mergeMinMax()
DEVICE IndexType totalWeight() const
DEVICE RealType interiorCentroid(RealType const x, IndexType const idx1, IndexType const prefix_sum) const
DEVICE RealType oneCentroid(RealType const x) const
IndexType centroids_allocate_
DEVICE void mergeCentroids(Centroids< RealType, IndexType > &)
DEVICE void reverse(ARGS &&...args)
DEVICE RealType min() const
DEVICE RealType firstCentroid(RealType const x) const
DEVICE bool isSingleton(CountsIterator itr)
Centroid< RealType, IndexType > mean_
DEVICE void appendAndSortCurrent(Centroids &buff)
DEVICE bool hasCurr() const
DEVICE void mergeSorted(RealType *sums, IndexType *counts, IndexType size)
VectorView< IndexType > counts()
Centroids< RealType, IndexType > centroids_