17 TORCH_CHECK(dets.is_cpu(),
"dets must be a CPU tensor");
18 TORCH_CHECK(scores.is_cpu(),
"scores must be a CPU tensor");
19 TORCH_CHECK(dets.scalar_type() == scores.scalar_type(),
20 "dets should have the same type as scores");
22 if (dets.numel() == 0)
23 return at::empty({0}, dets.options().dtype(at::kLong));
25 auto x1_t = dets.select(1, 0).contiguous();
26 auto y1_t = dets.select(1, 1).contiguous();
27 auto x2_t = dets.select(1, 2).contiguous();
28 auto y2_t = dets.select(1, 3).contiguous();
30 at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
33 std::get<1>(scores.sort(
true, 0,
true));
35 auto ndets = dets.size(0);
36 at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
37 at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
39 auto suppressed = suppressed_t.data_ptr<uint8_t>();
40 auto keep = keep_t.data_ptr<int64_t>();
41 auto order = order_t.data_ptr<int64_t>();
42 auto x1 = x1_t.data_ptr<scalar_t>();
43 auto y1 = y1_t.data_ptr<scalar_t>();
44 auto x2 = x2_t.data_ptr<scalar_t>();
45 auto y2 = y2_t.data_ptr<scalar_t>();
46 auto areas = areas_t.data_ptr<scalar_t>();
48 int64_t num_to_keep = 0;
50 for (int64_t _i = 0; _i < ndets; _i++) {
52 if (suppressed[i] == 1)
54 keep[num_to_keep++] = i;
59 auto iarea = areas[i];
61 for (int64_t _j = _i + 1; _j < ndets; _j++) {
63 if (suppressed[j] == 1)
65 auto xx1 = std::max(ix1, x1[j]);
66 auto yy1 = std::max(iy1, y1[j]);
67 auto xx2 = std::min(ix2, x2[j]);
68 auto yy2 = std::min(iy2, y2[j]);
70 auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
71 auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
73 auto ovr = inter / (iarea + areas[j] - inter);
74 if (ovr > iou_threshold)
78 return keep_t.narrow(0, 0, num_to_keep);