Executes the DDL command corresponding to provided JSON payload.
2395 std::string model_name;
2396 std::string select_query;
2397 if (ddl_payload.HasMember(
"modelName")) {
2398 model_name = ddl_payload[
"modelName"].GetString();
2400 if (ddl_payload.HasMember(
"query")) {
2401 select_query = ddl_payload[
"query"].GetString();
2403 std::regex newline_re(
"\\n");
2404 std::regex backtick_re(
"`");
2405 select_query = std::regex_replace(select_query, newline_re,
" ");
2406 select_query = std::regex_replace(select_query, backtick_re,
"");
2408 const auto model_type = model_metadata.getModelType();
2410 const auto& model_type_str = model_metadata.getModelTypeStr();
2411 std::ostringstream error_oss;
2412 error_oss <<
"EVALUATE MODEL not supported for " << model_type_str <<
" models.";
2413 throw std::runtime_error(error_oss.str());
2416 if (select_query.empty()) {
2417 const double data_split_eval_fraction = model_metadata.getDataSplitEvalFraction();
2418 CHECK_LE(data_split_eval_fraction, 1.0);
2419 if (data_split_eval_fraction <= 0.0) {
2420 throw std::runtime_error(
2421 "Unable to evaluate model: " + model_name +
2422 ". Model was not trained with a data split evaluation fraction.");
2424 const auto& training_query = model_metadata.getTrainingQuery();
2425 const auto& feature_permutations = model_metadata.getFeaturePermutations();
2426 std::ostringstream select_query_oss;
2431 const double data_split_train_fraction = 1.0 - data_split_eval_fraction;
2432 select_query_oss <<
"SELECT ";
2433 if (!feature_permutations.empty()) {
2434 select_query_oss << model_metadata.getPredicted() <<
", ";
2435 const auto& features = model_metadata.getFeatures();
2436 for (
const auto feature_permutation : feature_permutations) {
2437 select_query_oss << features[feature_permutation];
2438 if (feature_permutation != feature_permutations.back()) {
2439 select_query_oss <<
", ";
2443 select_query_oss <<
" * ";
2446 select_query_oss <<
" FROM (" << training_query <<
") WHERE NOT SAMPLE_RATIO("
2447 << data_split_train_fraction <<
")";
2448 select_query = select_query_oss.str();
2450 const auto& feature_permutations = model_metadata.getFeaturePermutations();
2451 if (!feature_permutations.empty()) {
2453 auto validate_query_state =
2455 auto validate_result = local_connector.
query(
2456 validate_query_state->createQueryStateProxy(), select_query, {},
true,
false);
2457 auto column_descriptors_list =
2459 std::vector<ColumnDescriptor> column_descriptors;
2460 for (
auto& cd : column_descriptors_list) {
2461 column_descriptors.emplace_back(cd);
2463 std::ostringstream select_query_oss;
2464 select_query_oss <<
"SELECT " << model_metadata.getPredicted() <<
", ";
2465 for (
const auto feature_permutation : feature_permutations) {
2466 select_query_oss << column_descriptors[feature_permutation + 1].columnName;
2467 if (feature_permutation != feature_permutations.back()) {
2468 select_query_oss <<
", ";
2471 select_query_oss <<
" FROM (" << select_query <<
")";
2472 select_query = select_query_oss.str();
2475 std::ostringstream r2_query_oss;
2476 r2_query_oss <<
"SELECT * FROM TABLE(r2_score(model_name => '" << model_name <<
"', "
2477 <<
"data => CURSOR(" << select_query <<
")))";
2478 std::string r2_query = r2_query_oss.str();
2484 local_connector.
query(query_state->createQueryStateProxy(), r2_query, {},
false);
2485 std::vector<std::string> labels{
"r2"};
2486 std::vector<TargetMetaInfo> label_infos;
2487 for (
const auto&
label : labels) {
2490 std::vector<RelLogicalValues::RowValues> logical_values;
2497 auto result_row =
result[0].rs->getNextRow(
true,
true);
2499 auto scalar_r = boost::get<ScalarTargetValue>(&result_row[0]);
2500 auto p = boost::get<double>(scalar_r);
2504 std::shared_ptr<ResultSet> rSet = std::shared_ptr<ResultSet>(
2508 }
catch (
const std::exception& e) {
2509 std::ostringstream error_oss;
2519 auto get_error_substring = [](
const std::string& message) -> std::string {
2520 size_t colon_position = std::string::npos;
2521 for (
int i = 0; i < 3; ++i) {
2522 colon_position = message.find(
':', colon_position + 1);
2523 if (colon_position == std::string::npos) {
2528 if (colon_position + 2 >= message.length()) {
2531 return message.substr(colon_position + 2);
2534 const auto error_substr = get_error_substring(e.what());
2535 error_oss <<
"Could not evaluate model " << model_name <<
". " << error_substr;
2536 throw std::runtime_error(error_oss.str());
auto getExecuteReadLock()
std::unique_ptr< RexLiteral > genLiteralDouble(double val)
std::vector< MLModelMetadata > getModelMetadata() const
static std::shared_ptr< QueryState > create(ARGS &&...args)
AggregatedResult query(QueryStateProxy, std::string &sql_query_string, std::vector< size_t > outer_frag_indices, bool validate_only, bool allow_interrupt)
const DdlCommandData & ddl_data_
const rapidjson::Value & extractPayload(const DdlCommandData &ddl_data)
static ResultSet * create(std::vector< TargetMetaInfo > &label_infos, std::vector< RelLogicalValues::RowValues > &logical_values)
bool is_regression_model(const MLModelType model_type)
std::list< ColumnDescriptor > getColumnDescriptors(AggregatedResult &result, bool for_create)
std::shared_ptr< Catalog_Namespace::SessionInfo const > session_ptr_
std::vector< std::unique_ptr< const RexScalar >> RowValues