33 #include "oneapi/dal/algo/decision_forest.hpp"
38 void addModel(
const std::string& model_name, std::shared_ptr<AbstractMLModel> model) {
39 const auto upper_model_name =
to_upper(model_name);
45 const auto upper_model_name =
to_upper(model_name);
47 auto model_map_itr =
model_map_.find(upper_model_name);
51 std::shared_ptr<AbstractMLModel>
getModel(
const std::string& model_name)
const {
52 const auto upper_model_name =
to_upper(model_name);
54 auto model_map_itr =
model_map_.find(upper_model_name);
56 return model_map_itr->second;
58 const std::string error_str =
"Model '" + upper_model_name +
"' does not exist.";
59 throw std::runtime_error(error_str);
63 const auto upper_model_name =
to_upper(model_name);
65 auto const model_it =
model_map_.find(upper_model_name);
67 std::ostringstream error_oss;
68 error_oss <<
"Cannot erase model " << upper_model_name
69 <<
". No model by that name was found.";
70 throw std::runtime_error(error_oss.str());
77 std::vector<std::string> model_names;
80 model_names.emplace_back(model.first);
86 std::vector<MLModelMetadata> model_metadata;
90 model.second->getModelType(),
91 model.second->getModelTypeString(),
92 model.second->getNumLogicalFeatures(),
93 model.second->getNumFeatures(),
94 model.second->getNumCatFeatures(),
95 model.second->getNumLogicalFeatures() - model.second->getNumCatFeatures(),
96 model.second->getModelMetadataStr()));
98 return model_metadata;
102 const auto upper_model_name =
to_upper(model_name);
104 auto model_map_itr =
model_map_.find(upper_model_name);
107 model_map_itr->second->getModelType(),
108 model_map_itr->second->getModelTypeString(),
109 model_map_itr->second->getNumLogicalFeatures(),
110 model_map_itr->second->getNumFeatures(),
111 model_map_itr->second->getNumCatFeatures(),
112 model_map_itr->second->getNumLogicalFeatures() -
113 model_map_itr->second->getNumCatFeatures(),
114 model_map_itr->second->getModelMetadataStr());
116 const std::string error_str =
"Model '" + upper_model_name +
"' does not exist.";
117 throw std::runtime_error(error_str);
121 std::map<std::string, std::shared_ptr<AbstractMLModel>>
model_map_;
130 const std::string& model_metadata)
134 const std::string& model_metadata,
135 const std::vector<std::vector<std::string>>& cat_feature_keys)
143 return static_cast<int64_t
>(
coefs_.size()) - 1;
154 class TreeModelVisitor;
163 virtual void traverseDF(
const int64_t tree_idx,
164 TreeModelVisitor& tree_node_visitor)
const = 0;
169 using namespace daal::algorithms;
170 using namespace daal::data_management;
172 namespace df = oneapi::dal::decision_forest;
174 class TreeModelVisitor :
public daal::algorithms::regression::TreeNodeVisitor {
176 TreeModelVisitor(std::vector<DecisionTreeEntry>& decision_table)
177 : decision_table_(decision_table) {}
179 const std::vector<DecisionTreeEntry>& getDecisionTable()
const {
180 return decision_table_;
183 bool onLeafNode(
size_t level,
double response)
override {
185 if (last_node_leaf_) {
186 decision_table_[parent_nodes_.top()].right_child_row_idx =
187 static_cast<int64_t
>(decision_table_.size() - 1);
190 last_node_leaf_ =
true;
194 bool onSplitNode(
size_t level,
size_t featureIndex,
double featureValue)
override {
195 decision_table_.emplace_back(
197 static_cast<int64_t>(featureIndex),
198 static_cast<int64_t>(decision_table_.size() + 1)));
199 if (last_node_leaf_) {
200 decision_table_[parent_nodes_.top()].right_child_row_idx =
201 static_cast<int64_t
>(decision_table_.size() - 1);
204 last_node_leaf_ =
false;
205 parent_nodes_.emplace(decision_table_.size() - 1);
209 bool operator()(
const df::leaf_node_info<df::task::regression>& info) {
211 if (last_node_leaf_) {
212 decision_table_[parent_nodes_.top()].right_child_row_idx =
213 static_cast<int64_t
>(decision_table_.size() - 1);
216 last_node_leaf_ =
true;
220 bool operator()(
const df::split_node_info<df::task::regression>& info) {
221 decision_table_.emplace_back(
223 static_cast<int64_t
>(info.get_feature_index()),
224 static_cast<int64_t>(decision_table_.size() + 1)));
225 if (last_node_leaf_) {
226 decision_table_[parent_nodes_.top()].right_child_row_idx =
227 static_cast<int64_t
>(decision_table_.size() - 1);
230 last_node_leaf_ =
false;
231 parent_nodes_.emplace(decision_table_.size() - 1);
236 std::vector<DecisionTreeEntry>& decision_table_;
237 std::stack<size_t> parent_nodes_;
238 bool last_node_leaf_{
false};
243 DecisionTreeRegressionModel(decision_tree::regression::interface1::ModelPtr& model_ptr,
244 const std::string& model_metadata)
246 DecisionTreeRegressionModel(
247 decision_tree::regression::interface1::ModelPtr& model_ptr,
248 const std::string& model_metadata,
249 const std::vector<std::vector<std::string>>& cat_feature_keys)
250 :
AbstractMLModel(model_metadata, cat_feature_keys), model_ptr_(model_ptr) {}
252 virtual MLModelType getModelType()
const override {
256 virtual std::string getModelTypeString()
const override {
257 return "Decision Tree Regression";
260 virtual int64_t getNumFeatures()
const override {
261 return model_ptr_->getNumberOfFeatures();
263 virtual int64_t getNumTrees()
const override {
return 1; }
264 virtual void traverseDF(
const int64_t tree_idx,
265 TreeModelVisitor& tree_node_visitor)
const override {
267 model_ptr_->traverseDF(tree_node_visitor);
269 const decision_tree::regression::interface1::ModelPtr getModelPtr()
const {
274 decision_tree::regression::interface1::ModelPtr model_ptr_;
279 GbtRegressionModel(gbt::regression::interface1::ModelPtr& model_ptr,
280 const std::string& model_metadata)
283 GbtRegressionModel(gbt::regression::interface1::ModelPtr& model_ptr,
284 const std::string& model_metadata,
285 const std::vector<std::vector<std::string>>& cat_feature_keys)
286 :
AbstractMLModel(model_metadata, cat_feature_keys), model_ptr_(model_ptr) {}
290 virtual std::string getModelTypeString()
const override {
291 return "Gradient Boosted Trees Regression";
294 virtual int64_t getNumFeatures()
const override {
295 return model_ptr_->getNumberOfFeatures();
297 virtual int64_t getNumTrees()
const override {
return model_ptr_->getNumberOfTrees(); }
298 virtual void traverseDF(
const int64_t tree_idx,
299 TreeModelVisitor& tree_node_visitor)
const override {
300 model_ptr_->traverseDF(tree_idx, tree_node_visitor);
302 const gbt::regression::interface1::ModelPtr getModelPtr()
const {
return model_ptr_; }
305 gbt::regression::interface1::ModelPtr model_ptr_;
310 virtual const std::vector<double>& getVariableImportanceScores()
const = 0;
311 virtual const double getOutOfBagError()
const = 0;
314 class RandomForestRegressionModel :
public virtual AbstractRandomForestModel {
316 RandomForestRegressionModel(
317 decision_forest::regression::interface1::ModelPtr& model_ptr,
318 const std::string& model_metadata,
319 const std::vector<double>& variable_importance,
320 const double out_of_bag_error)
322 , model_ptr_(model_ptr)
323 , variable_importance_(variable_importance)
324 , out_of_bag_error_(out_of_bag_error) {}
326 RandomForestRegressionModel(
327 decision_forest::regression::interface1::ModelPtr& model_ptr,
328 const std::string& model_metadata,
329 const std::vector<std::vector<std::string>>& cat_feature_keys,
330 const std::vector<double>& variable_importance,
331 const double out_of_bag_error)
333 , model_ptr_(model_ptr)
334 , variable_importance_(variable_importance)
335 , out_of_bag_error_(out_of_bag_error) {}
337 virtual MLModelType getModelType()
const override {
341 virtual std::string getModelTypeString()
const override {
342 return "Random Forest Regression";
344 virtual int64_t getNumFeatures()
const override {
345 return model_ptr_->getNumberOfFeatures();
347 virtual int64_t getNumTrees()
const override {
return model_ptr_->getNumberOfTrees(); }
348 virtual void traverseDF(
const int64_t tree_idx,
349 TreeModelVisitor& tree_node_visitor)
const override {
350 model_ptr_->traverseDF(tree_idx, tree_node_visitor);
353 virtual const std::vector<double>& getVariableImportanceScores()
const override {
354 return variable_importance_;
357 virtual const double getOutOfBagError()
const override {
return out_of_bag_error_; }
359 const decision_forest::regression::interface1::ModelPtr getModelPtr()
const {
364 decision_forest::regression::interface1::ModelPtr model_ptr_;
365 std::vector<double> variable_importance_;
366 double out_of_bag_error_;
369 class OneAPIRandomForestRegressionModel :
public virtual AbstractRandomForestModel {
371 OneAPIRandomForestRegressionModel(
372 const std::shared_ptr<
const df::model<df::task::regression>> model,
373 const std::string& model_metadata,
374 const std::vector<double>& variable_importance,
375 const double out_of_bag_error,
376 const int64_t num_features)
378 , model_(std::move(model))
379 , variable_importance_(variable_importance)
380 , out_of_bag_error_(out_of_bag_error)
381 , num_features_(num_features) {}
383 OneAPIRandomForestRegressionModel(
384 const std::shared_ptr<
const df::model<df::task::regression>> model,
385 const std::string& model_metadata,
386 const std::vector<std::vector<std::string>>& cat_feature_keys,
387 const std::vector<double>& variable_importance,
388 const double out_of_bag_error,
389 const int64_t num_features)
391 , model_(std::move(model))
392 , variable_importance_(variable_importance)
393 , out_of_bag_error_(out_of_bag_error)
394 , num_features_(num_features) {}
396 virtual MLModelType getModelType()
const override {
400 virtual std::string getModelTypeString()
const override {
401 return "Random Forest Regression";
403 virtual int64_t getNumFeatures()
const override {
return num_features_; }
404 virtual int64_t getNumTrees()
const override {
return model_->get_tree_count(); }
405 virtual void traverseDF(
const int64_t tree_idx,
406 TreeModelVisitor& tree_node_visitor)
const override {
407 model_->traverse_depth_first(tree_idx, tree_node_visitor);
410 virtual const std::vector<double>& getVariableImportanceScores()
const override {
411 return variable_importance_;
414 virtual const double getOutOfBagError()
const override {
return out_of_bag_error_; }
416 const std::shared_ptr<const df::model<df::task::regression>> getModel()
const {
421 const std::shared_ptr<const df::model<df::task::regression>> model_;
422 std::vector<double> variable_importance_;
423 double out_of_bag_error_;
424 int64_t num_features_;
427 #endif // #ifdef HAVE_ONEDAL
432 const std::vector<double>& col_std_devs,
433 const std::vector<std::vector<double>>& eigenvectors,
434 const std::vector<double>& eigenvalues,
435 const std::string& model_metadata)
437 , col_means_(col_means)
438 , col_std_devs_(col_std_devs)
439 , eigenvectors_(eigenvectors)
440 , eigenvalues_(eigenvalues) {}
443 const std::vector<double>& col_std_devs,
444 const std::vector<std::vector<double>>& eigenvectors,
445 const std::vector<double>& eigenvalues,
446 const std::string& model_metadata,
447 const std::vector<std::vector<std::string>>& cat_feature_keys)
449 , col_means_(col_means)
450 , col_std_devs_(col_std_devs)
451 , eigenvectors_(eigenvectors)
452 , eigenvalues_(eigenvalues) {}
459 return static_cast<int64_t
>(col_means_.size());
465 return eigenvectors_;
476 #endif // #ifndef __CUDACC__
PcaModel(const std::vector< double > &col_means, const std::vector< double > &col_std_devs, const std::vector< std::vector< double >> &eigenvectors, const std::vector< double > &eigenvalues, const std::string &model_metadata)
const std::vector< double > & getColumnStdDevs() const
virtual std::string getModelTypeString() const override
virtual int64_t getNumTrees() const =0
virtual void traverseDF(const int64_t tree_idx, TreeModelVisitor &tree_node_visitor) const =0
virtual int64_t getNumFeatures() const override
std::vector< MLModelMetadata > getModelMetadata() const
virtual std::string getModelTypeString() const override
std::vector< double > eigenvalues_
LinearRegressionModel(const std::vector< double > &coefs, const std::string &model_metadata)
virtual MLModelType getModelType() const =0
virtual MLModelType getModelType() const override
std::vector< double > col_std_devs_
std::vector< double > col_means_
void addModel(const std::string &model_name, std::shared_ptr< AbstractMLModel > model)
bool modelExists(const std::string &model_name) const
std::shared_ptr< AbstractMLModel > getModel(const std::string &model_name) const
virtual std::string getModelTypeString() const =0
LinearRegressionModel(const std::vector< double > &coefs, const std::string &model_metadata, const std::vector< std::vector< std::string >> &cat_feature_keys)
MLModelMetadata getModelMetadata(const std::string &model_name) const
std::shared_mutex model_map_mutex_
PcaModel(const std::vector< double > &col_means, const std::vector< double > &col_std_devs, const std::vector< std::vector< double >> &eigenvectors, const std::vector< double > &eigenvalues, const std::string &model_metadata, const std::vector< std::vector< std::string >> &cat_feature_keys)
void deleteModel(const std::string &model_name)
std::vector< double > coefs_
const std::vector< double > & getColumnMeans() const
const std::vector< double > & getCoefs() const
const std::vector< double > & getEigenvalues() const
virtual MLModelType getModelType() const override
std::vector< std::string > getModelNames() const
virtual int64_t getNumFeatures() const override
std::shared_timed_mutex shared_mutex
virtual ~AbstractTreeModel()=default
virtual int64_t getNumFeatures() const =0
std::map< std::string, std::shared_ptr< AbstractMLModel > > model_map_
std::vector< std::vector< double > > eigenvectors_
const std::vector< std::vector< double > > & getEigenvectors() const