53 #include <rapidjson/document.h>
54 #include <rapidjson/stringbuffer.h>
55 #include <rapidjson/writer.h>
60 #include <type_traits>
68 std::shared_ptr<rapidjson::Document>
doc_;
70 rapidjson::Document::AllocatorType&
allo_;
76 :
doc_(std::make_shared<rapidjson::Document>())
98 void parse(
const std::string& json) {
99 if (
doc_->Parse(json.c_str()).HasParseError()) {
100 throw std::runtime_error(
"failed to parse json");
106 if (
doc_->Parse(json).HasParseError()) {
107 throw std::runtime_error(
"failed to parse json");
113 void parse(
const char* json,
size_t len) {
114 if (
doc_->Parse(json, len).HasParseError()) {
115 throw std::runtime_error(
"failed to parse json");
121 rapidjson::StringBuffer buf;
122 rapidjson::Writer<rapidjson::StringBuffer> wr(buf);
124 return buf.GetString();
131 std::string
str()
const {
return static_cast<std::string
>(*this); }
133 bool b1()
const {
return static_cast<bool>(*this); }
135 uint64_t
u64()
const {
return static_cast<uint64_t
>(*this); }
136 int64_t
i64()
const {
return static_cast<int64_t
>(*this); }
138 uint32_t
u32()
const {
return static_cast<uint32_t
>(*this); }
139 int32_t
i32()
const {
return static_cast<int32_t
>(*this); }
141 uint16_t
u16()
const {
return static_cast<uint16_t
>(*this); }
142 int16_t
i16()
const {
return static_cast<int16_t
>(*this); }
144 uint8_t
u8()
const {
return static_cast<uint8_t
>(*this); }
145 int8_t
i8()
const {
return static_cast<int8_t
>(*this); }
147 double d64()
const {
return static_cast<double>(*this); }
148 float f32()
const {
return static_cast<float>(*this); }
166 operator std::string()
const {
167 if (!
vptr_->IsString()) {
168 throw std::runtime_error(
"expected JSON field '" +
name_ +
172 return std::string{
vptr_->GetString(),
vptr_->GetStringLength()};
175 operator bool()
const {
176 if (!
vptr_->IsBool()) {
177 throw std::runtime_error(
"expected JSON field '" +
name_ +
178 "' to be Boolean but got [" +
181 return vptr_->GetBool();
184 operator uint64_t()
const {
185 if (!
vptr_->IsUint64()) {
186 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
187 "' to be unsigned 64-bit integer from [" +
190 return vptr_->GetUint64();
193 operator int64_t()
const {
194 if (!
vptr_->IsInt64()) {
195 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
196 "' to be signed 64-bit integer from [" +
199 return vptr_->GetInt64();
202 operator uint32_t()
const {
203 if (!
vptr_->IsUint()) {
204 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
205 "' to be unsigned 32-bit integer from [" +
208 return vptr_->GetUint();
211 operator int32_t()
const {
212 if (!
vptr_->IsInt()) {
213 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
214 "' to be signed 32-bit integer from [" +
217 return vptr_->GetInt();
220 operator uint16_t()
const {
221 if (!
vptr_->IsUint()) {
222 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
223 "' to be unsigned 16-bit integer from [" +
226 return vptr_->GetUint();
229 operator int16_t()
const {
230 if (!
vptr_->IsInt()) {
231 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
232 "' to be signed 16-bit integer from [" +
235 return vptr_->GetInt();
238 operator uint8_t()
const {
239 if (!
vptr_->IsUint()) {
240 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
241 "' to be unsigned 8-bit integer from [" +
244 return vptr_->GetUint();
247 operator int8_t()
const {
248 if (!
vptr_->IsInt()) {
249 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
250 "' to be signed 8-bit integer from [" +
253 return vptr_->GetInt();
256 operator double()
const {
257 if (!
vptr_->IsDouble()) {
258 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
259 "' to be floating point number from [" +
262 return vptr_->GetDouble();
265 operator float()
const {
266 if (!
vptr_->IsDouble()) {
267 throw std::runtime_error(
"can't convert JSON field '" +
name_ +
268 "' to be floating point number from [" +
271 return static_cast<float>(
vptr_->GetDouble());
282 *
vptr_ = rapidjson::Value().SetString(item.c_str(),
allo_);
287 *
vptr_ = rapidjson::Value().SetString(item,
allo_);
292 vptr_->SetBool(item);
302 vptr_->SetInt64(item);
307 vptr_->SetUint(item);
312 vptr_->SetUint64(item);
320 if (!
vptr_->IsObject()) {
323 if (!
vptr_->HasMember(name)) {
325 rapidjson::Value(name,
allo_).Move(), rapidjson::Value().Move(),
allo_);
326 auto f =
vptr_->FindMember(name);
336 if (!
vptr_->IsObject()) {
337 throw std::runtime_error(
"JSON " +
kTypeNames[
vptr_->GetType()] +
" field '" +
338 name_ +
"' can't use operator []");
340 if (!
vptr_->HasMember(name)) {
341 throw std::runtime_error(
"JSON field '" + std::string(name) +
"' not found");
346 template <
typename T>
348 return operator[](static_cast<size_t>(index));
351 if (!
vptr_->IsArray()) {
354 if (index >=
vptr_->Size()) {
355 throw std::runtime_error(
"JSON array index " +
std::to_string(index) +
361 template <
typename T>
363 return operator[](static_cast<size_t>(index));
366 if (!
vptr_->IsArray()) {
367 throw std::runtime_error(
"JSON " +
kTypeNames[
vptr_->GetType()] +
" field '" +
368 name_ +
"' can't use operator []");
370 if (index >=
vptr_->Size()) {
371 throw std::runtime_error(
"JSON array index " +
std::to_string(index) +
379 {
"Null",
"False",
"True",
"Object",
"Array",
"String",
"Number"};
382 JSON(std::shared_ptr<rapidjson::Document> doc,
383 rapidjson::Value* vptr,
384 rapidjson::Document::AllocatorType& allo,
385 const std::string&
name)
391 template <
typename T>
393 template <
typename T>
396 template <
typename T>
398 template <
typename T>
414 template <
typename T>
416 return (*json.
vptr_ == value);
418 template <
typename T>
420 return (json == value);
423 template <
typename T>
425 return (*json.
vptr_ != value);
427 template <
typename T>
429 return (json != value);
JSON(const std::string &json)
JSON & operator=(const JSON &peer)
std::shared_ptr< rapidjson::Document > doc_
friend bool operator!=(const JSON &json1, const JSON &json2)
JSON operator[](const char *name)
void parse(const std::string &json)
JSON & operator=(uint64_t item)
JSON & operator=(int64_t item)
JSON(const char *json, size_t len)
JSON & operator=(bool item)
JSON operator[](const char *name) const
JSON & operator=(const std::string &item)
std::string stringify() const
JSON & operator=(uint32_t item)
JSON operator[](T index) const
bool hasMember(const std::string &name) const
JSON operator[](size_t index) const
JSON & operator=(const char *item)
JSON operator[](const std::string &name) const
torch::Tensor f(torch::Tensor x, torch::Tensor W_target, torch::Tensor b_target)
JSON operator[](size_t index)
JSON(std::shared_ptr< rapidjson::Document > doc, rapidjson::Value *vptr, rapidjson::Document::AllocatorType &allo, const std::string &name)
friend bool operator==(const JSON &json1, const JSON &json2)
JSON operator[](const std::string &name)
rapidjson::Document::AllocatorType & allo_
JSON & operator=(int32_t item)
bool operator==(const JSON &json1, const JSON &json2)
bool operator!=(const JSON &json1, const JSON &json2)
void parse(const char *json, size_t len)
void parse(const char *json)
static std::string kTypeNames[]