Skip to content

Commit

Permalink
make Sequence calls work with SourceModelContainer
Browse files Browse the repository at this point in the history
  • Loading branch information
emolter committed Sep 23, 2024
1 parent 20f0feb commit 7afa00c
Showing 1 changed file with 34 additions and 40 deletions.
74 changes: 34 additions & 40 deletions src/stpipe/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,34 +560,7 @@ def run(self, *args):

# Save the output file if one was specified
if not self.skip and self.save_results:
# Setup the save list.
if not isinstance(step_result, Sequence):
results_to_save = [step_result]
else:
results_to_save = step_result

for idx, result in enumerate(results_to_save):
if len(results_to_save) <= 1:
idx = None
if isinstance(
result, (AbstractDataModel | AbstractModelLibrary)
):
self.save_model(result, idx=idx)
elif hasattr(result, "save"):
try:
output_path = self.make_output_path(idx=idx)
except AttributeError:
self.log.warning(
"`save_results` has been requested, but cannot"
" determine filename."
)
self.log.warning(
"Specify an output file with `--output_file` or set"
" `--save_results=false`"
)
else:
self.log.info("Saving file %s", output_path)
result.save(output_path, overwrite=True)
self.save_model(step_result)

if not self.skip:
self.log.info("Step %s done", self.name)
Expand Down Expand Up @@ -995,6 +968,9 @@ def save_model(
if not force and not self.save_results and not output_file:
return None

if model is None:
return None

if isinstance(model, AbstractModelLibrary):
output_paths = []
with model:
Expand All @@ -1013,20 +989,36 @@ def save_model(
return output_paths

elif isinstance(model, Sequence):
output_paths = []
for i, m in enumerate(model):
output_paths.append(
self.save_model(
m,
idx=i,
suffix=suffix,
force=force,
**components,
if not hasattr(model, "save"):
# list of datamodels, e.g. ModelContainer
output_paths = []
for i, m in enumerate(model):
idx = None if len(model) == 1 else i
output_paths.append(
self.save_model(
m,
idx=idx,
suffix=suffix,
force=force,
**components,
)
)
return output_paths
else:
# JWST SourceModelContainer takes this path
save_model_func = partial(
self.save_model,
suffix=suffix,
force=force,
**components,
)
return output_paths
output_path = model.save(
path=output_file,
save_model_func=save_model_func,
)
return output_path

else:
elif hasattr(model, "save"):
# Search for an output file name.
if self.output_use_model or (
output_file is None and not self.search_output_file
Expand All @@ -1042,8 +1034,10 @@ def save_model(
)
)
self.log.info("Saved model in %s", output_path)
return output_path

return output_path
else:
return

@property
def make_output_path(self):
Expand Down

0 comments on commit 7afa00c

Please sign in to comment.