OmniSciDB
a5dc49c757
|
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include "TestTorchTableFunctions.h"
#include "torch/script.h"
#include "torch/torch.h"
Go to the source code of this file.
Functions | |
EXTENSION_NOINLINE int32_t | tf_test_runtime_torch (TableFunctionManager &mgr, Column< int64_t > &input, Column< int64_t > &output) |
template<typename T > | |
TEMPLATE_NOINLINE int32_t | tf_test_runtime_torch_template__template (TableFunctionManager &mgr, const Column< T > &input, Column< T > &output) |
template TEMPLATE_NOINLINE int32_t | tf_test_runtime_torch_template__template (TableFunctionManager &mgr, const Column< int64_t > &input, Column< int64_t > &output) |
template TEMPLATE_NOINLINE int32_t | tf_test_runtime_torch_template__template (TableFunctionManager &mgr, const Column< double > &input, Column< double > &output) |
EXTENSION_NOINLINE int32_t | tf_test_torch_generate_random_column (TableFunctionManager &mgr, int32_t num_elements, Column< double > &output) |
torch::Tensor | make_features_from_columns (const ColumnList< double > &cols, int32_t batch_size) |
torch::Tensor | f (torch::Tensor x, torch::Tensor W_target, torch::Tensor b_target) |
std::string | poly_desc (torch::Tensor W, torch::Tensor b) |
std::pair< torch::Tensor, torch::Tensor > | get_batch (const ColumnList< double > &cols, torch::Tensor W_target, torch::Tensor b_target, int32_t batch_size) |
EXTENSION_NOINLINE int32_t | tf_test_torch_regression (TableFunctionManager &mgr, const ColumnList< double > &features, int32_t batch_size, bool use_gpu, bool save_model, const TextEncodingNone &model_filename, Column< double > &output) |
EXTENSION_NOINLINE int32_t | tf_test_torch_load_model (TableFunctionManager &mgr, const TextEncodingNone &model_filename, Column< bool > &output) |
Variables | |
torch::Device | _test_torch_tfs_device = torch::kCPU |
torch::Tensor f | ( | torch::Tensor | x, |
torch::Tensor | W_target, | ||
torch::Tensor | b_target | ||
) |
Definition at line 103 of file TestTorchTableFunctions.cpp.
Referenced by _geoToHex2d(), QueryFragmentDescriptor::assignFragsToKernelDispatch(), QueryFragmentDescriptor::assignFragsToMultiDispatch(), threading_tbb::async(), atomicSumFltSkipVal(), TransformUTMTo4326::calculateY(), ResultSetReductionJIT::codegen(), GpuReductionHelperJIT::codegen(), FixedWidthInt::codegenDecode(), FixedWidthUnsigned::codegenDecode(), DiffFixedWidthInt::codegenDecode(), FixedWidthReal::codegenDecode(), FixedWidthSmallDate::codegenDecode(), org.apache.calcite.sql2rel.SqlToRelConverter::collectInsertTargets(), com.mapd.calcite.parser.HeavyDBParser::convertSqlToRelNode(), ArrowResultSetConverter::convertToArrow(), org.apache.calcite.sql2rel.SqlToRelConverter::convertWhere(), File_Namespace::create(), File_Namespace::FileMgr::createFileInfo(), com.mapd.tests.DateTimeTest.DateAddUnit::DateAddUnit(), com.mapd.tests.DateTimeTest.DateExtractUnit::DateExtractUnit(), Catalog_Namespace::SysCatalog::execInTransaction(), Catalog_Namespace::Catalog::execInTransaction(), Analyzer::Expr::find_expr(), Analyzer::UOper::find_expr(), Analyzer::BinOper::find_expr(), Analyzer::InValues::find_expr(), Analyzer::MLPredictExpr::find_expr(), Analyzer::PCAProjectExpr::find_expr(), Analyzer::CharLengthExpr::find_expr(), Analyzer::KeyForStringExpr::find_expr(), Analyzer::SampleRatioExpr::find_expr(), Analyzer::CardinalityExpr::find_expr(), Analyzer::LikeExpr::find_expr(), Analyzer::RegexpExpr::find_expr(), Analyzer::WidthBucketExpr::find_expr(), Analyzer::LikelihoodExpr::find_expr(), Analyzer::AggExpr::find_expr(), Analyzer::CaseExpr::find_expr(), Analyzer::ExtractExpr::find_expr(), Analyzer::DateaddExpr::find_expr(), Analyzer::DatediffExpr::find_expr(), Analyzer::DatetruncExpr::find_expr(), Analyzer::StringOper::find_expr(), Analyzer::FunctionOper::find_expr(), RasterFormat_Namespace::format_raster_data(), get_batch(), org.apache.calcite.sql2rel.SqlToRelConverter::getInitializerFactory(), import_export::Importer::importGDALRaster(), FilterSelectivity::isFilterSelectiveEnough(), CodeGenerator::link_udf_module(), org.apache.calcite.sql2rel.SqlToRelConverter.Blackboard::lookupExp(), com.mapd.tests.DateTimeTest::main(), AutomaticIRMetadataGuard::makeQueryEngineFilename(), File_Namespace::open(), File_Namespace::FileMgr::openExistingFile(), GenericKeyHandler::operator()(), BoundingBoxIntersectKeyHandler::operator()(), RangeKeyHandler::operator()(), heavyai::JSON::operator[](), threading_serial::parallel_for(), threading_std::parallel_for(), ArrowForeignStorageBase::parseArrowTable(), org.apache.calcite.sql.validate.SqlValidatorImpl.Permute::Permute(), Data_Namespace::ProcMeminfoParser::ProcMeminfoParser(), Parser::QuerySpec::QuerySpec(), File_Namespace::FileBuffer::readMetadata(), reg_hex_horiz_pixel_bin_packed(), reg_hex_horiz_pixel_bin_x(), reg_hex_horiz_pixel_bin_y(), reg_hex_vert_pixel_bin_packed(), reg_hex_vert_pixel_bin_x(), reg_hex_vert_pixel_bin_y(), threading_serial::task_group::run(), threading_std::task_group::run(), anonymous_namespace{NativeCodegen.cpp}::show_defined(), com.mapd.tests.DateTimeTest::testAdd(), com.mapd.tests.DateTimeTest::testDateAdd(), com.mapd.tests.DateTimeTest::testDateExtract(), com.mapd.tests.DateTimeTest::testDateTrunc(), com.mapd.tests.DateTimeTest::testDiff(), com.mapd.tests.DateTimeTest::testSub(), tf_metadata_getter__cpu_template(), tf_metadata_getter_bad__cpu_template(), tf_metadata_setter__cpu_template(), GDALTableFunctions::tf_raster_contour_rasterize_impl(), tf_test_torch_regression(), org.apache.calcite.sql2rel.SqlToRelConverter::toRel(), anonymous_namespace{ResultSetReductionCodegen.cpp}::translate_body(), anonymous_namespace{ResultSetReductionCodegen.cpp}::translate_for(), translate_function(), org.apache.calcite.sql.validate.SqlValidatorImpl::validateInsert(), anonymous_namespace{ExpressionRewrite.cpp}::ConstantFoldingVisitor::visitBinOper(), and File_Namespace::FileBuffer::writeMetadata().
std::pair<torch::Tensor, torch::Tensor> get_batch | ( | const ColumnList< double > & | cols, |
torch::Tensor | W_target, | ||
torch::Tensor | b_target, | ||
int32_t | batch_size | ||
) |
Definition at line 126 of file TestTorchTableFunctions.cpp.
References f(), and make_features_from_columns().
Referenced by tf_test_torch_regression().
torch::Tensor make_features_from_columns | ( | const ColumnList< double > & | cols, |
int32_t | batch_size | ||
) |
Definition at line 85 of file TestTorchTableFunctions.cpp.
References _test_torch_tfs_device, ColumnList< T >::numCols(), and ColumnList< T >::size().
Referenced by get_batch().
std::string poly_desc | ( | torch::Tensor | W, |
torch::Tensor | b | ||
) |
Definition at line 108 of file TestTorchTableFunctions.cpp.
Referenced by tf_test_torch_regression().
EXTENSION_NOINLINE int32_t tf_test_runtime_torch | ( | TableFunctionManager & | mgr, |
Column< int64_t > & | input, | ||
Column< int64_t > & | output | ||
) |
Definition at line 43 of file TestTorchTableFunctions.cpp.
TEMPLATE_NOINLINE int32_t tf_test_runtime_torch_template__template | ( | TableFunctionManager & | mgr, |
const Column< T > & | input, | ||
Column< T > & | output | ||
) |
Definition at line 51 of file TestTorchTableFunctions.cpp.
template TEMPLATE_NOINLINE int32_t tf_test_runtime_torch_template__template | ( | TableFunctionManager & | mgr, |
const Column< int64_t > & | input, | ||
Column< int64_t > & | output | ||
) |
template TEMPLATE_NOINLINE int32_t tf_test_runtime_torch_template__template | ( | TableFunctionManager & | mgr, |
const Column< double > & | input, | ||
Column< double > & | output | ||
) |
EXTENSION_NOINLINE int32_t tf_test_torch_generate_random_column | ( | TableFunctionManager & | mgr, |
int32_t | num_elements, | ||
Column< double > & | output | ||
) |
Definition at line 70 of file TestTorchTableFunctions.cpp.
References TableFunctionManager::set_output_row_size().
EXTENSION_NOINLINE int32_t tf_test_torch_load_model | ( | TableFunctionManager & | mgr, |
const TextEncodingNone & | model_filename, | ||
Column< bool > & | output | ||
) |
Definition at line 226 of file TestTorchTableFunctions.cpp.
References TextEncodingNone::getString(), boost::serialization::load(), and TableFunctionManager::set_output_row_size().
EXTENSION_NOINLINE int32_t tf_test_torch_regression | ( | TableFunctionManager & | mgr, |
const ColumnList< double > & | features, | ||
int32_t | batch_size, | ||
bool | use_gpu, | ||
bool | save_model, | ||
const TextEncodingNone & | model_filename, | ||
Column< double > & | output | ||
) |
Definition at line 146 of file TestTorchTableFunctions.cpp.
References _test_torch_tfs_device, f(), get_batch(), TextEncodingNone::getString(), ColumnList< T >::numCols(), poly_desc(), boost::serialization::save(), and TableFunctionManager::set_output_row_size().
torch::Device _test_torch_tfs_device = torch::kCPU |
Definition at line 40 of file TestTorchTableFunctions.cpp.
Referenced by make_features_from_columns(), and tf_test_torch_regression().