From 3c1ac2a40e1d145b306968aebfd45b47c925a005 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 8 Jul 2016 14:26:14 -0700 Subject: [PATCH] [PASS] Add save/load json (#1) --- nnvm/include/nnvm/node.h | 6 +- nnvm/src/pass/saveload_json.cc | 202 +++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 3 deletions(-) create mode 100644 nnvm/src/pass/saveload_json.cc diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index d6eb5817abc7..51e7a6049001 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -31,6 +31,8 @@ struct NodeEntry { * Usually are additional parameters like axis, */ struct NodeAttrs { + /*! \brief name of the node */ + std::string name; /*! \brief The dictionary representation of attributes */ std::unordered_map dict; /*! @@ -46,13 +48,11 @@ struct NodeAttrs { */ class Node { public: - /*! \brief name of the node */ - std::string name; /*! * \brief The operator this node uses. * For place holder variable, op == nullptr. */ - const Op *op; + const Op *op{nullptr}; /*! \brief inputs to this node */ std::vector inputs; /*! diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc new file mode 100644 index 000000000000..9e82f1bd59cc --- /dev/null +++ b/nnvm/src/pass/saveload_json.cc @@ -0,0 +1,202 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file saveload_json.cc + * \brief Passes that defines save and load graph to/from JSON file. + */ +#include +#include +#include + +namespace dmlc { +namespace json { +// overload handler for shared ptr +template<> +struct Handler > { + inline static void Write(JSONWriter *writer, const std::shared_ptr &data) { + writer->Write(*data); + } + inline static void Read(JSONReader *reader, std::shared_ptr *data) { + any v; + reader->Read(&v); + *data = std::make_shared(std::move(v)); + } +}; +} // namespace json +} // namespace dmlc + +namespace nnvm { +namespace pass { + +// auxiliary node structure for serialization. +struct JSONNode { + // the node entry structure in serialized format + typedef std::pair Entry; + // pointer to the graph node + std::shared_ptr node; + // inputs + std::vector inputs; + // control flow dependencies + std::vector control_deps; + + // function to save JSON node. + void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + if (node->op != nullptr) { + writer->WriteObjectKeyValue("op", node->op->name); + writer->WriteObjectKeyValue("attr", node->attrs.dict); + } else { + std::string json_null = "null"; + writer->WriteObjectKeyValue("op", json_null); + } + writer->WriteObjectKeyValue("name", node->attrs.name); + writer->WriteObjectKeyValue("inputs", inputs); + writer->WriteObjectKeyValue("control_deps", control_deps); + writer->EndObject(); + } + + void Load(dmlc::JSONReader *reader) { + node = std::move(Node::Create()); + control_deps.clear(); + dmlc::JSONObjectReadHelper helper; + std::string op_type_str; + helper.DeclareField("op", &op_type_str); + helper.DeclareField("name", &(node->attrs.name)); + helper.DeclareField("inputs", &inputs); + helper.DeclareOptionalField("attr", &(node->attrs.dict)); + helper.DeclareOptionalField("control_deps", &control_deps); + // backward compatible code with mxnet graph. + int backward_source_id; + std::unordered_map param; + helper.DeclareOptionalField("param", ¶m); + helper.DeclareOptionalField("backward_source_id", &backward_source_id); + node->attrs.dict.insert(param.begin(), param.end()); + helper.ReadAllFields(reader); + + if (op_type_str != "null") { + try { + node->op = Op::Get(op_type_str); + } catch (const dmlc::Error &err) { + std::ostringstream os; + os << "Failed loading Op " << node->attrs.name + << " of type " << op_type_str << ": " << err.what(); + throw dmlc::Error(os.str()); + } + } else { + node->op = nullptr; + } + } +}; + +// graph structure to help read/save JSON. +struct JSONGraph { + std::vector nodes; + std::vector arg_nodes; + std::vector heads; + std::unordered_map > attrs; + + void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("nodes", nodes); + writer->WriteObjectKeyValue("arg_nodes", arg_nodes); + writer->WriteObjectKeyValue("heads", heads); + if (attrs.size() != 0) { + writer->WriteObjectKeyValue("attrs", attrs); + } + writer->EndObject(); + } + + void Load(dmlc::JSONReader *reader) { + attrs.clear(); + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("nodes", &nodes); + helper.DeclareField("arg_nodes", &arg_nodes); + helper.DeclareField("heads", &heads); + helper.DeclareOptionalField("attrs", &attrs); + helper.ReadAllFields(reader); + } +}; + +// Load a graph from JSON file. +Graph LoadJSON(const Graph& src) { + CHECK_NE(src.attrs.count("json"), 0) + << "Load JSON require json to be presented."; + const std::string &json_str = + nnvm::get(*src.attrs.at("json")); + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JSONGraph jgraph; + // load in json graph. + jgraph.Load(&reader); + // connects the nodes + for (JSONNode &n : jgraph.nodes) { + n.node->inputs.reserve(n.inputs.size()); + for (const JSONNode::Entry &e : n.inputs) { + n.node->inputs.emplace_back( + NodeEntry{jgraph.nodes[e.first].node, e.second}); + } + n.node->control_deps.reserve(n.control_deps.size()); + for (uint32_t nid : n.control_deps) { + n.node->control_deps.push_back(jgraph.nodes[nid].node); + } + } + // consistent check + for (uint32_t nid : jgraph.arg_nodes) { + CHECK(jgraph.nodes[nid].node->is_variable()); + } + // return the graph + Graph ret; + ret.attrs = std::move(jgraph.attrs); + ret.outputs.reserve(jgraph.heads.size()); + for (const JSONNode::Entry &e : jgraph.heads) { + ret.outputs.emplace_back( + NodeEntry{jgraph.nodes[e.first].node, e.second}); + } + return ret; +} + +// save a graph to json +Graph SaveJSON(const Graph& src) { + JSONGraph jgraph; + std::unordered_map node2index; + src.DFSVisit([&node2index, &jgraph](const std::shared_ptr& n) { + uint32_t nid = static_cast(jgraph.nodes.size()); + node2index[n.get()] = nid; + if (n->is_variable()) { + jgraph.arg_nodes.push_back(nid); + } + JSONNode jnode; + jnode.node = n; + jnode.inputs.reserve(n->inputs.size()); + for (const NodeEntry& e : n->inputs) { + jnode.inputs.emplace_back( + std::make_pair(node2index.at(e.node.get()), e.index)); + } + for (const std::shared_ptr& c : n->control_deps) { + jnode.control_deps.push_back(node2index.at(c.get())); + } + jgraph.nodes.emplace_back(std::move(jnode)); + }); + + std::ostringstream os; + dmlc::JSONWriter writer(&os); + jgraph.Save(&writer); + Graph ret; + ret.attrs["json"] = std::make_shared(os.str()); + return ret; +} + +// register pass +NNVM_REGISTER_PASS(LoadJSON) +.describe("Return a new Graph, loaded from src.attrs[\"json\"]") +.set_body(LoadJSON) +.set_change_graph(true) +.depend_graph_attr("json"); + +NNVM_REGISTER_PASS(SaveJSON) +.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]") +.set_body(SaveJSON) +.set_change_graph(true) +.provide_graph_attr("json"); + +} // namespace pass +} // namespace nnvm