Skip to content

Commit

Permalink
reformulate super-OPs as derived classes of Steps
Browse files Browse the repository at this point in the history
Signed-off-by: Han Wang <[email protected]>
  • Loading branch information
Han Wang committed Feb 10, 2022
1 parent 453abd9 commit 95df260
Show file tree
Hide file tree
Showing 10 changed files with 598 additions and 228 deletions.
144 changes: 107 additions & 37 deletions dpgen2/flow/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,117 @@
from typing import Set, List
from pathlib import Path

def block_cl(

class ConcurrentLearningBlock(Steps):
def __init__(
self,
name : str,
prep_run_dp_train_op : OP,
prep_run_lmp_op : OP,
select_confs_op : OP,
prep_run_fp_op : OP,
collect_data_op : OP,
select_confs_image : str = "dflow:v1.0",
collect_data_image : str = "dflow:v1.0",
upload_python_package : str = None,
):
self._input_parameters={
"block_id" : InputParameter(),
"type_map" : InputParameter(),
"numb_models": InputParameter(type=int),
"template_script" : InputParameter(),
"train_config" : InputParameter(),
"lmp_task_grp" : InputParameter(),
"lmp_config" : InputParameter(),
"conf_selector" : InputParameter(),
"fp_inputs" : InputParameter(),
"fp_config" : InputParameter(),
}
self._input_artifacts={
"init_models" : InputArtifact(),
"init_data" : InputArtifact(),
"iter_data" : InputArtifact(),
}
self._output_parameters={
"exploration_report": OutputParameter(),
}
self._output_artifacts={
"models": OutputArtifact(),
"iter_data" : OutputArtifact(),
"trajs" : OutputArtifact(),
}

super().__init__(
name = name,
inputs = Inputs(
parameters=self._input_parameters,
artifacts=self._input_artifacts,
),
outputs=Outputs(
parameters=self._output_parameters,
artifacts=self._output_artifacts,
),
)

self._my_keys = ['select-confs', 'collect-data']
self._keys = \
prep_run_dp_train_op.keys + \
prep_run_lmp_op.keys + \
self._my_keys[:1] + \
prep_run_fp_op.keys + \
self._my_keys[1:2]
self.step_keys = {}
for ii in self._my_keys:
self.step_keys[ii] = os.path.join("%s"%self.inputs.parameters["block_id"], ii)

self = _block_cl(
self,
self.step_keys,
name,
prep_run_dp_train_op,
prep_run_lmp_op,
select_confs_op,
prep_run_fp_op,
collect_data_op,
select_confs_image = select_confs_image,
collect_data_image = collect_data_image,
upload_python_package = upload_python_package,
)

@property
def input_parameters(self):
return self._input_parameters

@property
def input_artifacts(self):
return self._input_artifacts

@property
def output_parameters(self):
return self._output_parameters

@property
def output_artifacts(self):
return self._output_artifacts

@property
def keys(self):
return self._keys


def _block_cl(
block_steps : Steps,
step_keys : List[str],
name : str,
prep_run_dp_train_op : OP,
prep_run_lmp_op : OP,
select_confs_op : OP,
prep_run_fp_op : OP,
collect_data_op : OP,
select_confs_image : str = "dflow:v1.0",
collect_data_image : str = "dflow:v1.0",
upload_python_package : str = None,
):
block_steps = Steps(
name = name,
inputs = Inputs(
parameters={
"block_id" : InputParameter(),
"type_map" : InputParameter(),
"numb_models": InputParameter(type=int),
"template_script" : InputParameter(),
"train_config" : InputParameter(),
"lmp_task_grp" : InputParameter(),
"lmp_config" : InputParameter(),
"conf_selector" : InputParameter(),
"fp_inputs" : InputParameter(),
"fp_config" : InputParameter(),
},
artifacts={
"init_models" : InputArtifact(),
"init_data" : InputArtifact(),
"iter_data" : InputArtifact(),
},
),
outputs=Outputs(
parameters={
"exploration_report": OutputParameter(),
},
artifacts={
"models": OutputArtifact(),
"iter_data" : OutputArtifact(),
"trajs" : OutputArtifact(),
},
),
)

prep_run_dp_train = Step(
name + '-prep-run-dp-train',
Expand Down Expand Up @@ -106,7 +176,7 @@ def block_cl(
name = name + '-select-confs',
template=PythonOPTemplate(
select_confs_op,
image="dflow:v1.0",
image=select_confs_image,
output_artifact_archive={
"confs": None
},
Expand All @@ -121,7 +191,7 @@ def block_cl(
"trajs" : prep_run_lmp.outputs.artifacts['trajs'],
"model_devis" : prep_run_lmp.outputs.artifacts['model_devis'],
},
key = os.path.join("%s"%block_steps.inputs.parameters["block_id"], "select-conf"),
key = step_keys['select-confs'],
)
block_steps.add(select_confs)

Expand All @@ -144,7 +214,7 @@ def block_cl(
name = name + '-collect-data',
template=PythonOPTemplate(
collect_data_op,
image="dflow:v1.0",
image=collect_data_image,
output_artifact_archive={
"iter_data": None
},
Expand All @@ -157,7 +227,7 @@ def block_cl(
"iter_data" : block_steps.inputs.artifacts['iter_data'],
"labeled_data" : prep_run_fp.outputs.artifacts['labeled_data'],
},
key = os.path.join("%s"%block_steps.inputs.parameters["block_id"], "collect-data"),
key = step_keys['collect-data'],
)
block_steps.add(collect_data)

Expand Down
Loading

0 comments on commit 95df260

Please sign in to comment.