Skip to content

Commit

Permalink
fix errors
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 27, 2024
1 parent 4db7824 commit 741e8e0
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 16 deletions.
6 changes: 3 additions & 3 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def get_data(self, is_train=True, task_key="Default"):
def print_header(self, fout, train_results, valid_results):
train_keys = sorted(train_results.keys())
print_str = ""
print_str += "# %5s" % "step"
print_str += "# {:5s}".format("step")
if not self.multi_task:
if valid_results:
prop_fmt = " %11s %11s"
Expand All @@ -1111,15 +1111,15 @@ def print_header(self, fout, train_results, valid_results):
prop_fmt = " %11s"
for k in sorted(train_results[model_key].keys()):
print_str += prop_fmt % (k + f"_trn_{model_key}")
print_str += " %8s\n" % "lr"
print_str += " {:8s}\n".format("lr")
print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n"
fout.write(print_str)
fout.flush()

def print_on_training(self, fout, step_id, cur_lr, train_results, valid_results):
train_keys = sorted(train_results.keys())
print_str = ""
print_str += "%7d" % step_id
print_str += f"{step_id:7d}"
if not self.multi_task:
if valid_results:
prop_fmt = " %11.2e %11.2e"
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def get_data(self, is_train=True, task_key="Default"):
def print_header(self, fout, train_results, valid_results) -> None:
train_keys = sorted(train_results.keys())
print_str = ""
print_str += f"# {'step':5s}"
print_str += "# {:5s}".format("step")
if not self.multi_task:
if valid_results:
prop_fmt = " %11s %11s"
Expand All @@ -1155,7 +1155,7 @@ def print_header(self, fout, train_results, valid_results) -> None:
prop_fmt = " %11s"
for k in sorted(train_results[model_key].keys()):
print_str += prop_fmt % (k + f"_trn_{model_key}")
print_str += f" {'lr':8s}\n"
print_str += " {:8s}\n".format("lr")
print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n"
fout.write(print_str)
fout.flush()
Expand Down
13 changes: 9 additions & 4 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,12 @@ def train(self, train_data=None, valid_data=None) -> None:
is_first_step = True
self.cur_batch = cur_batch
log.info(
f"start training at lr {run_sess(self.sess, self.learning_rate):.2e} (== {self.lr.value(cur_batch):.2e}), decay_step {self.lr.decay_steps_}, decay_rate {self.lr.decay_rate_:f}, final lr will be {self.lr.value(stop_batch):.2e}"
"start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e",
run_sess(self.sess, self.learning_rate),
self.lr.value(cur_batch),
self.lr.decay_steps_,
self.lr.decay_rate_,
self.lr.value(stop_batch),
)

prf_options = None
Expand Down Expand Up @@ -595,7 +600,7 @@ def train(self, train_data=None, valid_data=None) -> None:
if self.timing_in_training and elapsed_batch // self.disp_freq > 0:
if elapsed_batch >= 2 * self.disp_freq:
log.info(
"average training time: %.4f s/batcsh (exclude first %d batches)",
"average training time: %.4f s/batch (exclude first %d batches)",
total_train_time
/ (
elapsed_batch // self.disp_freq * self.disp_freq
Expand Down Expand Up @@ -685,7 +690,7 @@ def valid_on_the_fly(
@staticmethod
def print_header(fp, train_results, valid_results) -> None:
print_str = ""
print_str += f'# {"step":5s}'
print_str += "# {:5s}".format("step")
if valid_results is not None:
prop_fmt = " %11s %11s"
for k in train_results.keys():
Expand All @@ -694,7 +699,7 @@ def print_header(fp, train_results, valid_results) -> None:
prop_fmt = " %11s"
for k in train_results.keys():
print_str += prop_fmt % (k + "_trn")
print_str += f' {"lr":8s}\n'
print_str += " {:8s}\n".format("lr")
print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n"
fp.write(print_str)
fp.flush()
Expand Down
19 changes: 15 additions & 4 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,14 +664,25 @@ def print_summary(
log.info(
f"---Summary of DataSystem: {name:13s}-----------------------------------------------"
)
log.info(f"found {nsystems} system(s):")
log.info("found %d system(s):", nsystems)
log.info(
("{} ".format(_format_name_length("system", sys_width)))
+ (f'{"natoms":6s} {"bch_sz":6s} {"n_bch":6s} {"prob":9s} {"pbc":3s}')
"%s %6s %6s %6s %9s %3s",
_format_name_length("system", sys_width),
"natoms",
"bch_sz",
"n_bch",
"prob",
"pbc",
)
for ii in range(nsystems):
log.info(
f'{_format_name_length(system_dirs[ii], sys_width)} {natoms[ii]:6d} {batch_size[ii]:6d} {nbatches[ii]:6d} {sys_probs[ii]:9.3e} {"T" if pbc[ii] else "F":3s}'
"%s %6d %6d %6d %9.3e %3s",
_format_name_length(system_dirs[ii], sys_width),
natoms[ii],
batch_size[ii],
nbatches[ii],
sys_probs[ii],
"T" if pbc[ii] else "F",
)
log.info(
"--------------------------------------------------------------------------------------"
Expand Down
6 changes: 3 additions & 3 deletions source/tests/pd/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ def paddle2tf(paddle_name, last_layer_id=None):
return None
layer_id = int(fields[4 + offset]) + 1
weight_type = fields[5 + offset]
ret = "filter_type_all/%s_%d_%d:0" % (weight_type, layer_id, element_id)
ret = f"filter_type_all/{weight_type}_{layer_id}_{element_id}:0"
elif fields[1] == "fitting_net":
layer_id = int(fields[4 + offset])
weight_type = fields[5 + offset]
if layer_id != last_layer_id:
ret = "layer_%d_type_%d/%s:0" % (layer_id, element_id, weight_type)
ret = f"layer_{layer_id}_type_{element_id}/{weight_type}:0"
else:
ret = "final_layer_type_%d/%s:0" % (element_id, weight_type)
ret = f"final_layer_type_{element_id}/{weight_type}:0"
else:
raise RuntimeError(f"Unexpected parameter name: {paddle_name}")
return ret
Expand Down

0 comments on commit 741e8e0

Please sign in to comment.