Skip to content

Commit

Permalink
[tutorials][tmva] Use different name for header and weights in some S…
Browse files Browse the repository at this point in the history
…OFIE tutorials
  • Loading branch information
lmoneta committed Apr 3, 2024
1 parent 3990d95 commit 3421389
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
5 changes: 2 additions & 3 deletions tutorials/tmva/TMVA_SOFIE_Inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# parse the input Keras model into RModel object
model = ROOT.TMVA.Experimental.SOFIE.PyKeras.Parse(modelFile)

generatedHeaderFile = modelFile.replace(".h5",".hxx")
generatedHeaderFile = "generatedSofieHiggsModel.hxx" #modelFile.replace(".h5",".hxx")
print("Generating inference code for the Keras model from ",modelFile,"in the header ", generatedHeaderFile)
#Generating inference code
model.Generate()
Expand All @@ -44,7 +44,6 @@
ROOT.gInterpreter.Declare('#include "' + generatedHeaderFile + '"')


generatedHeaderFile = modelFile.replace(".h5",".hxx")
print("Generating inference code for the Keras model from ",modelFile,"in the header ", generatedHeaderFile)
#Generating inference

Expand All @@ -67,7 +66,7 @@
print("size of data", dataset_size)

#instantiate SOFIE session class
session = ROOT.TMVA_SOFIE_Higgs_trained_model.Session()
session = ROOT.TMVA_SOFIE_Higgs_trained_model.Session("generatedSofieHiggsModel.dat")

hs = ROOT.TH1D("hs","Signal result",100,0,1)
for i in range(0,dataset_size):
Expand Down
6 changes: 3 additions & 3 deletions tutorials/tmva/TMVA_SOFIE_RDataFrame.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@

# generating inference code
model.Generate()
model.OutputGenerated("Higgs_trained_model.hxx")
model.OutputGenerated("Higgs_generated_Sofie_model.hxx")
model.PrintGenerated()

# compile using ROOT JIT trained model
print("compiling SOFIE model and functor....")
ROOT.gInterpreter.Declare('#include "Higgs_trained_model.hxx"')
ROOT.gInterpreter.Declare('#include "Higgs_generated_Sofie_model.hxx"')
modelName = "Higgs_trained_model"
ROOT.gInterpreter.Declare('auto sofie_functor = TMVA::Experimental::SofieFunctor<7,TMVA_SOFIE_'+modelName+'::Session>(0);')
ROOT.gInterpreter.Declare('auto sofie_functor = TMVA::Experimental::SofieFunctor<7,TMVA_SOFIE_'+modelName+'::Session>(0,"Higgs_generated_Sofie_model.dat");')

# run inference over input data
inputFile = "http://root.cern/files/Higgs_data.root"
Expand Down

0 comments on commit 3421389

Please sign in to comment.