Skip to content

Commit

Permalink
[tmva][sofie] Add a new function CheckModel in RModelParser_ONNX
Browse files Browse the repository at this point in the history
A new function has been added to check the model operators (nodes) and
print the list of missing ones (not yet supported)

To use it do:

```
TMVA::Experimental::SOFIE::RModelParser_ONNX p;
bool ret = p.CheckModel("model.onnx");
```

It will return true if all operators of the model are supported.
In case of missing ones, it will print their list.
The check will also extend to subgraph presented in the model as attributes of
some specific nodes (e.g. of If operator)
  • Loading branch information
lmoneta committed Dec 9, 2024
1 parent 3d0d8cd commit 1ab6713
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 16 deletions.
8 changes: 8 additions & 0 deletions tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
namespace onnx {
class NodeProto;
class GraphProto;
class ModelProto;
} // namespace onnx

namespace TMVA {
Expand Down Expand Up @@ -68,15 +69,22 @@ public:
std::unique_ptr<ROperator> ParseOperator(const size_t /*index*/, const onnx::GraphProto & /*graphproto*/,
const std::vector<size_t> & /*nodes*/);

// check a graph for missing operators
void CheckGraph(const onnx::GraphProto & g, int & level, std::map<std::string, int> & missingOperators);

// parse the ONNX graph
void ParseONNXGraph(RModel & model, const onnx::GraphProto & g, std::string name = "");

std::unique_ptr<onnx::ModelProto> LoadModel(std::string filename);

public:

RModelParser_ONNX() noexcept;

RModel Parse(std::string filename, bool verbose = false);

// check the model for missing operators - return false in case some operator implementation is missing
bool CheckModel(std::string filename, bool verbose = false);

~RModelParser_ONNX();
};
Expand Down
94 changes: 78 additions & 16 deletions tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,19 @@ RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphpr
RModel RModelParser_ONNX::Parse(std::string filename, bool verbose)
{
fVerbose = verbose;

fTensorTypeMap.clear();

auto model = LoadModel(filename);

const onnx::GraphProto &graph = model->graph(); // not a memory leak. model freed automatically at the end.


std::time_t ttime = std::time(0);
std::tm *gmt_time = std::gmtime(&ttime);
std::string parsetime(std::asctime(gmt_time));

// get name of model (filename without directory name)
char sep = '/';
#ifdef _WIN32
sep = '\\';
Expand All @@ -322,34 +335,83 @@ RModel RModelParser_ONNX::Parse(std::string filename, bool verbose)
filename_nodir = (filename.substr(isep + 1, filename.length() - isep));
}

RModel rmodel(filename_nodir, parsetime);
ParseONNXGraph(rmodel, graph, filename_nodir);
return rmodel;
}

GOOGLE_PROTOBUF_VERIFY_VERSION;
// model I/O
onnx::ModelProto model;

std::unique_ptr<onnx::ModelProto> RModelParser_ONNX::LoadModel(std::string filename) {

fTensorTypeMap.clear();
GOOGLE_PROTOBUF_VERIFY_VERSION;
auto model = std::make_unique<onnx::ModelProto>();

std::fstream input(filename, std::ios::in | std::ios::binary);
if (!model.ParseFromIstream(&input)) {
if (!model->ParseFromIstream(&input)) {
throw std::runtime_error("TMVA::SOFIE - Failed to parse onnx file " + filename);
}

const onnx::GraphProto &graph = model.graph(); // not a memory leak. model freed automatically at the end.
google::protobuf::ShutdownProtobufLibrary();

// ONNX version is ir_version() - model_version() returns 0
if (fVerbose) {
std::cout << "ONNX Version " << model.ir_version() << std::endl;
std::cout << "ONNX Version " << model->ir_version() << std::endl;
}
google::protobuf::ShutdownProtobufLibrary();
return model;

std::time_t ttime = std::time(0);
std::tm *gmt_time = std::gmtime(&ttime);
std::string parsetime(std::asctime(gmt_time));
}

RModel rmodel(filename_nodir, parsetime);
ParseONNXGraph(rmodel, graph, filename_nodir);
return rmodel;
void RModelParser_ONNX::CheckGraph(const onnx::GraphProto & graph, int & level, std::map<std::string, int> & missingOperators) {
if (fVerbose)
std::cout << "\n" << graph.name() << " Graph operator list\n";
for (int i = 0; i < graph.node_size(); i++) {
const auto & node = graph.node(i);
const std::string opType = node.op_type();
if (fVerbose) {
std::cout << "\tOperator " << i << " : " << opType << " (" << node.name() << "), " << graph.node(i).input_size()
<< " inputs : {"
for (int j = 0; j < graph.node(i).input_size(); j++) {
std::cout << graph.node(i).input(j);
if (j < graph.node(i).input_size() - 1)
std::cout << ", ";
}
std::cout << " }" << std::endl;
}
// check if operator exists
if (!IsRegisteredOperator(opType))
missingOperators[opType] = level;
// see if sub-graph exists as node attributes
for (int j = 0; j < node.attribute_size(); j++) {
const auto & attribute = node.attribute(j);
if (attribute.has_g()) {
const auto & subGraph = attribute.g();
level += 1;
CheckGraph(subGraph, level, missingOperators);
}
}
}
}

bool RModelParser_ONNX::CheckModel(std::string filename, bool verbose) {

fVerbose = verbose;
auto model = LoadModel(filename);
const onnx::GraphProto &graph = model->graph();
// Initial operator order
if (fVerbose)
std::cout << "\nModel operator list " << model->producer_name() << "\n";

std::map<std::string, int> missingOperators;
int level = 1;
CheckGraph(graph, level, missingOperators);

if (!missingOperators.empty()) {
std::cout << "List of missing operators for model loaded from file " << filename << std::endl;
for (auto & op : missingOperators) {
std::cout << op.first << " " << op.second << std::endl;
}
return false;
}
std::cout << "All operators in the loaded model are supported!\n";
return true;
}

void RModelParser_ONNX::ParseONNXGraph(RModel & rmodel, const onnx::GraphProto & graph, std::string graphName)
Expand Down

0 comments on commit 1ab6713

Please sign in to comment.