Skip to content

Commit

Permalink
Import ROOT functions only where needed
Browse files Browse the repository at this point in the history
* Allows basic hepdata_lib functionality to work without ROOT.
* Need to disable pylint "import-outside-toplevel" error.
  • Loading branch information
GraemeWatt committed Mar 15, 2024
1 parent eb392fe commit f842b3a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
5 changes: 4 additions & 1 deletion hepdata_lib/c_file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import io
from array import array
from future.utils import raise_from
from ROOT import TGraph, TGraphErrors
import hepdata_lib.root_utils as ru
from hepdata_lib.helpers import check_file_existence

Expand Down Expand Up @@ -137,6 +136,8 @@ def create_tgrapherrors(self, x_value, y_value, dx_value, dy_value):
"""Function to create pyroot TGraphErrors object"""
# pylint: disable=no-self-use

from ROOT import TGraphErrors # pylint: disable=import-outside-toplevel

# Creating pyroot TGraphErrors object
x_values = array('i')
y_values = array('i')
Expand Down Expand Up @@ -173,6 +174,8 @@ def create_tgraph(self, x_value, y_value):
"""Function to create pyroot TGraph object"""
# pylint: disable=no-self-use

from ROOT import TGraph # pylint: disable=import-outside-toplevel

# Creating pyroot TGraph object
x_values = array('i')
y_values = array('i')
Expand Down
26 changes: 17 additions & 9 deletions hepdata_lib/root_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import ctypes
from future.utils import raise_from
import numpy as np
import ROOT as r
from hepdata_lib.helpers import check_file_existence

class RootFileReader:
Expand Down Expand Up @@ -31,14 +30,16 @@ def tfile(self, tfile):
Can either be an already open TFile or a path to the file on disk.
:type tfile: TFile or str
"""
from ROOT import TFile # pylint: disable=import-outside-toplevel

if isinstance(tfile, str):
if not tfile.endswith(".root"):
raise RuntimeError(
"RootFileReader: Input file is not a ROOT file (name does not end in .root)!"
)
if check_file_existence(tfile):
self._tfile = r.TFile(tfile)
elif isinstance(tfile, r.TFile):
self._tfile = TFile(tfile)
elif isinstance(tfile, TFile):
self._tfile = tfile
else:
raise ValueError(
Expand Down Expand Up @@ -216,8 +217,10 @@ def read_tree(self, path_to_tree, branch_name):
:returns: list -- The values saved in the tree branch.
"""
from ROOT import TTree # pylint: disable=import-outside-toplevel

tree = self.tfile.Get(path_to_tree)
if not tree or not isinstance(tree, r.TTree):
if not tree or not isinstance(tree, TTree):
raise RuntimeError(f"No TTree found for path '{path_to_tree}'.")
values = []
for event in tree:
Expand Down Expand Up @@ -293,6 +296,8 @@ def get_hist_2d_points(hist, **kwargs):
Symmetric errors are returned if the histogram error option
TH1::GetBinErrorOption() returns TH1::kNormal.
"""
from ROOT import TH1 # pylint: disable=import-outside-toplevel

xlim = kwargs.pop('xlim', (None, None))
ylim = kwargs.pop('ylim', (None, None))
force_symmetric_errors = kwargs.pop('force_symmetric_errors', False)
Expand All @@ -317,7 +322,7 @@ def get_hist_2d_points(hist, **kwargs):
ixmax = hist.GetXaxis().FindBin(xlim[1]) if xlim[1] is not None else hist.GetNbinsX() + 1
iymin = hist.GetYaxis().FindBin(ylim[0]) if ylim[0] is not None else 1
iymax = hist.GetYaxis().FindBin(ylim[1]) if ylim[1] is not None else hist.GetNbinsY() + 1
symmetric = (hist.GetBinErrorOption() == r.TH1.kNormal)
symmetric = (hist.GetBinErrorOption() == TH1.kNormal)
if force_symmetric_errors:
symmetric = True
for x_bin in range(ixmin, ixmax):
Expand Down Expand Up @@ -374,6 +379,8 @@ def get_hist_1d_points(hist, **kwargs):
Symmetric errors are returned if the histogram error option
TH1::GetBinErrorOption() returns TH1::kNormal.
"""
from ROOT import TH1 # pylint: disable=import-outside-toplevel

xlim = kwargs.pop('xlim', (None, None))
force_symmetric_errors = kwargs.pop('force_symmetric_errors', False)
if kwargs:
Expand All @@ -388,7 +395,7 @@ def get_hist_1d_points(hist, **kwargs):
for key in ["x", "y", "x_edges", "x_labels", "dy"]:
points[key] = []

symmetric = (hist.GetBinErrorOption() == r.TH1.kNormal)
symmetric = (hist.GetBinErrorOption() == TH1.kNormal)
if force_symmetric_errors:
symmetric = True
ixmin = hist.GetXaxis().FindBin(xlim[0]) if xlim[0] is not None else 1
Expand Down Expand Up @@ -428,9 +435,10 @@ def get_graph_points(graph):
For asymmetric errors, a list of tuples of (down,up) values is given.
"""
from ROOT import TGraph, TGraphErrors, TGraphAsymmErrors # pylint: disable=import-outside-toplevel

# Check input
if not isinstance(graph, (r.TGraph, r.TGraphErrors, r.TGraphAsymmErrors)):
if not isinstance(graph, (TGraph, TGraphErrors, TGraphAsymmErrors)):
raise TypeError(
"Expected to input to be TGraph or similar, instead got '{0}'".
format(type(graph)))
Expand All @@ -444,10 +452,10 @@ def get_graph_points(graph):
graph.GetPoint(i, x_val, y_val)
points["x"].append(float(x_val.value))
points["y"].append(float(y_val.value))
if isinstance(graph, r.TGraphErrors):
if isinstance(graph, TGraphErrors):
points["dx"].append(graph.GetErrorX(i))
points["dy"].append(graph.GetErrorY(i))
elif isinstance(graph, r.TGraphAsymmErrors):
elif isinstance(graph, TGraphAsymmErrors):
points["dx"].append((-graph.GetErrorXlow(i),
graph.GetErrorXhigh(i)))
points["dy"].append((-graph.GetErrorYlow(i),
Expand Down

0 comments on commit f842b3a

Please sign in to comment.