diff --git a/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx b/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx index 2a5ae03a6ee25..350f005574508 100644 --- a/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx +++ b/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx @@ -11,6 +11,7 @@ namespace onnx { class NodeProto; class GraphProto; +class ModelProto; } // namespace onnx namespace TMVA { @@ -68,15 +69,22 @@ public: std::unique_ptr ParseOperator(const size_t /*index*/, const onnx::GraphProto & /*graphproto*/, const std::vector & /*nodes*/); + // check a graph for missing operators + void CheckGraph(const onnx::GraphProto & g, int & level, std::map & missingOperators); + // parse the ONNX graph void ParseONNXGraph(RModel & model, const onnx::GraphProto & g, std::string name = ""); + std::unique_ptr LoadModel(std::string filename); + public: RModelParser_ONNX() noexcept; RModel Parse(std::string filename, bool verbose = false); + // check the model for missing operators - return false in case some operator implementation is missing + bool CheckModel(std::string filename, bool verbose = false); ~RModelParser_ONNX(); }; diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index 81a4b23697898..23549646486fb 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -312,6 +312,19 @@ RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphpr RModel RModelParser_ONNX::Parse(std::string filename, bool verbose) { fVerbose = verbose; + + fTensorTypeMap.clear(); + + auto model = LoadModel(filename); + + const onnx::GraphProto &graph = model->graph(); // not a memory leak. model freed automatically at the end. + + + std::time_t ttime = std::time(0); + std::tm *gmt_time = std::gmtime(&ttime); + std::string parsetime(std::asctime(gmt_time)); + + // get name of model (filename without directory name) char sep = '/'; #ifdef _WIN32 sep = '\\'; @@ -322,34 +335,83 @@ RModel RModelParser_ONNX::Parse(std::string filename, bool verbose) filename_nodir = (filename.substr(isep + 1, filename.length() - isep)); } + RModel rmodel(filename_nodir, parsetime); + ParseONNXGraph(rmodel, graph, filename_nodir); + return rmodel; +} - GOOGLE_PROTOBUF_VERIFY_VERSION; - // model I/O - onnx::ModelProto model; - +std::unique_ptr RModelParser_ONNX::LoadModel(std::string filename) { - fTensorTypeMap.clear(); + GOOGLE_PROTOBUF_VERIFY_VERSION; + auto model = std::make_unique(); std::fstream input(filename, std::ios::in | std::ios::binary); - if (!model.ParseFromIstream(&input)) { + if (!model->ParseFromIstream(&input)) { throw std::runtime_error("TMVA::SOFIE - Failed to parse onnx file " + filename); } - const onnx::GraphProto &graph = model.graph(); // not a memory leak. model freed automatically at the end. - google::protobuf::ShutdownProtobufLibrary(); - // ONNX version is ir_version() - model_version() returns 0 if (fVerbose) { - std::cout << "ONNX Version " << model.ir_version() << std::endl; + std::cout << "ONNX Version " << model->ir_version() << std::endl; } + google::protobuf::ShutdownProtobufLibrary(); + return model; - std::time_t ttime = std::time(0); - std::tm *gmt_time = std::gmtime(&ttime); - std::string parsetime(std::asctime(gmt_time)); +} - RModel rmodel(filename_nodir, parsetime); - ParseONNXGraph(rmodel, graph, filename_nodir); - return rmodel; +void RModelParser_ONNX::CheckGraph(const onnx::GraphProto & graph, int & level, std::map & missingOperators) { + if (fVerbose) + std::cout << "\n" << graph.name() << " Graph operator list\n"; + for (int i = 0; i < graph.node_size(); i++) { + const auto & node = graph.node(i); + const std::string opType = node.op_type(); + if (fVerbose) { + std::cout << "\tOperator " << i << " : " << opType << " (" << node.name() << "), " << graph.node(i).input_size() + << " inputs : {" + for (int j = 0; j < graph.node(i).input_size(); j++) { + std::cout << graph.node(i).input(j); + if (j < graph.node(i).input_size() - 1) + std::cout << ", "; + } + std::cout << " }" << std::endl; + } + // check if operator exists + if (!IsRegisteredOperator(opType)) + missingOperators[opType] = level; + // see if sub-graph exists as node attributes + for (int j = 0; j < node.attribute_size(); j++) { + const auto & attribute = node.attribute(j); + if (attribute.has_g()) { + const auto & subGraph = attribute.g(); + level += 1; + CheckGraph(subGraph, level, missingOperators); + } + } + } +} + +bool RModelParser_ONNX::CheckModel(std::string filename, bool verbose) { + + fVerbose = verbose; + auto model = LoadModel(filename); + const onnx::GraphProto &graph = model->graph(); + // Initial operator order + if (fVerbose) + std::cout << "\nModel operator list " << model->producer_name() << "\n"; + + std::map missingOperators; + int level = 1; + CheckGraph(graph, level, missingOperators); + + if (!missingOperators.empty()) { + std::cout << "List of missing operators for model loaded from file " << filename << std::endl; + for (auto & op : missingOperators) { + std::cout << op.first << " " << op.second << std::endl; + } + return false; + } + std::cout << "All operators in the loaded model are supported!\n"; + return true; } void RModelParser_ONNX::ParseONNXGraph(RModel & rmodel, const onnx::GraphProto & graph, std::string graphName)