Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes #217 #218

Merged
merged 6 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scabha/cargo.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ class Parameter(object):
# list of aliases for this parameter (i.e. references to other parameters whose schemas/values this parameter shares)
aliases: Optional[List[str]] = ()

# if true, treat parameter as a path, and ensure that the parent directories it refers to exist
mkdir: bool = False
# if true, create parent directories of file-type outputs if needed
mkdir: bool = True

# if True, and parameter is a path, access to its parent directory is required
access_parent_dir: bool = False
Expand Down
3 changes: 3 additions & 0 deletions scabha/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def clickify_parameters(schemas: Union[str, Dict[str, Any]]):
decorator_chain = None
for io in schemas.inputs, schemas.outputs:
for name, schema in io.items():
# skip outputs, unless they're named outputs
if io is schemas.outputs and not (schema.is_file_type and not schema.implicit):
continue

name = name.replace("_", "-")
optname = f"--{name}"
Expand Down
20 changes: 15 additions & 5 deletions scabha/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def validate_parameters(params: Dict[str, Any], schemas: Dict[str, Any],
# convert this to a pydantic dataclass which does validation
pcls = pydantic.dataclasses.dataclass(dcls)

# check Files etc. and expand globs
# check Files etc.
for name, value in list(inputs.items()):
# get schema from those that need validation, skip if not in schemas
schema = schemas.get(name)
Expand Down Expand Up @@ -227,6 +227,8 @@ def validate_parameters(params: Dict[str, Any], schemas: Dict[str, Any],
files = value
else:
raise ParameterValidationError(f"'{mkname(name)}={value}': invalid type '{type(value)}'")
# expand ~
files = [os.path.expanduser(f) for f in files]

# check for existence of all files in list, if needed
if must_exist:
Expand Down Expand Up @@ -289,10 +291,18 @@ def validate_parameters(params: Dict[str, Any], schemas: Dict[str, Any],
# check for mkdir directives
if create_dirs:
for name, value in validated.items():
if schemas[name].mkdir and isinstance(value, str):
dirname = os.path.dirname(value)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
schema = schemas[name]
if schema.is_output and schema.mkdir:
if schema.is_file_type:
files = [value]
elif schema.is_file_list_type:
files = value
else:
continue
for path in files:
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)

# add in unresolved values
validated.update(**unresolved)
Expand Down
3 changes: 2 additions & 1 deletion stimela/backends/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def get_executable(self):
def run_command_wrapper(self, args: List[str], fqname: Optional[str]=None, log: Optional[logging.Logger]=None) -> List[str]:
output_args = [self.get_executable()]

# reverse fqname to make job name (more informative that way)
if fqname is not None:
output_args += ["-J", fqname]
output_args += ["-J", '.'.join(fqname.split('.')[::-1])]

# add all base options that have been specified
for name, value in self.srun_opts.items():
Expand Down
30 changes: 21 additions & 9 deletions stimela/kitchen/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def finalize(self, config=None, log=None, name=None, fqname=None, backend=None,
self.logopts = config.opts.log.copy()

# update file logger
logsubst = SubstitutionNS(config=config, info=dict(fqname=fqname))
logsubst = SubstitutionNS(config=config, info=dict(fqname=fqname, taskname=fqname))
stimelogging.update_file_logger(log, self.logopts, nesting=nesting, subst=logsubst, location=[self.fqname])

# call Cargo's finalize method
Expand Down Expand Up @@ -744,7 +744,7 @@ def prevalidate(self, params: Dict[str, Any], subst: Optional[SubstitutionNS]=No
subst_outer = subst # outer dictionary is used to prevalidate our parameters

subst = SubstitutionNS()
info = SubstitutionNS(fqname=self.fqname, label='', label_parts=[], suffix='')
info = SubstitutionNS(fqname=self.fqname, taskname=self.fqname, label='', label_parts=[], suffix='')
# mutable=False means these sub-namespaces are not subject to {}-substitutions
subst._add_('info', info.copy(), nosubst=True)
subst._add_('config', self.config, nosubst=True)
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def _update_aliases(self, name: str, value: Any):
alias.step.update_parameter(alias.param, value)


def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, subprocess=False, raise_exc=True):
def _iterate_loop_worker(self, params, subst, backend, count, iter_var, subprocess=False, raise_exc=True):
""""
Needed for concurrency
"""
Expand All @@ -1046,6 +1046,7 @@ def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, su
task_stats.add_subprocess_id(count)
task_stats.destroy_progress_bar()
subst.info.subprocess = task_stats.get_subprocess_id()
taskname = subst.info.taskname
outputs = {}
exception = tb = None
task_attrs, task_kwattrs = (), {}
Expand Down Expand Up @@ -1075,6 +1076,8 @@ def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, su
if status is None:
status = "{index1}/{total}".format(**status_dict)
task_stats.declare_subtask_status(status)
taskname = f"{taskname}.{count}"
subst.info.taskname = taskname
# task_stats.declare_subtask_attributes(count)
# task_attrs = (count,)
context = task_stats.declare_subtask(f"({count})")
Expand All @@ -1085,16 +1088,20 @@ def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, su
for label, step in self.steps.items():
# update step info
self._prep_step(label, step, subst)
subst.info.taskname = f"{taskname}.{label}"
# reevaluate recipe level assignments (info.fqname etc. have changed)
self.update_assignments(subst, params=params)
# evaluate step-level assignments
self.update_assignments(subst, whose=step, params=params)
# step logger may have changed
stimelogging.update_file_logger(step.log, step.logopts, nesting=step.nesting, subst=subst, location=[step.fqname])
# set our info back temporarily to update log assignments
info_step = subst.info
subst.info = info.copy()
subst.info = info_step

## OMS: note to self, I had this here but not sure why. Seems like a no-op. Something with logname fiddling.
## Leave as a puzzle to future self for a bit. Remove info from args.
# info_step = subst.info
# subst.info = info.copy()
# subst.info = info_step

if step.skip is True:
self.log.debug(f"step '{label}' will be explicitly skipped")
Expand Down Expand Up @@ -1137,7 +1144,7 @@ def _iterate_loop_worker(self, params, info, subst, backend, count, iter_var, su
# else will be returned
exception = exc
tb = FormattedTraceback(sys.exc_info()[2])

return task_attrs, task_kwattrs, task_stats.collect_stats(), outputs, exception, tb

def build(self, backend={}, rebuild=False, build_skips=False, log: Optional[logging.Logger] = None):
Expand Down Expand Up @@ -1172,8 +1179,11 @@ def _run(self, params, subst=None, backend={}) -> Dict[str, Any]:
subst_outer = subst
if subst is None:
subst = SubstitutionNS()
taskname = self.name
else:
taskname = subst.info.taskname

info = SubstitutionNS(fqname=self.fqname, label='', label_parts=[], suffix='')
info = SubstitutionNS(fqname=self.fqname, label='', label_parts=[], suffix='', taskname=taskname)
# nosubst=True means these sub-namespaces are not subject to {}-substitutions
subst._add_('info', info.copy(), nosubst=True)
subst._add_('config', self.config, nosubst=True)
Expand Down Expand Up @@ -1224,7 +1234,7 @@ def _run(self, params, subst=None, backend={}) -> Dict[str, Any]:
# form list of arguments for each invocation of the loop worker
loop_worker_args = []
for count, iter_var in enumerate(self._for_loop_values):
loop_worker_args.append((params, info, subst, backend, count, iter_var))
loop_worker_args.append((params, subst, backend, count, iter_var))

# if scatter is enabled, use a process pool
if self._for_loop_scatter:
Expand Down Expand Up @@ -1264,6 +1274,8 @@ def _run(self, params, subst=None, backend={}) -> Dict[str, Any]:
if errors:
pool.shutdown()
raise StimelaRuntimeError(f"{nfail}/{nloop} jobs have failed", errors)
# drop a rendering of the progress bar onto the console, to overwrite previous garbage if it's there
task_stats.restate_progress()
# else just iterate directly
else:
for args in loop_worker_args:
Expand Down
2 changes: 1 addition & 1 deletion stimela/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def cli(config_files=[], config_dotlist=[], include=[], backend=None,
stimela.CONFIG.opts.log.level = "DEBUG"
# setup file logging
subst = OmegaConf.create(dict(
info=OmegaConf.create(dict(fqname='stimela')),
info=OmegaConf.create(dict(fqname='stimela', taskname='stimela')),
config=stimela.CONFIG))
stimelogging.update_file_logger(log, stimela.CONFIG.opts.log, nesting=-1, subst=subst)

Expand Down
7 changes: 7 additions & 0 deletions stimela/task_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def destroy_progress_bar():
progress_bar.__exit__(None, None, None)
progress_bar = None

def restate_progress():
"""Renders a snapshot of the progress bar onto the console"""
if progress_bar is not None:
progress_console.print(progress_bar.get_renderable())
progress_console.rule()


@contextlib.contextmanager
def declare_subtask(subtask_name, status_reporter=None, hide_local_metrics=False):
task_names = []
Expand Down
Loading