diff --git a/changes/190.misc.rst b/changes/190.misc.rst new file mode 100644 index 00000000..5cfbcf04 --- /dev/null +++ b/changes/190.misc.rst @@ -0,0 +1 @@ +improve support for list-like input into Step diff --git a/src/stpipe/step.py b/src/stpipe/step.py index f956c454..54333786 100644 --- a/src/stpipe/step.py +++ b/src/stpipe/step.py @@ -491,31 +491,37 @@ def run(self, *args): e, ) library.shelve(model, i) - elif isinstance(args[0], AbstractDataModel): - if self.class_alias is not None: - if isinstance(args[0], Sequence): - for model in args[0]: - try: - model[f"meta.cal_step.{self.class_alias}"] = ( - "SKIPPED" - ) - except AttributeError as e: # noqa: PERF203 - self.log.info( - "Could not record skip into DataModel " - "header: %s", - e, - ) - elif isinstance(args[0], AbstractDataModel): + + elif ( + (isinstance(args[0], Sequence)) + and (not isinstance(args[0], str)) + and (self.class_alias is not None) + ): + # handle ModelContainer or list of models + if args[0] and isinstance(args[0][0], AbstractDataModel): + for model in args[0]: try: - args[0][ - f"meta.cal_step.{self.class_alias}" - ] = "SKIPPED" + setattr( + model.meta.cal_step, self.class_alias, "SKIPPED" + ) except AttributeError as e: self.log.info( - "Could not record skip into DataModel" - " header: %s", + "Could not record skip into DataModel " + "header: %s", e, ) + + elif ( + isinstance(args[0], AbstractDataModel) + and self.class_alias is not None + ): + try: + args[0][f"meta.cal_step.{self.class_alias}"] = "SKIPPED" + except AttributeError as e: + self.log.info( + "Could not record skip into DataModel header: %s", + e, + ) step_result = args[0] else: if self.prefetch_references: @@ -558,10 +564,13 @@ def run(self, *args): # Save the output file if one was specified if not self.skip and self.save_results: # Setup the save list. - if not isinstance(step_result, list | tuple): - results_to_save = [step_result] + if isinstance(step_result, Sequence): + if hasattr(step_result, "save") or isinstance(step_result, str): + results_to_save = [step_result] + else: + results_to_save = step_result else: - results_to_save = step_result + results_to_save = [step_result] for idx, result in enumerate(results_to_save): if len(results_to_save) <= 1: @@ -992,6 +1001,9 @@ def save_model( if not force and not self.save_results and not output_file: return None + if model is None: + return None + if isinstance(model, AbstractModelLibrary): output_paths = [] with model: @@ -1008,6 +1020,7 @@ def save_model( # leaving modify=True in case saving modify the file model.shelve(m, i) return output_paths + elif isinstance(model, Sequence): save_model_func = partial( self.save_model, diff --git a/tests/test_step.py b/tests/test_step.py index 972f84a9..88197e25 100644 --- a/tests/test_step.py +++ b/tests/test_step.py @@ -1,7 +1,9 @@ """Test step.Step""" +import copy import logging import re +from collections.abc import Sequence from typing import ClassVar import asdf @@ -9,6 +11,7 @@ import stpipe.config_parser as cp from stpipe import cmdline +from stpipe.datamodel import AbstractDataModel from stpipe.pipeline import Pipeline from stpipe.step import Step @@ -411,3 +414,159 @@ def test_log_records(): pipeline.run() assert any(r == "This step has called out a warning." for r in pipeline.log_records) + + +class StepWithModel(Step): + """A step that immediately saves the model it gets passed in""" + + spec = """ + output_ext = string(default='simplestep') + save_results = boolean(default=True) + """ + + def process(self, input_model): + # make a change to ensure step skip is working + # without having to define SimpleDataModel.meta.stepname + if isinstance(input_model, SimpleDataModel): + input_model.stepstatus = "COMPLETED" + elif isinstance(input_model, SimpleContainer): + for model in input_model: + model.stepstatus = "COMPLETED" + return input_model + + +class SimpleDataModel(AbstractDataModel): + """A simple data model""" + + @property + def crds_observatory(self): + return "jwst" + + def get_crds_parameters(self): + return {"test": "none"} + + def save(self, path, dir_path=None, *args, **kwargs): + saveid = getattr(self, "saveid", None) + if saveid is not None: + fname = saveid + "-saved.txt" + with open(fname, "w") as f: + f.write(f"{path}") + return fname + return None + + +def test_save(tmp_cwd): + + model = SimpleDataModel() + model.saveid = "test" + step = StepWithModel() + step.run(model) + assert (tmp_cwd / "test-saved.txt").exists() + + +def test_skip(): + model = SimpleDataModel() + step = StepWithModel() + step.skip = True + out = step.run(model) + assert not hasattr(out, "stepstatus") + assert out is model + + +@pytest.fixture(scope="function") +def model_list(): + model = SimpleDataModel() + model_list = [copy.deepcopy(model) for _ in range(3)] + for i, model in enumerate(model_list): + model.saveid = f"test{i}" + return model_list + + +def test_save_list(tmp_cwd, model_list): + step = StepWithModel() + step.run(model_list) + for i in range(3): + assert (tmp_cwd / f"test{i}-saved.txt").exists() + + +class SimpleContainer(Sequence): + + def __init__(self, models): + self._models = models + + def __len__(self): + return len(self._models) + + def __getitem__(self, idx): + return self._models[idx] + + def __iter__(self): + yield from self._models + + def insert(self, index, model): + self._models.insert(index, model) + + def append(self, model): + self._models.append(model) + + def extend(self, model): + self._models.extend(model) + + def pop(self, index=-1): + self._models.pop(index) + + +class SimpleContainerWithSave(SimpleContainer): + + def save(self, path, dir_path=None, *args, **kwargs): + for model in self._models[1:]: + # skip the first model to test that the save method is called + # rather than just looping over all models like in the without-save case + model.save(path, dir_path, *args, **kwargs) + + +def test_save_container(tmp_cwd, model_list): + """ensure list-like save still works for non-list sequence""" + container = SimpleContainer(model_list) + step = StepWithModel() + step.run(container) + for i in range(3): + assert (tmp_cwd / f"test{i}-saved.txt").exists() + + +def test_skip_container(tmp_cwd, model_list): + step = StepWithModel() + step.skip = True + out = step.run(model_list) + assert not hasattr(out, "stepstatus") + for i, model in enumerate(out): + assert not hasattr(model, "stepstatus") + assert model_list[i] is model + + +def test_save_container_with_save_method(tmp_cwd, model_list): + """ensure list-like save still works for non-list sequence""" + container = SimpleContainerWithSave(model_list) + step = StepWithModel() + step.run(container) + assert not (tmp_cwd / "test0-saved.txt").exists() + assert (tmp_cwd / "test1-saved.txt").exists() + assert (tmp_cwd / "test2-saved.txt").exists() + + +def test_save_tuple_with_nested_list(tmp_cwd, model_list): + """ + in rare cases, multiple outputs are returned from step as tuple. + One example is the jwst badpix_selfcal step, which returns one sci exposure + and a list containing an arbitrary number of background exposures. + Expected behavior in this case, at least at time of writing, is to save the + science exposure and ignore the list + """ + single_model = SimpleDataModel() + single_model.saveid = "test" + container = (single_model, model_list) + step = StepWithModel() + step.run(container) + assert (tmp_cwd / "test-saved.txt").exists() + for i in range(3): + assert not (tmp_cwd / f"test{i}-saved.txt").exists()