OmniSciDB
a5dc49c757
|
#include "TorchWrapper.h"
#include "Shared/funcannotations.h"
#include "TorchOps.hpp"
#include <torch/script.h>
#include <torch/torch.h>
#include <chrono>
#include <cmath>
#include <fstream>
#include <iostream>
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "rapidjson/document.h"
Go to the source code of this file.
Enumerations | |
enum | DetectionIdx { centroid_x = 0, centroid_y = 1, width = 2, height = 3, class_idx = 4, score = 5 } |
enum | BoxDetectionIdx { tl_x = 0, tl_y = 1, br_x = 2, br_y = 3 } |
Functions | |
std::string | get_device_string (const bool use_gpu, const int64_t device_num) |
bool | should_use_half (const bool use_gpu, const std::string &model_path) |
std::shared_ptr < torch::jit::script::Module > | get_model_from_cache (const std::string &model_path) |
void | add_model_to_cache (const std::string &model_path, std::shared_ptr< torch::jit::script::Module > model_module) |
std::shared_ptr < torch::jit::script::Module > | load_module (const std::string &model_path, const std::string compute_device, const at::ScalarType data_type, const bool use_cache) |
std::string | get_json_str_from_file_header (const std::string &filename, const size_t max_search_chars) |
ModelInfo | get_model_info_from_json (const std::string &json_str) |
torch::Tensor | xywh2xyxy (const torch::Tensor &x) |
torch::Tensor | world_scale_detections (const torch::Tensor &input, const int64_t batch_idx, const RasterFormat_Namespace::RasterInfo &raster_info) |
std::vector< Detection > | process_detections (const torch::Tensor &raw_detections, const float min_confidence_threshold, const float iou_threshold, const ModelInfo &model_info, const RasterFormat_Namespace::RasterInfo &raster_info, std::shared_ptr< CpuTimer > timer) |
std::vector< Detection > | detect_objects_in_tiled_raster_impl (const std::string &model_path, const ModelInfo &model_info, const bool use_gpu, const int64_t device_num, std::vector< float > &raster_data, const RasterFormat_Namespace::RasterInfo &raster_info, const float min_confidence_threshold, const float iou_threshold, std::shared_ptr< CpuTimer > timer) |
void | print_model_params (const std::string &model_path, const bool use_gpu, const int64_t device_num) |
__attribute__ ((__used__)) ModelInfo get_model_info_from_file(const std | |
Variables | |
static std::unordered_map < std::string, std::shared_ptr < torch::jit::script::Module > > | model_cache |
static std::shared_mutex | model_mutex |
enum BoxDetectionIdx |
Enumerator | |
---|---|
tl_x | |
tl_y | |
br_x | |
br_y |
Definition at line 188 of file TorchWrapper.cpp.
enum DetectionIdx |
Enumerator | |
---|---|
centroid_x | |
centroid_y | |
width | |
height | |
class_idx | |
score |
Definition at line 179 of file TorchWrapper.cpp.
__attribute__ | ( | (__used__) | ) | const |
Definition at line 469 of file TorchWrapper.cpp.
References get_json_str_from_file_header(), get_model_info_from_json(), ModelInfo::is_valid, and json_str().
void add_model_to_cache | ( | const std::string & | model_path, |
std::shared_ptr< torch::jit::script::Module > | model_module | ||
) |
Definition at line 76 of file TorchWrapper.cpp.
References model_cache, and model_mutex.
Referenced by load_module().
std::vector<Detection> detect_objects_in_tiled_raster_impl | ( | const std::string & | model_path, |
const ModelInfo & | model_info, | ||
const bool | use_gpu, | ||
const int64_t | device_num, | ||
std::vector< float > & | raster_data, | ||
const RasterFormat_Namespace::RasterInfo & | raster_info, | ||
const float | min_confidence_threshold, | ||
const float | iou_threshold, | ||
std::shared_ptr< CpuTimer > | timer | ||
) |
Definition at line 354 of file TorchWrapper.cpp.
References RasterFormat_Namespace::RasterInfo::batch_tiles, get_device_string(), load_module(), process_detections(), RasterFormat_Namespace::RasterInfo::raster_channels, should_use_half(), RasterFormat_Namespace::RasterInfo::x_pixels_per_tile, RasterFormat_Namespace::RasterInfo::x_tiles, RasterFormat_Namespace::RasterInfo::y_pixels_per_tile, and RasterFormat_Namespace::RasterInfo::y_tiles.
std::string get_device_string | ( | const bool | use_gpu, |
const int64_t | device_num | ||
) |
Definition at line 42 of file TorchWrapper.cpp.
References to_string().
Referenced by detect_objects_in_tiled_raster_impl(), and print_model_params().
std::string get_json_str_from_file_header | ( | const std::string & | filename, |
const size_t | max_search_chars | ||
) |
Definition at line 112 of file TorchWrapper.cpp.
References json_str().
Referenced by __attribute__().
std::shared_ptr<torch::jit::script::Module> get_model_from_cache | ( | const std::string & | model_path | ) |
Definition at line 66 of file TorchWrapper.cpp.
References model_cache, and model_mutex.
Referenced by load_module().
ModelInfo get_model_info_from_json | ( | const std::string & | json_str | ) |
Definition at line 147 of file TorchWrapper.cpp.
References ModelInfo::batch_size, ModelInfo::class_labels, ModelInfo::is_valid, run_benchmark_import::label, ModelInfo::raster_channels, ModelInfo::raster_tile_height, ModelInfo::raster_tile_width, and ModelInfo::stride.
Referenced by __attribute__().
std::shared_ptr<torch::jit::script::Module> load_module | ( | const std::string & | model_path, |
const std::string | compute_device, | ||
const at::ScalarType | data_type, | ||
const bool | use_cache | ||
) |
Definition at line 82 of file TorchWrapper.cpp.
References add_model_to_cache(), get_model_from_cache(), and boost::serialization::load().
Referenced by detect_objects_in_tiled_raster_impl(), and print_model_params().
void print_model_params | ( | const std::string & | model_path, |
const bool | use_gpu, | ||
const int64_t | device_num | ||
) |
Definition at line 434 of file TorchWrapper.cpp.
References get_device_string(), load_module(), and should_use_half().
Referenced by TorchWarmer::warmup_torch().
std::vector<Detection> process_detections | ( | const torch::Tensor & | raw_detections, |
const float | min_confidence_threshold, | ||
const float | iou_threshold, | ||
const ModelInfo & | model_info, | ||
const RasterFormat_Namespace::RasterInfo & | raster_info, | ||
std::shared_ptr< CpuTimer > | timer | ||
) |
Definition at line 237 of file TorchWrapper.cpp.
References cat(), centroid_x, centroid_y, class_idx, ModelInfo::class_labels, RasterFormat_Namespace::RasterInfo::halo_x_pixels_per_tile_boundary, RasterFormat_Namespace::RasterInfo::halo_y_pixels_per_tile_boundary, height, logical_and(), nms_kernel(), score, width, world_scale_detections(), RasterFormat_Namespace::RasterInfo::x_pixels_per_tile, RasterFormat_Namespace::RasterInfo::x_tiles, xywh2xyxy(), RasterFormat_Namespace::RasterInfo::y_pixels_per_tile, and RasterFormat_Namespace::RasterInfo::y_tiles.
Referenced by detect_objects_in_tiled_raster_impl().
bool should_use_half | ( | const bool | use_gpu, |
const std::string & | model_path | ||
) |
Definition at line 52 of file TorchWrapper.cpp.
Referenced by detect_objects_in_tiled_raster_impl(), and print_model_params().
torch::Tensor world_scale_detections | ( | const torch::Tensor & | input, |
const int64_t | batch_idx, | ||
const RasterFormat_Namespace::RasterInfo & | raster_info | ||
) |
Definition at line 206 of file TorchWrapper.cpp.
References centroid_x, centroid_y, class_idx, RasterFormat_Namespace::RasterInfo::halo_x_pixels_per_tile_boundary, RasterFormat_Namespace::RasterInfo::halo_y_pixels_per_tile_boundary, height, RasterFormat_Namespace::RasterInfo::logical_x_pixels_per_tile, RasterFormat_Namespace::RasterInfo::logical_y_pixels_per_tile, RasterFormat_Namespace::RasterInfo::min_x_input, RasterFormat_Namespace::RasterInfo::min_y_input, score, width, RasterFormat_Namespace::RasterInfo::x_input_units_per_pixel, RasterFormat_Namespace::RasterInfo::x_tiles, and RasterFormat_Namespace::RasterInfo::y_input_units_per_pixel.
Referenced by process_detections().
torch::Tensor xywh2xyxy | ( | const torch::Tensor & | x | ) |
Definition at line 195 of file TorchWrapper.cpp.
References br_x, br_y, tl_x, and tl_y.
Referenced by process_detections().
|
static |
Definition at line 63 of file TorchWrapper.cpp.
Referenced by add_model_to_cache(), and get_model_from_cache().
|
static |
Definition at line 64 of file TorchWrapper.cpp.
Referenced by add_model_to_cache(), and get_model_from_cache().