Skip to content

Commit

Permalink
fix bugs in PR deepmodeling#2
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 6, 2022
1 parent 9547663 commit 45ea9c1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
53 changes: 27 additions & 26 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os, json, dpdata, glob
from pathlib import Path
from dpgen2.utils.run_command import run_command
from dpgen2.utils.chdir import chdir
from dpgen2.utils.chdir import set_directory
from dflow.python import (
OP,
OPIO,
Expand Down Expand Up @@ -57,7 +57,6 @@ def get_output_sign(cls):
})

@OP.exec_sign_check
@chdir("task_name")
def execute(
self,
ip : OPIO,
Expand Down Expand Up @@ -100,6 +99,7 @@ def execute(
iter_data_old_exp = _expand_all_multi_sys_to_sys(iter_data[:-1])
iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:])
iter_data_exp = iter_data_old_exp + iter_data_new_exp
work_dir = Path(task_name)

# update the input script
input_script = Path(task_path)/train_script_name
Expand All @@ -125,34 +125,35 @@ def execute(
train_dict = RunDPTrain.write_other_to_input_script(
train_dict, config, do_init_model, major_version)

# open log
fplog = open('train.log', 'w')
def clean_before_quit():
fplog.close()
with set_directory(work_dir):
# open log
fplog = open('train.log', 'w')
def clean_before_quit():
fplog.close()

# dump train script
with open(train_script_name, 'w') as fp:
json.dump(train_dict, fp, indent=4)
# dump train script
with open(train_script_name, 'w') as fp:
json.dump(train_dict, fp, indent=4)

# train model
if do_init_model:
command = ['dp', 'train', '--init-frz-model', str(init_model), train_script_name]
else:
command = ['dp', 'train', train_script_name]
ret, out, err = run_command(command)
if ret != 0:
clean_before_quit()
raise FatalError('dp train failed')
fplog.write(out)
# train model
if do_init_model:
command = ['dp', 'train', '--init-frz-model', str(init_model), train_script_name]
else:
command = ['dp', 'train', train_script_name]
ret, out, err = run_command(command)
if ret != 0:
clean_before_quit()
raise FatalError('dp train failed')
fplog.write(out)

# freeze model
ret, out, err = run_command(['dp', 'freeze', '-o', 'frozen_model.pb'])
if ret != 0:
clean_before_quit()
raise FatalError('dp freeze failed')
fplog.write(out)

# freeze model
ret, out, err = run_command(['dp', 'freeze', '-o', 'frozen_model.pb'])
if ret != 0:
clean_before_quit()
raise FatalError('dp freeze failed')
fplog.write(out)

clean_before_quit()

return OPIO({
"script" : work_dir / train_script_name,
Expand Down
7 changes: 4 additions & 3 deletions dpgen2/utils/chdir.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from functools improt wraps
from functools import wraps
from typing import Callable
from contextlib import contextmanager
from pathlib import Path


from dflow.python import (
OPIO,
)
@contextmanager
def set_directory(path: Path):
"""Sets the current working path within the context.
Expand Down

0 comments on commit 45ea9c1

Please sign in to comment.