diff --git a/hepdata_lib/root_utils.py b/hepdata_lib/root_utils.py index 87b243ec..c247b549 100644 --- a/hepdata_lib/root_utils.py +++ b/hepdata_lib/root_utils.py @@ -37,13 +37,13 @@ def tfile(self, tfile): raise RuntimeError( "RootFileReader: Input file is not a ROOT file (name does not end in .root)!" ) - if check_file_existence(tfile): - self._tfile = TFile(tfile) + check_file_existence(tfile) + self._tfile = TFile(tfile) elif isinstance(tfile, TFile): self._tfile = tfile else: raise ValueError( - "RootReader: Encountered unkonown type of variable passed as tfile argument: " + "RootReader: Encountered unknown type of variable passed as tfile argument: " + str(type(tfile))) if not self._tfile: diff --git a/tests/test_rootfilereader.py b/tests/test_rootfilereader.py index 7ede5f49..509ed671 100644 --- a/tests/test_rootfilereader.py +++ b/tests/test_rootfilereader.py @@ -6,7 +6,7 @@ import ctypes import numpy as np import ROOT -from hepdata_lib.root_utils import RootFileReader +from hepdata_lib.root_utils import RootFileReader, get_graph_points from .test_utilities import float_compare, tuple_compare, histogram_compare_1d, make_tmp_root_file @@ -18,11 +18,11 @@ def test_tfile_setter(self): Test the behavior of the RootFileReader member setters. """ - # Check with nonexistant file that ends in .root + # Check with nonexistent file that ends in .root with self.assertRaises(RuntimeError): _reader = RootFileReader("/path/to/nowhere/butEndsIn.root") - # Check with existant file that does not end in .root + # Check with existing file that does not end in .root path_to_file = "test.txt" self.addCleanup(os.remove, path_to_file) @@ -42,9 +42,11 @@ def test_tfile_setter(self): # Finally, try a good call path_to_file = make_tmp_root_file(close=True, testcase=self) + tfile = make_tmp_root_file(testcase=self) try: _reader = RootFileReader(path_to_file) + _reader = RootFileReader(tfile) # pylint: disable=W0702 except: self.fail("RootFileReader raised an unexpected exception.") @@ -817,3 +819,8 @@ def test_retrieve_object_canvas_tpad(self): # Clean up self.doCleanups() + + def test_get_graph_points(self): + '''Check that get_graph_points with input not a TGraph (or similar) gives an exception.''' + with self.assertRaises(TypeError): + get_graph_points(100)