Skip to content

Commit

Permalink
[pymva] Remove import of deprecated torch function
Browse files Browse the repository at this point in the history
  • Loading branch information
lmoneta committed Mar 15, 2024
1 parent 2cb5559 commit 4b6659d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tmva/pymva/src/RModelParser_PyTorch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ RModel Parse(std::string filename, std::vector<std::vector<size_t>> inputShapes,
PyRunString("import torch",fGlobalNS,fLocalNS);
PyRunString("print('Torch Version: '+torch.__version__)",fGlobalNS,fLocalNS);
PyRunString("from torch.onnx.utils import _model_to_graph",fGlobalNS,fLocalNS);
PyRunString("from torch.onnx.symbolic_helper import _set_onnx_shape_inference",fGlobalNS,fLocalNS);
//PyRunString("from torch.onnx.symbolic_helper import _set_onnx_shape_inference",fGlobalNS,fLocalNS);
PyRunString(TString::Format("model= torch.jit.load('%s')",filename.c_str()),fGlobalNS,fLocalNS);
PyRunString("globals().update(locals())",fGlobalNS,fLocalNS);
PyRunString("model.cpu()",fGlobalNS,fLocalNS);
Expand Down

0 comments on commit 4b6659d

Please sign in to comment.