Skip to content

Commit

Permalink
adding another test for coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
hangqianjun committed Sep 19, 2024
1 parent 6836a7a commit af2b439
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
12 changes: 6 additions & 6 deletions src/rail/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,27 +485,27 @@ def _finalize_tag(self, tag):
def _check_column_names(self, data, columns_to_check, **kwargs):
try:
groupname = kwargs.get("groupname", self.config.hdf5_groupname)
except Exception:
groupname = None # pragma: no cover
except Exception: # pragma: no cover
groupname = None

if isinstance(data, DataHandle) and data.has_data == False:
if data.has_path == True:
# data handle only has a path, read the columns from the path
path = data.path
data._check_data_columns(path, columns_to_check, parent_groupname=groupname, **kwargs)
elif data.has_path == False:
print("The data handle does not contain data or path.") # pragma: no cover
elif data.has_path == False: # pragma: no cover
print("The data handle does not contain data or path.")

else:
# data has been read in, access the columns in the table/dictionary directly
if isinstance(data, DataHandle) and data.has_data == True:
if groupname == None:
if groupname in [None, ""]:
col_list = list(data.data.keys())
else:
col_list = list(data.data[groupname].keys())
else:
# data is passed as a table
if groupname == None:
if if groupname in [None, ""]:
col_list = list(data.keys())
else:
col_list = list(data[groupname].keys())
Expand Down
33 changes: 32 additions & 1 deletion tests/estimation/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rail.core.stage import RailStage
from rail.estimation.algos import random_gauss, train_z
from rail.utils.path_utils import RAILDIR
from rail.core.data import PqHandle
from rail.core.data import PqHandle, TableHandle

sci_ver_str = scipy.__version__.split(".")

Expand Down Expand Up @@ -147,6 +147,37 @@ def test_train_pz_with_wrong_columns_table():
redshift_col="REDSHIFT",
)

train_algo = train_z.TrainZInformer
train_pz = train_algo.make_stage(**train_config_dict)
train_pz._get_stage_columns()
with pytest.raises(KeyError):
# testing the case where data is a table
train_pz._check_column_names(training_data3, train_pz.stage_columns)


def test_train_pz_with_wrong_columns_table_wgroupname():

DS = RailStage.data_store
DS.clear()
DS.__class__.allow_overwrite = False

traindata = os.path.join(RAILDIR, "rail/examples_data/testdata/training_100gal.hdf5")

# ! create training data to be a table
# ! however it seems that with set_data() one always reads in the data handle
# ! hence the way we make the data here:
training_data2 = DS.read_file("training_data", TableHandle, traindata)
training_data3 = training_data2.data

train_config_dict = dict(
zmin=0.0,
zmax=3.0,
nzbins=301,
hdf5_groupname="photometry",
model="model_train_z.tmp",
redshift_col="REDSHIFT",
)

train_algo = train_z.TrainZInformer
train_pz = train_algo.make_stage(**train_config_dict)
train_pz._get_stage_columns()
Expand Down

0 comments on commit af2b439

Please sign in to comment.