10 #include <ATen/ATen.h>
11 #include <torch/library.h>
13 template <
typename scalar_t>
15 const at::Tensor& scores,
16 double iou_threshold) {
17 TORCH_CHECK(dets.is_cpu(),
"dets must be a CPU tensor");
18 TORCH_CHECK(scores.is_cpu(),
"scores must be a CPU tensor");
19 TORCH_CHECK(dets.scalar_type() == scores.scalar_type(),
20 "dets should have the same type as scores");
22 if (dets.numel() == 0)
23 return at::empty({0}, dets.options().dtype(at::kLong));
25 auto x1_t = dets.select(1, 0).contiguous();
26 auto y1_t = dets.select(1, 1).contiguous();
27 auto x2_t = dets.select(1, 2).contiguous();
28 auto y2_t = dets.select(1, 3).contiguous();
30 at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
33 std::get<1>(scores.sort(
true, 0,
true));
35 auto ndets = dets.size(0);
36 at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
37 at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
39 auto suppressed = suppressed_t.data_ptr<uint8_t>();
40 auto keep = keep_t.data_ptr<int64_t>();
41 auto order = order_t.data_ptr<int64_t>();
42 auto x1 = x1_t.data_ptr<scalar_t>();
43 auto y1 = y1_t.data_ptr<scalar_t>();
44 auto x2 = x2_t.data_ptr<scalar_t>();
45 auto y2 = y2_t.data_ptr<scalar_t>();
46 auto areas = areas_t.data_ptr<scalar_t>();
48 int64_t num_to_keep = 0;
50 for (int64_t _i = 0; _i < ndets; _i++) {
52 if (suppressed[i] == 1)
54 keep[num_to_keep++] = i;
59 auto iarea = areas[i];
61 for (int64_t _j = _i + 1; _j < ndets; _j++) {
63 if (suppressed[j] == 1)
65 auto xx1 = std::max(ix1, x1[j]);
66 auto yy1 = std::max(iy1, y1[j]);
67 auto xx2 = std::min(ix2, x2[j]);
68 auto yy2 = std::min(iy2, y2[j]);
70 auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
71 auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
73 auto ovr = inter / (iarea + areas[j] - inter);
74 if (ovr > iou_threshold)
78 return keep_t.narrow(0, 0, num_to_keep);
82 const at::Tensor& scores,
83 double iou_threshold) {
84 TORCH_CHECK(dets.dim() == 2,
"boxes should be a 2d tensor, got ", dets.dim(),
"D");
85 TORCH_CHECK(dets.size(1) == 4,
86 "boxes should have 4 elements in dimension 1, got ",
88 TORCH_CHECK(scores.dim() == 1,
"scores should be a 1d tensor, got ", scores.dim(),
"D");
89 TORCH_CHECK(dets.size(0) == scores.size(0),
90 "boxes and scores should have same number of elements in ",
96 auto result = at::empty({0}, dets.options());
98 AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(),
"nms_kernel", [&] {
99 result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
at::Tensor nms_kernel(const at::Tensor &dets, const at::Tensor &scores, double iou_threshold)
at::Tensor nms_kernel_impl(const at::Tensor &dets, const at::Tensor &scores, double iou_threshold)