Skip to content

Commit

Permalink
revert big changes to step save logic
Browse files Browse the repository at this point in the history
  • Loading branch information
emolter committed Oct 8, 2024
1 parent 843c18c commit ec7e160
Showing 1 changed file with 44 additions and 37 deletions.
81 changes: 44 additions & 37 deletions src/stpipe/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,37 @@ def run(self, *args):

# Save the output file if one was specified
if not self.skip and self.save_results:
self.save_model(step_result)
# Setup the save list.
if isinstance(step_result, Sequence):
if hasattr(step_result, "save") or isinstance(step_result, str):
results_to_save = [step_result]
else:
results_to_save = step_result
else:
results_to_save = [step_result]

for idx, result in enumerate(results_to_save):
if len(results_to_save) <= 1:
idx = None
if isinstance(
result, (AbstractDataModel | AbstractModelLibrary)
):
self.save_model(result, idx=idx)
elif hasattr(result, "save"):
try:
output_path = self.make_output_path(idx=idx)
except AttributeError:
self.log.warning(
"`save_results` has been requested, but cannot"
" determine filename."
)
self.log.warning(
"Specify an output file with `--output_file` or set"
" `--save_results=false`"
)
else:
self.log.info("Saving file %s", output_path)
result.save(output_path, overwrite=True)

if not self.skip:
self.log.info("Step %s done", self.name)
Expand Down Expand Up @@ -988,39 +1018,18 @@ def save_model(
model.shelve(m, i)
return output_paths

elif isinstance(model, Sequence) and not isinstance(model, str):
if not hasattr(model, "save"):
# list of datamodels, e.g. JWST ModelContainer
output_paths = []
for i, m in enumerate(model):
# ignore list of lists. individual steps should handle this
if not isinstance(m, Sequence):
idx = None if len(model) == 1 else i
output_paths.append(
self.save_model(
m,
idx=idx,
suffix=suffix,
force=force,
**components,
)
)
return output_paths
else:
# JWST SourceModelContainer takes this path
save_model_func = partial(
self.save_model,
suffix=suffix,
force=force,
**components,
)
output_path = model.save(
path=output_file,
save_model_func=save_model_func,
)
return output_path

elif hasattr(model, "save"):
elif isinstance(model, Sequence):
save_model_func = partial(
self.save_model,
suffix=suffix,
force=force,
**components,
)
output_path = model.save(
path=output_file,
save_model_func=save_model_func,
)
else:
# Search for an output file name.
if self.output_use_model or (
output_file is None and not self.search_output_file
Expand All @@ -1036,10 +1045,8 @@ def save_model(
)
)
self.log.info("Saved model in %s", output_path)
return output_path

else:
return
return output_path

@property
def make_output_path(self):
Expand Down

0 comments on commit ec7e160

Please sign in to comment.