Skip to content

Commit

Permalink
Treat loop as data member of ConcurrrentLearning.
Browse files Browse the repository at this point in the history
Signed-off-by: Han Wang <[email protected]>
  • Loading branch information
Han Wang committed Feb 10, 2022
1 parent 45e76c2 commit 89f5ec4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
14 changes: 10 additions & 4 deletions dpgen2/flow/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,17 @@ class ConcurrentLearning(Steps):
def __init__(
self,
name : str,
loop_op : Steps,
block_op : Steps,
image : str = "dflow:v1.0",
upload_python_package : str = None,
):
self.loop = ConcurrentLearningLoop(
name+'-loop',
block_op,
image = image,
upload_python_package = upload_python_package,
)

self._input_parameters={
"type_map" : InputParameter(),
"numb_models": InputParameter(type=int),
Expand Down Expand Up @@ -233,13 +240,12 @@ def __init__(
self.step_keys = {}
for ii in self._init_keys:
self.step_keys[ii] = os.path.join('init', ii)
self.loop_op_keys = loop_op.keys

self = _dpgen(
self,
self.step_keys,
name,
loop_op,
self.loop,
self.loop_key,
image = image,
upload_python_package = upload_python_package,
Expand Down Expand Up @@ -267,7 +273,7 @@ def init_keys(self):

@property
def loop_keys(self):
return [self.loop_key] + self.loop_op_keys
return [self.loop_key] + self.loop.keys


def _loop (
Expand Down
9 changes: 2 additions & 7 deletions tests/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from dpgen2.flow.block import ConcurrentLearningBlock
from dpgen2.exploration.task import ExplorationTask, ExplorationTaskGroup
from dpgen2.fp.vasp import VaspInputs
from dpgen2.flow.loop import ConcurrentLearning, ConcurrentLearningLoop
from dpgen2.flow.loop import ConcurrentLearning
from dpgen2.exploration.report import ExplorationReport
from dpgen2.exploration.task import ExplorationTaskGroup, ExplorationStage
from dpgen2.exploration.selector import TrustLevelConfSelector, TrustLevel
Expand Down Expand Up @@ -122,14 +122,9 @@ def _setUp_ops(self):
MockedCollectData,
upload_python_package = upload_python_package,
)
self.loop_op = ConcurrentLearningLoop(
self.name+'-loop',
self.block_cl_op,
upload_python_package = upload_python_package,
)
self.dpgen_op = ConcurrentLearning(
self.name,
self.loop_op,
self.block_cl_op,
upload_python_package = upload_python_package,
)

Expand Down

0 comments on commit 89f5ec4

Please sign in to comment.