From 7afa00cd2b70bad80250b7e72fe341844ddbe68a Mon Sep 17 00:00:00 2001 From: Ned Molter Date: Mon, 23 Sep 2024 17:55:41 -0400 Subject: [PATCH] make Sequence calls work with SourceModelContainer --- src/stpipe/step.py | 74 +++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/src/stpipe/step.py b/src/stpipe/step.py index 7f0d508a..80cd962c 100644 --- a/src/stpipe/step.py +++ b/src/stpipe/step.py @@ -560,34 +560,7 @@ def run(self, *args): # Save the output file if one was specified if not self.skip and self.save_results: - # Setup the save list. - if not isinstance(step_result, Sequence): - results_to_save = [step_result] - else: - results_to_save = step_result - - for idx, result in enumerate(results_to_save): - if len(results_to_save) <= 1: - idx = None - if isinstance( - result, (AbstractDataModel | AbstractModelLibrary) - ): - self.save_model(result, idx=idx) - elif hasattr(result, "save"): - try: - output_path = self.make_output_path(idx=idx) - except AttributeError: - self.log.warning( - "`save_results` has been requested, but cannot" - " determine filename." - ) - self.log.warning( - "Specify an output file with `--output_file` or set" - " `--save_results=false`" - ) - else: - self.log.info("Saving file %s", output_path) - result.save(output_path, overwrite=True) + self.save_model(step_result) if not self.skip: self.log.info("Step %s done", self.name) @@ -995,6 +968,9 @@ def save_model( if not force and not self.save_results and not output_file: return None + if model is None: + return None + if isinstance(model, AbstractModelLibrary): output_paths = [] with model: @@ -1013,20 +989,36 @@ def save_model( return output_paths elif isinstance(model, Sequence): - output_paths = [] - for i, m in enumerate(model): - output_paths.append( - self.save_model( - m, - idx=i, - suffix=suffix, - force=force, - **components, + if not hasattr(model, "save"): + # list of datamodels, e.g. ModelContainer + output_paths = [] + for i, m in enumerate(model): + idx = None if len(model) == 1 else i + output_paths.append( + self.save_model( + m, + idx=idx, + suffix=suffix, + force=force, + **components, + ) ) + return output_paths + else: + # JWST SourceModelContainer takes this path + save_model_func = partial( + self.save_model, + suffix=suffix, + force=force, + **components, ) - return output_paths + output_path = model.save( + path=output_file, + save_model_func=save_model_func, + ) + return output_path - else: + elif hasattr(model, "save"): # Search for an output file name. if self.output_use_model or ( output_file is None and not self.search_output_file @@ -1042,8 +1034,10 @@ def save_model( ) ) self.log.info("Saved model in %s", output_path) + return output_path - return output_path + else: + return @property def make_output_path(self):