24 #include <torch/script.h>
25 #include <torch/torch.h>
27 #ifdef HAVE_CUDA_TORCH
28 #include <ATen/cuda/CUDAEvent.h>
29 #include <c10/cuda/CUDAStream.h>
36 #include <shared_mutex>
38 #include <unordered_map>
40 #include "rapidjson/document.h"
43 std::string device_type{
"cpu"};
44 #ifdef HAVE_CUDA_TORCH
45 if (torch::cuda::is_available() && use_gpu) {
53 bool use_half =
false;
54 #ifdef HAVE_CUDA_TORCH
55 if (use_gpu && model_path.find(
"half") != std::string::npos) {
62 static std::unordered_map<std::string, std::shared_ptr<torch::jit::script::Module>>
67 const std::string& model_path) {
68 std::shared_lock<std::shared_mutex> model_cache_read_lock(
model_mutex);
73 return model_itr->second;
77 std::shared_ptr<torch::jit::script::Module> model_module) {
78 std::unique_lock<std::shared_mutex> model_cache_write_lock(
model_mutex);
82 std::shared_ptr<torch::jit::script::Module>
load_module(
const std::string& model_path,
83 const std::string compute_device,
84 const at::ScalarType data_type,
86 const bool use_cache) {
87 std::shared_ptr<torch::jit::script::Module> module;
93 if (module ==
nullptr) {
94 module = std::make_shared<torch::jit::script::Module>(
torch::jit::load(model_path));
95 module->to(compute_device, data_type);
104 }
catch (
const c10::Error& e) {
105 std::string error_msg{
"Error loading the provided model: "};
106 error_msg += e.what();
107 throw std::runtime_error(error_msg);
113 const size_t max_search_chars) {
114 std::ifstream model_file(filename);
115 bool found_opening_brace =
false;
116 size_t brace_nest_count = 0;
119 if (model_file.is_open()) {
121 while (model_file.get(c) && (brace_nest_count >= 1 || char_idx < max_search_chars)) {
124 found_opening_brace =
true;
127 if (found_opening_brace) {
131 if (brace_nest_count > 0) {
133 if (found_opening_brace &&
134 brace_nest_count == 0) {
141 if (found_opening_brace && brace_nest_count == 0) {
149 rapidjson::Document doc;
150 if (doc.Parse<0>(json_str.c_str()).HasParseError()) {
153 const auto shape_array_itr = doc.FindMember(
"shape");
154 if (shape_array_itr != doc.MemberEnd() && shape_array_itr->value.IsArray()) {
155 const rapidjson::SizeType num_shape_elems = shape_array_itr->value.Size();
156 if (num_shape_elems == 4) {
157 model_info.
batch_size = shape_array_itr->value[0].GetInt();
163 const auto stride_itr = doc.FindMember(
"stride");
164 if (stride_itr != doc.MemberEnd() && stride_itr->value.IsInt()) {
165 model_info.
stride = stride_itr->value.GetInt();
167 const auto class_labels_itr = doc.FindMember(
"names");
168 if (class_labels_itr != doc.MemberEnd() && class_labels_itr->value.IsArray()) {
169 const rapidjson::SizeType num_class_labels = class_labels_itr->value.Size();
170 model_info.
class_labels.reserve(static_cast<size_t>(num_class_labels));
171 for (
auto&
label : class_labels_itr->value.GetArray()) {
196 auto y = torch::zeros_like(x);
207 const torch::Tensor& input,
208 const int64_t batch_idx,
210 const int64_t tile_y_idx = batch_idx / raster_info.
x_tiles;
211 const int64_t tile_x_idx = batch_idx % raster_info.
x_tiles;
216 auto options = torch::TensorOptions().dtype(torch::kFloat64);
219 auto output = torch::zeros_like(input, options);
238 const torch::Tensor& raw_detections,
239 const float min_confidence_threshold,
240 const float iou_threshold,
243 std::shared_ptr<CpuTimer> timer) {
246 timer->start_event_timer(
"Confidence mask");
247 constexpr int64_t item_attr_size = 5;
249 const int32_t num_class_labels =
static_cast<int32_t
>(class_labels.size());
250 const auto batch_size = raster_info.
x_tiles * raster_info.
y_tiles;
251 const auto num_classes = raw_detections.size(2) - item_attr_size;
252 auto conf_mask = raw_detections.select(2, 4).ge(min_confidence_threshold).unsqueeze(2);
253 torch::Tensor all_world_scaled_detections;
254 for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
255 auto masked_detections =
256 torch::masked_select(raw_detections[batch_idx], conf_mask[batch_idx])
257 .view({-1, num_classes + item_attr_size});
259 if (masked_detections.size(0) == 0) {
263 masked_detections.slice(1, item_attr_size, item_attr_size + num_classes) *=
264 masked_detections.select(1, 4).unsqueeze(1);
267 std::tuple<torch::Tensor, torch::Tensor> max_classes = torch::max(
268 masked_detections.slice(1, item_attr_size, item_attr_size + num_classes), 1);
271 auto max_conf_scores = std::get<0>(max_classes);
273 auto max_conf_classes = std::get<1>(max_classes);
275 max_conf_scores = max_conf_scores.to(torch::kFloat).unsqueeze(1);
276 max_conf_classes = max_conf_classes.to(torch::kFloat).unsqueeze(1);
278 {masked_detections.slice(1, 0, 4), max_conf_classes, max_conf_scores}, 1);
283 const double max_x_pixel =
286 const double max_y_pixel =
290 masked_detections.select(1, 0).le(max_x_pixel));
293 masked_detections.select(1, 1).le(max_y_pixel));
297 torch::masked_select(masked_detections, halo_mask).view({-1, 6});
300 auto world_scaled_detections =
303 auto world_scaled_detections_cpu = world_scaled_detections.cpu();
304 if (batch_idx == 0) {
305 all_world_scaled_detections = world_scaled_detections_cpu.cpu();
307 all_world_scaled_detections =
308 torch::cat({all_world_scaled_detections, world_scaled_detections_cpu}, 0).cpu();
311 timer->start_event_timer(
"Per-batch processing");
312 std::vector<Detection> processed_detections;
313 if (all_world_scaled_detections.size(0) == 0) {
314 return processed_detections;
317 torch::Tensor bboxes =
xywh2xyxy(all_world_scaled_detections.slice(1, 0, 4));
319 auto kept_bboxes_idxs =
320 nms_kernel(bboxes, all_world_scaled_detections.select(1, 5), iou_threshold);
322 timer->start_event_timer(
"Nms processing");
324 const int64_t num_kept_detections = kept_bboxes_idxs.size(0);
325 processed_detections.reserve(num_kept_detections);
327 const auto& kept_bboxes_idxs_accessor = kept_bboxes_idxs.accessor<int64_t, 1>();
328 const auto& detections_array_accessor =
329 all_world_scaled_detections.accessor<double, 2>();
331 for (int64_t detection_idx = 0; detection_idx < num_kept_detections; ++detection_idx) {
332 int64_t kept_detection_idx = kept_bboxes_idxs_accessor[detection_idx];
333 const auto& detection_array = detections_array_accessor[kept_detection_idx];
336 std::string class_label;
337 if (class_idx < num_class_labels) {
348 processed_detections.emplace_back(processed_detection);
350 timer->start_event_timer(
"Output processing");
351 return processed_detections;
355 const std::string& model_path,
358 const int64_t device_num,
359 std::vector<float>& raster_data,
361 const float min_confidence_threshold,
362 const float iou_threshold,
363 std::shared_ptr<CpuTimer> timer) {
366 const auto input_data_type = use_half ? torch::kHalf : torch::kFloat32;
367 const bool use_model_cache = use_gpu;
372 #ifdef HAVE_CUDA_TORCH
373 c10::cuda::OptionalCUDAGuard cuda_guard;
375 cuda_guard.set_index(static_cast<int8_t>(device_num));
379 c10::InferenceMode guard;
380 torch::NoGradGuard no_grad;
382 timer->start_event_timer(
"Model load");
385 load_module(model_path, compute_device, input_data_type, use_model_cache);
386 timer->start_event_timer(
"Input prep");
387 std::cout <<
"Device: " << compute_device <<
" Use half: " << use_half << std::endl;
389 std::cout <<
"X tiles: " << raster_info.
x_tiles <<
" Y tiles: " << raster_info.
y_tiles
390 <<
" Batch size: " << raster_info.
batch_tiles << std::endl;
393 torch::from_blob(raster_data.data(),
398 .to(compute_device, input_data_type);
400 std::vector<torch::jit::IValue> module_input;
401 module_input.emplace_back(input_tensor);
403 timer->start_event_timer(
"Inference");
404 torch::jit::IValue output = module->forward(module_input);
406 auto raw_detections = output.toTuple()->elements()[0].toTensor();
408 #ifdef HAVE_CUDA_TORCH
409 constexpr
bool enable_debug_timing{
true};
410 if (enable_debug_timing && use_gpu) {
411 std::cout <<
"Synchronizing timing" << std::endl;
412 c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
413 AT_CUDA_CHECK(cudaStreamSynchronize(stream));
417 const auto processed_detections =
419 min_confidence_threshold,
423 timer->start_nested_event_timer(
"process_detections"));
425 return processed_detections;
427 }
catch (std::exception& e) {
428 std::string error_msg{
"Error during model inference: "};
429 error_msg += e.what();
430 throw std::runtime_error(error_msg);
436 const int64_t device_num) {
439 const auto input_data_type = use_half ? torch::kHalf : torch::kFloat32;
443 load_module(model_path, compute_device, input_data_type, use_gpu );
444 const auto module_named_params = module->named_parameters(
true);
445 const size_t num_named_params = module_named_params.size();
446 std::cout <<
"Module # params: " << num_named_params << std::endl;
447 const size_t max_params_to_print{1000};
449 for (
const auto& param : module_named_params) {
450 std::cout << param.name << std::endl;
451 if (param_idx++ == max_params_to_print) {
455 const auto module_named_buffers = module->named_buffers(
true);
456 const size_t num_named_buffers = module_named_buffers.size();
457 std::cout <<
"Module # named buffers: " << num_named_buffers << std::endl;
458 const auto module_named_children = module->named_children();
459 const size_t num_named_children = module_named_children.size();
460 std::cout <<
"Module # named children: " << num_named_children << std::endl;
461 std::cout <<
"Finishing torch warmup" << std::endl;
462 }
catch (std::exception& e) {
463 std::string error_msg{
"Error fetching Torch model params: "};
464 error_msg += e.what();
465 std::cout << error_msg << std::endl;
473 if (json_str.size() > 0) {
483 const std::
string& model_path,
486 const int64_t device_num,
487 std::vector<
float>& raster_data,
488 const RasterFormat_Namespace::RasterInfo& raster_info,
489 const
float min_confidence_threshold,
490 const
float iou_threshold,
498 min_confidence_threshold,
505 const int64_t device_num) {
507 for (
size_t l = 0; l < model_info.
class_labels.size(); ++l) {
508 std::cout << l <<
": " << model_info.
class_labels[l] << std::endl;
517 #endif // #ifndef __CUDACC__
std::string get_json_str_from_file_header(const std::string &filename, const size_t max_search_chars)
torch::Tensor xywh2xyxy(const torch::Tensor &x)
std::shared_ptr< torch::jit::script::Module > get_model_from_cache(const std::string &model_path)
ModelInfo get_model_info_from_json(const std::string &json_str)
ModelInfo get_model_info_from_file(const std::string &filename)
void load(Archive &ar, ExplainedQueryHint &query_hint, const unsigned int version)
const std::string json_str(const rapidjson::Value &obj) noexcept
RUNTIME_EXPORT ALWAYS_INLINE int8_t logical_and(const int8_t lhs, const int8_t rhs, const int8_t null_val)
std::vector< Detection > detect_objects_in_tiled_raster(const std::string &model_path, const ModelInfo &model_info, const bool use_gpu, const int64_t device_num, std::vector< float > &raster_data, const RasterFormat_Namespace::RasterInfo &raster_info, const float min_confidence_threshold, const float iou_threshold, std::shared_ptr< CpuTimer > timer)
static std::unordered_map< std::string, std::shared_ptr< torch::jit::script::Module > > model_cache
torch::Tensor world_scale_detections(const torch::Tensor &input, const int64_t batch_idx, const RasterFormat_Namespace::RasterInfo &raster_info)
static std::shared_mutex model_mutex
std::vector< std::string > class_labels
int64_t raster_tile_width
at::Tensor nms_kernel(const at::Tensor &dets, const at::Tensor &scores, double iou_threshold)
std::vector< Detection > process_detections(const torch::Tensor &raw_detections, const float min_confidence_threshold, const float iou_threshold, const ModelInfo &model_info, const RasterFormat_Namespace::RasterInfo &raster_info, std::shared_ptr< CpuTimer > timer)
static bool warmup_torch(const std::string &model_path, const bool use_gpu, const int64_t device_num)
void add_model_to_cache(const std::string &model_path, std::shared_ptr< torch::jit::script::Module > model_module)
__attribute__((__used__)) ModelInfo get_model_info_from_file(const std
void print_model_params(const std::string &model_path, const bool use_gpu, const int64_t device_num)
std::vector< Detection > detect_objects_in_tiled_raster_impl(const std::string &model_path, const ModelInfo &model_info, const bool use_gpu, const int64_t device_num, std::vector< float > &raster_data, const RasterFormat_Namespace::RasterInfo &raster_info, const float min_confidence_threshold, const float iou_threshold, std::shared_ptr< CpuTimer > timer)
std::string get_device_string(const bool use_gpu, const int64_t device_num)
std::string filename(char const *path)
bool should_use_half(const bool use_gpu, const std::string &model_path)
std::shared_ptr< torch::jit::script::Module > load_module(const std::string &model_path, const std::string compute_device, const at::ScalarType data_type, const bool use_cache)
std::shared_timed_mutex shared_mutex
int64_t raster_tile_height