OmniSciDB  a5dc49c757
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Parser::CreateModelStmt Class Reference

#include <ParserNode.h>

+ Inheritance diagram for Parser::CreateModelStmt:
+ Collaboration diagram for Parser::CreateModelStmt:

Public Member Functions

 CreateModelStmt (const rapidjson::Value &payload)
 
const std::string & get_model_name () const
 
const std::string & get_select_query () const
 
void execute (const Catalog_Namespace::SessionInfo &session, bool read_only_mode) override
 
void train_model (const Catalog_Namespace::SessionInfo &session)
 
- Public Member Functions inherited from Parser::DDLStmt
void setColumnDescriptor (ColumnDescriptor &cd, const ColumnDef *coldef)
 
- Public Member Functions inherited from Parser::Node
virtual ~Node ()
 

Private Member Functions

bool check_model_exists ()
 
void parse_model_options ()
 
std::string build_model_query (const std::shared_ptr< Catalog_Namespace::SessionInfo > session_ptr)
 

Private Attributes

MLModelType model_type_
 
std::string model_name_
 
std::string select_query_
 
bool replace_
 
bool if_not_exists_
 
std::list< std::unique_ptr
< NameValueAssign > > 
model_options_
 
std::ostringstream options_oss_
 
size_t num_options_ {0}
 
double data_split_train_fraction_ {1.0}
 
double data_split_eval_fraction_ {0.0}
 
std::string model_predicted_var_
 
std::vector< std::string > model_feature_vars_
 
std::vector< int64_t > feature_permutations_
 

Detailed Description

Definition at line 1959 of file ParserNode.h.

Constructor & Destructor Documentation

Parser::CreateModelStmt::CreateModelStmt ( const rapidjson::Value &  payload)

Definition at line 3455 of file ParserNode.cpp.

References CHECK, g_enable_ml_functions, get_ml_model_type_from_str(), if_not_exists_, json_bool(), json_str(), model_name_, model_options_, model_type_, Parser::anonymous_namespace{ParserNode.cpp}::parse_options(), replace_, and select_query_.

3455  {
3456  if (!g_enable_ml_functions) {
3457  throw std::runtime_error("Cannot create model. ML functions are disabled.");
3458  }
3459  CHECK(payload.HasMember("name"));
3460  const std::string model_type_str = json_str(payload["type"]);
3461  model_type_ = get_ml_model_type_from_str(model_type_str);
3462  model_name_ = json_str(payload["name"]);
3463  replace_ = false;
3464  if (payload.HasMember("replace")) {
3465  replace_ = json_bool(payload["replace"]);
3466  }
3467 
3468  if_not_exists_ = false;
3469  if (payload.HasMember("ifNotExists")) {
3470  if_not_exists_ = json_bool(payload["ifNotExists"]);
3471  }
3472 
3473  CHECK(payload.HasMember("query"));
3474  select_query_ = json_str(payload["query"]);
3475  std::regex newline_re("\\n");
3476  std::regex backtick_re("`");
3477  select_query_ = std::regex_replace(select_query_, newline_re, " ");
3478  select_query_ = std::regex_replace(select_query_, backtick_re, "");
3479 
3480  // No need to ensure trailing semicolon as we will wrap this select statement
3481  // in a CURSOR as input to the train model table function
3482  parse_options(payload, model_options_);
3483 }
std::list< std::unique_ptr< NameValueAssign > > model_options_
Definition: ParserNode.h:1975
const bool json_bool(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:51
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:46
std::string select_query_
Definition: ParserNode.h:1972
void parse_options(const rapidjson::Value &payload, std::list< std::unique_ptr< NameValueAssign >> &nameValueList, bool stringToNull=false, bool stringToInteger=false)
bool g_enable_ml_functions
Definition: Execute.cpp:122
#define CHECK(condition)
Definition: Logger.h:291
MLModelType get_ml_model_type_from_str(const std::string &model_type_str)
Definition: MLModelType.h:52

+ Here is the call graph for this function:

Member Function Documentation

std::string Parser::CreateModelStmt::build_model_query ( const std::shared_ptr< Catalog_Namespace::SessionInfo session_ptr)
private

Definition at line 3659 of file ParserNode.cpp.

References query_state::QueryState::create(), data_split_train_fraction_, feature_permutations_, Parser::LocalQueryConnector::getColumnDescriptors(), is_regression_model(), model_feature_vars_, model_predicted_var_, model_type_, Parser::LocalQueryConnector::query(), and select_query_.

Referenced by train_model().

3660  {
3661  auto validate_query_state = query_state::QueryState::create(session_ptr, select_query_);
3662 
3663  LocalQueryConnector local_connector;
3664 
3665  auto validate_result = local_connector.query(
3666  validate_query_state->createQueryStateProxy(), select_query_, {}, true, false);
3667 
3668  auto column_descriptors_for_model_create =
3669  local_connector.getColumnDescriptors(validate_result, true);
3670 
3671  std::vector<size_t> categorical_feature_idxs;
3672  std::vector<size_t> numeric_feature_idxs;
3673  bool numeric_feature_seen = false;
3674  bool all_categorical_features_placed_first = true;
3675  bool model_has_predicted_var = is_regression_model(model_type_);
3676  model_feature_vars_.reserve(column_descriptors_for_model_create.size() -
3677  (model_has_predicted_var ? 1 : 0));
3678  bool is_predicted = model_has_predicted_var ? true : false;
3679  size_t feature_idx = 0;
3680  for (auto& cd : column_descriptors_for_model_create) {
3681  // Check to see if the projected column is an expression without a user-provided
3682  // alias, as we don't allow this.
3683  if (cd.columnName.rfind("EXPR$", 0) == 0) {
3684  throw std::runtime_error(
3685  "All projected expressions (i.e. col * 2) that are not column references (i.e. "
3686  "col) must be aliased.");
3687  }
3688  if (is_predicted) {
3689  model_predicted_var_ = cd.columnName;
3690  if (!cd.columnType.is_number()) {
3691  throw std::runtime_error(
3692  "Numeric predicted column expression should be first argument to CREATE "
3693  "MODEL.");
3694  }
3695  is_predicted = false;
3696  } else {
3697  if (cd.columnType.is_number()) {
3698  numeric_feature_idxs.emplace_back(feature_idx);
3699  numeric_feature_seen = true;
3700  } else if (cd.columnType.is_string()) {
3701  categorical_feature_idxs.emplace_back(feature_idx);
3702  if (numeric_feature_seen) {
3703  all_categorical_features_placed_first = false;
3704  }
3705  } else {
3706  throw std::runtime_error("Feature column expression should be numeric or TEXT.");
3707  }
3708  model_feature_vars_.emplace_back(cd.columnName);
3709  feature_idx++;
3710  }
3711  }
3712  auto modified_select_query = select_query_;
3713  if (!all_categorical_features_placed_first) {
3714  std::ostringstream modified_query_oss;
3715  modified_query_oss << "SELECT ";
3716  if (model_has_predicted_var) {
3717  modified_query_oss << model_predicted_var_ << ", ";
3718  }
3719  for (auto categorical_feature_idx : categorical_feature_idxs) {
3720  modified_query_oss << model_feature_vars_[categorical_feature_idx] << ", ";
3721  feature_permutations_.emplace_back(static_cast<int64_t>(categorical_feature_idx));
3722  }
3723  for (auto numeric_feature_idx : numeric_feature_idxs) {
3724  modified_query_oss << model_feature_vars_[numeric_feature_idx];
3725  feature_permutations_.emplace_back(static_cast<int64_t>(numeric_feature_idx));
3726  if (numeric_feature_idx != numeric_feature_idxs.back()) {
3727  modified_query_oss << ", ";
3728  }
3729  }
3730  modified_query_oss << " FROM (" << modified_select_query << ")";
3731  modified_select_query = modified_query_oss.str();
3732  }
3733 
3734  if (data_split_train_fraction_ < 1.0) {
3735  std::ostringstream modified_query_oss;
3736  if (all_categorical_features_placed_first) {
3737  modified_query_oss << "SELECT * FROM (" << modified_select_query << ")";
3738  } else {
3739  modified_query_oss << modified_select_query;
3740  }
3741  modified_query_oss << " WHERE SAMPLE_RATIO(" << data_split_train_fraction_ << ")";
3742  modified_select_query = modified_query_oss.str();
3743  }
3744  return modified_select_query;
3745 }
static std::shared_ptr< QueryState > create(ARGS &&...args)
Definition: QueryState.h:148
std::string model_predicted_var_
Definition: ParserNode.h:1980
std::vector< std::string > model_feature_vars_
Definition: ParserNode.h:1981
std::vector< int64_t > feature_permutations_
Definition: ParserNode.h:1982
std::string select_query_
Definition: ParserNode.h:1972
bool is_regression_model(const MLModelType model_type)
Definition: MLModelType.h:69

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

bool Parser::CreateModelStmt::check_model_exists ( )
private

Definition at line 3545 of file ParserNode.cpp.

References g_ml_models, get_model_name(), if_not_exists_, MLModelMap::modelExists(), and replace_.

Referenced by train_model().

3545  {
3547  if (if_not_exists_) {
3548  // Returning true tells the caller we should just return early and silently (without
3549  // error)
3550  return true;
3551  }
3552  if (!replace_) {
3553  std::ostringstream error_oss;
3554  error_oss << "Model " << get_model_name() << " already exists.";
3555  throw std::runtime_error(error_oss.str());
3556  }
3557  }
3558  // Returning false tells the caller all is clear to proceed with the create model,
3559  // whether that means creating a new one or overwriting an existing model
3560  return false;
3561 }
const std::string & get_model_name() const
Definition: ParserNode.h:1963
bool modelExists(const std::string &model_name) const
Definition: MLModel.h:44
MLModelMap g_ml_models
Definition: MLModel.h:125

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void Parser::CreateModelStmt::execute ( const Catalog_Namespace::SessionInfo session,
bool  read_only_mode 
)
overridevirtual

Implements Parser::DDLStmt.

Definition at line 3806 of file ParserNode.cpp.

References model_name_, and train_model().

Referenced by heavydb.cursor.Cursor::executemany().

3807  {
3808  if (read_only_mode) {
3809  throw std::runtime_error("CREATE MODEL invalid in read only mode.");
3810  }
3811 
3812  try {
3813  train_model(session);
3814  } catch (std::exception& e) {
3815  std::ostringstream error_oss;
3816  // Error messages from table functions come back like this:
3817  // Error executing table function: MLTableFunctions.hpp:269 linear_reg_fit_impl: No
3818  // rows exist in training input. Training input must at least contain 1 row.
3819 
3820  // We want to take everything after the function name, so we will search for the
3821  // third colon.
3822  // Todo(todd): Look at making this less hacky by setting a mode for the table
3823  // function that will return only the core error string and not the preprending
3824  // metadata
3825 
3826  auto get_error_substring = [](const std::string& message) -> std::string {
3827  size_t colon_position = std::string::npos;
3828  for (int i = 0; i < 3; ++i) {
3829  colon_position = message.find(':', colon_position + 1);
3830  if (colon_position == std::string::npos) {
3831  return message;
3832  }
3833  }
3834 
3835  if (colon_position + 2 >= message.length()) {
3836  return message;
3837  }
3838  return message.substr(colon_position + 2);
3839  };
3840 
3841  const auto error_substr = get_error_substring(e.what());
3842 
3843  error_oss << "Could not create model " << model_name_ << ". " << error_substr;
3844  throw std::runtime_error(error_oss.str());
3845  }
3846 }
void train_model(const Catalog_Namespace::SessionInfo &session)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

const std::string& Parser::CreateModelStmt::get_model_name ( ) const
inline

Definition at line 1963 of file ParserNode.h.

References model_name_.

Referenced by check_model_exists(), and train_model().

1963 { return model_name_; }

+ Here is the caller graph for this function:

const std::string& Parser::CreateModelStmt::get_select_query ( ) const
inline

Definition at line 1964 of file ParserNode.h.

References select_query_.

1964 { return select_query_; }
std::string select_query_
Definition: ParserNode.h:1972
void Parser::CreateModelStmt::parse_model_options ( )
private

Definition at line 3563 of file ParserNode.cpp.

References data_split_eval_fraction_, data_split_train_fraction_, Parser::DoubleLiteral::get_doubleval(), Parser::IntLiteral::get_intval(), Parser::StringLiteral::get_stringval(), model_options_, num_options_, and options_oss_.

Referenced by train_model().

3563  {
3564  bool train_fraction_specified = false;
3565  bool eval_fraction_specified = false;
3566  for (auto& p : model_options_) {
3567  const auto key = boost::to_lower_copy<std::string>(*p->get_name());
3568  if (key == "train_fraction" || key == "data_split_train_fraction") {
3569  if (train_fraction_specified) {
3570  throw std::runtime_error(
3571  "Error parsing DATA_SPLIT_TRAIN_FRACTION value. "
3572  "Expected only one value.");
3573  }
3574  const DoubleLiteral* fp_literal =
3575  dynamic_cast<const DoubleLiteral*>(p->get_value());
3576  if (fp_literal != nullptr) {
3577  data_split_train_fraction_ = fp_literal->get_doubleval();
3578  if (data_split_train_fraction_ <= 0.0 || data_split_train_fraction_ > 1.0) {
3579  throw std::runtime_error(
3580  "Error parsing DATA_SPLIT_TRAIN_FRACTION value. "
3581  "Expected value between 0.0 and 1.0.");
3582  }
3583  } else {
3584  throw std::runtime_error(
3585  "Error parsing DATA_SPLIT_TRAIN_FRACTION value. "
3586  "Expected floating point value betwen 0.0 and 1.0.");
3587  }
3588  train_fraction_specified = true;
3589  continue;
3590  }
3591  if (key == "eval_fraction" || key == "data_split_eval_fraction") {
3592  if (eval_fraction_specified) {
3593  throw std::runtime_error(
3594  "Error parsing DATA_SPLIT_EVAL_FRACTION value. "
3595  "Expected only one value.");
3596  }
3597  const DoubleLiteral* fp_literal =
3598  dynamic_cast<const DoubleLiteral*>(p->get_value());
3599  if (fp_literal != nullptr) {
3600  data_split_eval_fraction_ = fp_literal->get_doubleval();
3601  if (data_split_eval_fraction_ < 0.0 || data_split_eval_fraction_ >= 1.0) {
3602  throw std::runtime_error(
3603  "Error parsing DATA_SPLIT_EVAL_FRACTION value. "
3604  "Expected value between 0.0 and 1.0.");
3605  }
3606  } else {
3607  throw std::runtime_error(
3608  "Error parsing DATA_SPLIT_EVAL_FRACTION value. "
3609  "Expected floating point value betwen 0.0 and 1.0.");
3610  }
3611  eval_fraction_specified = true;
3612  continue;
3613  }
3614  if (num_options_) {
3615  options_oss_ << ", ";
3616  }
3617  num_options_++;
3618  options_oss_ << key << " => ";
3619  const StringLiteral* str_literal = dynamic_cast<const StringLiteral*>(p->get_value());
3620  if (str_literal != nullptr) {
3621  options_oss_ << "'"
3622  << boost::to_lower_copy<std::string>(*str_literal->get_stringval())
3623  << "'";
3624  continue;
3625  }
3626  const IntLiteral* int_literal = dynamic_cast<const IntLiteral*>(p->get_value());
3627  if (int_literal != nullptr) {
3628  options_oss_ << int_literal->get_intval();
3629  continue;
3630  }
3631  const DoubleLiteral* fp_literal = dynamic_cast<const DoubleLiteral*>(p->get_value());
3632  if (fp_literal != nullptr) {
3633  options_oss_ << fp_literal->get_doubleval();
3634  continue;
3635  }
3636  throw std::runtime_error("Error parsing value.");
3637  }
3638 
3639  // First handle case where data_split_train_fraction was left to default value
3640  // and data_split_eval_fraction was specified. We shouldn't error here,
3641  // but rather set data_split_train_fraction to 1.0 - data_split_eval_fraction
3642  // Likewise if data_split_eval_fraction was left to default value and we have
3643  // a specified data_split_train_fraction, we should set data_split_eval_fraction
3644  // to 1.0 - data_split_train_fraction
3647  } else if (data_split_eval_fraction_ == 0.0 && data_split_train_fraction_ < 1.0) {
3649  }
3650 
3651  // If data_split_train_fraction was specified, and data_split_train_fraction +
3652  // data_split_eval_fraction > 1.0, then we should error
3654  throw std::runtime_error(
3655  "Error parsing DATA_SPLIT_TRAIN_FRACTION and DATA_SPLIT_EVAL_FRACTION values. "
3656  "Expected sum of values to be less than or equal to 1.0.");
3657  }
3658 }
std::list< std::unique_ptr< NameValueAssign > > model_options_
Definition: ParserNode.h:1975
std::ostringstream options_oss_
Definition: ParserNode.h:1976

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void Parser::CreateModelStmt::train_model ( const Catalog_Namespace::SessionInfo session)

Definition at line 3747 of file ParserNode.cpp.

References build_model_query(), check_model_exists(), query_state::QueryState::create(), data_split_eval_fraction_, data_split_train_fraction_, shared::encode_base64(), feature_permutations_, get_ml_model_type_str(), get_model_name(), model_feature_vars_, model_predicted_var_, model_type_, num_options_, options_oss_, parse_model_options(), Parser::LocalQueryConnector::query(), select_query_, and Parser::write_model_params_to_json().

Referenced by execute().

3747  {
3748  if (check_model_exists()) {
3749  // Will return true if model exists and if_not_exists_ is true, in this
3750  // case we should return only
3751  return;
3752  }
3753 
3755 
3756  auto session_copy = session;
3757  auto session_ptr = std::shared_ptr<Catalog_Namespace::SessionInfo>(
3758  &session_copy, boost::null_deleter());
3759 
3760  // We need to do various manipulations on the raw select query, such
3761  // as adding in any sampling or feature permutation logic. All of this
3762  // work is encapsulated in build_model_query
3763 
3764  const auto modified_select_query = build_model_query(session_ptr);
3765 
3766  // We have to base64 encode the model metadata because depending on the query,
3767  // the training data can have single quotes that trips up the parsing of the combined
3768  // select query with this metadata embedded.
3769 
3770  // This is just a temporary workaround until we store this info in the Catalog
3771  // rather than in the stored model pointer itself (and have to pass the metadata
3772  // down through the table function call)
3773  const auto model_metadata =
3776  select_query_,
3780  if (num_options_) {
3781  // The options string does not have a trailing comma,
3782  // so add it
3783  options_oss_ << ", ";
3784  }
3785  options_oss_ << "model_metadata => '" << model_metadata << "'";
3786 
3787  const std::string options_str = options_oss_.str();
3788 
3789  const std::string model_train_func = get_ml_model_type_str(model_type_) + "_FIT";
3790 
3791  std::ostringstream model_query_oss;
3792  model_query_oss << "SELECT * FROM TABLE(" << model_train_func << "(model_name=>'"
3793  << get_model_name() << "', data=>CURSOR(" << modified_select_query
3794  << ")";
3795  model_query_oss << ", " << options_str;
3796  model_query_oss << "))";
3797 
3798  std::string wrapped_model_query = model_query_oss.str();
3799  auto query_state = query_state::QueryState::create(session_ptr, wrapped_model_query);
3800  // Don't need result back from query, as the query will create the model
3801  LocalQueryConnector local_connector;
3802  local_connector.query(
3803  query_state->createQueryStateProxy(), wrapped_model_query, {}, false);
3804 }
std::string get_ml_model_type_str(const MLModelType model_type)
Definition: MLModelType.h:27
const std::string & get_model_name() const
Definition: ParserNode.h:1963
static std::shared_ptr< QueryState > create(ARGS &&...args)
Definition: QueryState.h:148
std::string write_model_params_to_json(const std::string &predicted, const std::vector< std::string > &features, const std::string &training_query, const double data_split_train_fraction, const double data_split_eval_fraction, const std::vector< int64_t > &feature_permutations)
std::string model_predicted_var_
Definition: ParserNode.h:1980
std::vector< std::string > model_feature_vars_
Definition: ParserNode.h:1981
std::vector< int64_t > feature_permutations_
Definition: ParserNode.h:1982
std::ostringstream options_oss_
Definition: ParserNode.h:1976
std::string build_model_query(const std::shared_ptr< Catalog_Namespace::SessionInfo > session_ptr)
std::string select_query_
Definition: ParserNode.h:1972
static std::string encode_base64(const std::string &val)
Definition: base64.h:45

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

Member Data Documentation

double Parser::CreateModelStmt::data_split_eval_fraction_ {0.0}
private

Definition at line 1979 of file ParserNode.h.

Referenced by parse_model_options(), and train_model().

double Parser::CreateModelStmt::data_split_train_fraction_ {1.0}
private

Definition at line 1978 of file ParserNode.h.

Referenced by build_model_query(), parse_model_options(), and train_model().

std::vector<int64_t> Parser::CreateModelStmt::feature_permutations_
private

Definition at line 1982 of file ParserNode.h.

Referenced by build_model_query(), and train_model().

bool Parser::CreateModelStmt::if_not_exists_
private

Definition at line 1974 of file ParserNode.h.

Referenced by check_model_exists(), and CreateModelStmt().

std::vector<std::string> Parser::CreateModelStmt::model_feature_vars_
private

Definition at line 1981 of file ParserNode.h.

Referenced by build_model_query(), and train_model().

std::string Parser::CreateModelStmt::model_name_
private

Definition at line 1971 of file ParserNode.h.

Referenced by CreateModelStmt(), execute(), and get_model_name().

std::list<std::unique_ptr<NameValueAssign> > Parser::CreateModelStmt::model_options_
private

Definition at line 1975 of file ParserNode.h.

Referenced by CreateModelStmt(), and parse_model_options().

std::string Parser::CreateModelStmt::model_predicted_var_
private

Definition at line 1980 of file ParserNode.h.

Referenced by build_model_query(), and train_model().

MLModelType Parser::CreateModelStmt::model_type_
private

Definition at line 1970 of file ParserNode.h.

Referenced by build_model_query(), CreateModelStmt(), and train_model().

size_t Parser::CreateModelStmt::num_options_ {0}
private

Definition at line 1977 of file ParserNode.h.

Referenced by parse_model_options(), and train_model().

std::ostringstream Parser::CreateModelStmt::options_oss_
private

Definition at line 1976 of file ParserNode.h.

Referenced by parse_model_options(), and train_model().

bool Parser::CreateModelStmt::replace_
private

Definition at line 1973 of file ParserNode.h.

Referenced by check_model_exists(), and CreateModelStmt().

std::string Parser::CreateModelStmt::select_query_
private

Definition at line 1972 of file ParserNode.h.

Referenced by build_model_query(), CreateModelStmt(), get_select_query(), and train_model().


The documentation for this class was generated from the following files: