OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
EvaluateModelCommand Class Reference

#include <DdlCommandExecutor.h>

+ Inheritance diagram for EvaluateModelCommand:
+ Collaboration diagram for EvaluateModelCommand:

Public Member Functions

 EvaluateModelCommand (const DdlCommandData &ddl_data, std::shared_ptr< Catalog_Namespace::SessionInfo const > session_ptr)
 
ExecutionResult execute (bool read_only_mode) override
 
- Public Member Functions inherited from DdlCommand
 DdlCommand (const DdlCommandData &ddl_data, std::shared_ptr< Catalog_Namespace::SessionInfo const > session_ptr)
 

Additional Inherited Members

- Protected Attributes inherited from DdlCommand
const DdlCommandDataddl_data_
 
std::shared_ptr
< Catalog_Namespace::SessionInfo
const > 
session_ptr_
 

Detailed Description

Definition at line 317 of file DdlCommandExecutor.h.

Constructor & Destructor Documentation

EvaluateModelCommand::EvaluateModelCommand ( const DdlCommandData ddl_data,
std::shared_ptr< Catalog_Namespace::SessionInfo const >  session_ptr 
)

Definition at line 2383 of file DdlCommandExecutor.cpp.

References g_enable_ml_functions.

2386  : DdlCommand(ddl_data, session_ptr) {
2387  if (!g_enable_ml_functions) {
2388  throw std::runtime_error("Cannot evaluate model. ML functions are disabled.");
2389  }
2390 }
bool g_enable_ml_functions
Definition: Execute.cpp:122
DdlCommand(const DdlCommandData &ddl_data, std::shared_ptr< Catalog_Namespace::SessionInfo const > session_ptr)

Member Function Documentation

ExecutionResult EvaluateModelCommand::execute ( bool  read_only_mode)
overridevirtual

Executes the DDL command corresponding to provided JSON payload.

Parameters
_returnresult of DDL command execution (if applicable)

Implements DdlCommand.

Definition at line 2392 of file DdlCommandExecutor.cpp.

References CHECK_EQ, CHECK_LE, ResultSetLogicalValuesBuilder::create(), query_state::QueryState::create(), DdlCommand::ddl_data_, anonymous_namespace{DdlCommandExecutor.cpp}::extractPayload(), g_ml_models, anonymous_namespace{DdlCommandExecutor.cpp}::genLiteralDouble(), Parser::LocalQueryConnector::getColumnDescriptors(), legacylockmgr::getExecuteReadLock(), MLModelMap::getModelMetadata(), is_regression_model(), kDOUBLE, run_benchmark_import::label, Parser::LocalQueryConnector::query(), run_benchmark_import::result, and DdlCommand::session_ptr_.

Referenced by heavydb.cursor.Cursor::executemany().

2392  {
2393  auto execute_read_lock = legacylockmgr::getExecuteReadLock();
2394  auto& ddl_payload = extractPayload(ddl_data_);
2395  std::string model_name;
2396  std::string select_query;
2397  if (ddl_payload.HasMember("modelName")) {
2398  model_name = ddl_payload["modelName"].GetString();
2399  }
2400  if (ddl_payload.HasMember("query")) {
2401  select_query = ddl_payload["query"].GetString();
2402  }
2403  std::regex newline_re("\\n");
2404  std::regex backtick_re("`");
2405  select_query = std::regex_replace(select_query, newline_re, " ");
2406  select_query = std::regex_replace(select_query, backtick_re, "");
2407  const auto model_metadata = g_ml_models.getModelMetadata(model_name);
2408  const auto model_type = model_metadata.getModelType();
2409  if (!is_regression_model(model_type)) {
2410  const auto& model_type_str = model_metadata.getModelTypeStr();
2411  std::ostringstream error_oss;
2412  error_oss << "EVALUATE MODEL not supported for " << model_type_str << " models.";
2413  throw std::runtime_error(error_oss.str());
2414  }
2415 
2416  if (select_query.empty()) {
2417  const double data_split_eval_fraction = model_metadata.getDataSplitEvalFraction();
2418  CHECK_LE(data_split_eval_fraction, 1.0);
2419  if (data_split_eval_fraction <= 0.0) {
2420  throw std::runtime_error(
2421  "Unable to evaluate model: " + model_name +
2422  ". Model was not trained with a data split evaluation fraction.");
2423  }
2424  const auto& training_query = model_metadata.getTrainingQuery();
2425  const auto& feature_permutations = model_metadata.getFeaturePermutations();
2426  std::ostringstream select_query_oss;
2427  // To get a non-overlapping eval dataset (that does not overlap with the training
2428  // dataset), we need to use NOT SAMPLE_RATIO(training_fraction) and not
2429  // SAMPLE_RATIO(eval_fraction), as the latter will be a subset of the training
2430  // dataset
2431  const double data_split_train_fraction = 1.0 - data_split_eval_fraction;
2432  select_query_oss << "SELECT ";
2433  if (!feature_permutations.empty()) {
2434  select_query_oss << model_metadata.getPredicted() << ", ";
2435  const auto& features = model_metadata.getFeatures();
2436  for (const auto feature_permutation : feature_permutations) {
2437  select_query_oss << features[feature_permutation];
2438  if (feature_permutation != feature_permutations.back()) {
2439  select_query_oss << ", ";
2440  }
2441  }
2442  } else {
2443  select_query_oss << " * ";
2444  }
2445 
2446  select_query_oss << " FROM (" << training_query << ") WHERE NOT SAMPLE_RATIO("
2447  << data_split_train_fraction << ")";
2448  select_query = select_query_oss.str();
2449  } else {
2450  const auto& feature_permutations = model_metadata.getFeaturePermutations();
2451  if (!feature_permutations.empty()) {
2452  Parser::LocalQueryConnector local_connector;
2453  auto validate_query_state =
2455  auto validate_result = local_connector.query(
2456  validate_query_state->createQueryStateProxy(), select_query, {}, true, false);
2457  auto column_descriptors_list =
2458  local_connector.getColumnDescriptors(validate_result, true);
2459  std::vector<ColumnDescriptor> column_descriptors;
2460  for (auto& cd : column_descriptors_list) {
2461  column_descriptors.emplace_back(cd);
2462  }
2463  std::ostringstream select_query_oss;
2464  select_query_oss << "SELECT " << model_metadata.getPredicted() << ", ";
2465  for (const auto feature_permutation : feature_permutations) {
2466  select_query_oss << column_descriptors[feature_permutation + 1].columnName;
2467  if (feature_permutation != feature_permutations.back()) {
2468  select_query_oss << ", ";
2469  }
2470  }
2471  select_query_oss << " FROM (" << select_query << ")";
2472  select_query = select_query_oss.str();
2473  }
2474  }
2475  std::ostringstream r2_query_oss;
2476  r2_query_oss << "SELECT * FROM TABLE(r2_score(model_name => '" << model_name << "', "
2477  << "data => CURSOR(" << select_query << ")))";
2478  std::string r2_query = r2_query_oss.str();
2479 
2480  try {
2481  Parser::LocalQueryConnector local_connector;
2482  auto query_state = query_state::QueryState::create(session_ptr_, r2_query);
2483  auto result =
2484  local_connector.query(query_state->createQueryStateProxy(), r2_query, {}, false);
2485  std::vector<std::string> labels{"r2"};
2486  std::vector<TargetMetaInfo> label_infos;
2487  for (const auto& label : labels) {
2488  label_infos.emplace_back(label, SQLTypeInfo(kDOUBLE, true));
2489  }
2490  std::vector<RelLogicalValues::RowValues> logical_values;
2491  logical_values.emplace_back(RelLogicalValues::RowValues{});
2492 
2493  CHECK_EQ(result.size(), size_t(1));
2494  CHECK_EQ(result[0].rs->rowCount(), size_t(1));
2495  CHECK_EQ(result[0].rs->colCount(), size_t(1));
2496 
2497  auto result_row = result[0].rs->getNextRow(true, true);
2498 
2499  auto scalar_r = boost::get<ScalarTargetValue>(&result_row[0]);
2500  auto p = boost::get<double>(scalar_r);
2501 
2502  logical_values.back().emplace_back(genLiteralDouble(*p));
2503 
2504  std::shared_ptr<ResultSet> rSet = std::shared_ptr<ResultSet>(
2505  ResultSetLogicalValuesBuilder::create(label_infos, logical_values));
2506 
2507  return ExecutionResult(rSet, label_infos);
2508  } catch (const std::exception& e) {
2509  std::ostringstream error_oss;
2510  // Error messages from table functions come back like this:
2511  // Error executing table function: MLTableFunctions.hpp:1416 r2_score_impl: No
2512  // rows exist in evaluation data. Evaluation data must at least contain 1 row.
2513 
2514  // We want to take everything after the function name, so we will search for the
2515  // third colon. Todo(todd): Look at making this less hacky by setting a mode for the
2516  // table function that will return only the core error string and not the
2517  // preprending metadata
2518 
2519  auto get_error_substring = [](const std::string& message) -> std::string {
2520  size_t colon_position = std::string::npos;
2521  for (int i = 0; i < 3; ++i) {
2522  colon_position = message.find(':', colon_position + 1);
2523  if (colon_position == std::string::npos) {
2524  return message;
2525  }
2526  }
2527 
2528  if (colon_position + 2 >= message.length()) {
2529  return message;
2530  }
2531  return message.substr(colon_position + 2);
2532  };
2533 
2534  const auto error_substr = get_error_substring(e.what());
2535  error_oss << "Could not evaluate model " << model_name << ". " << error_substr;
2536  throw std::runtime_error(error_oss.str());
2537  }
2538 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
auto getExecuteReadLock()
std::unique_ptr< RexLiteral > genLiteralDouble(double val)
std::vector< MLModelMetadata > getModelMetadata() const
Definition: MLModel.h:84
static std::shared_ptr< QueryState > create(ARGS &&...args)
Definition: QueryState.h:148
AggregatedResult query(QueryStateProxy, std::string &sql_query_string, std::vector< size_t > outer_frag_indices, bool validate_only, bool allow_interrupt)
const DdlCommandData & ddl_data_
const rapidjson::Value & extractPayload(const DdlCommandData &ddl_data)
static ResultSet * create(std::vector< TargetMetaInfo > &label_infos, std::vector< RelLogicalValues::RowValues > &logical_values)
MLModelMap g_ml_models
Definition: MLModel.h:125
#define CHECK_LE(x, y)
Definition: Logger.h:304
bool is_regression_model(const MLModelType model_type)
Definition: MLModelType.h:69
std::list< ColumnDescriptor > getColumnDescriptors(AggregatedResult &result, bool for_create)
std::shared_ptr< Catalog_Namespace::SessionInfo const > session_ptr_
std::vector< std::unique_ptr< const RexScalar >> RowValues
Definition: RelAlgDag.h:2656

+ Here is the call graph for this function:

+ Here is the caller graph for this function:


The documentation for this class was generated from the following files: