Skip to content

Commit

Permalink
Merge branch 'devel' into spin_rf
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Mar 5, 2024
2 parents 64410d0 + c8c941a commit 0cc0776
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 41 deletions.
22 changes: 12 additions & 10 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,18 @@ def share_params(self, base_class, shared_level, resume=False):
), "Only descriptors of the same type can share params!"
if shared_level == 0:
# link buffers
if hasattr(self, "mean") and not resume:
# in case of change params during resume
base_env = EnvMatStatSe(base_class)
base_env.stats = base_class.stats
for kk in base_class.get_stats():
base_env.stats[kk] += self.get_stats()[kk]
mean, stddev = base_env()
if not base_class.set_davg_zero:
base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE))
base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
if hasattr(self, "mean"):
if not resume:
# in case of change params during resume
base_env = EnvMatStatSe(base_class)
base_env.stats = base_class.stats
for kk in base_class.get_stats():
base_env.stats[kk] += self.get_stats()[kk]
mean, stddev = base_env()
if not base_class.set_davg_zero:
base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE))
base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
# must share, even if not do stat
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
Expand Down
68 changes: 37 additions & 31 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,10 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
f"training in {model_key}",
to_numpy_array(self.training_dataloader[model_key].sampler.weights),
)
if validation_data is not None:
if (
validation_data is not None
and validation_data[model_key] is not None
):
validation_data[model_key].print_summary(
f"validation in {model_key}",
to_numpy_array(
Expand Down Expand Up @@ -727,7 +730,7 @@ def log_loss_valid(_task_key="Default"):
)
if input_dict == {}:
# no validation data
return "", None
return {}
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
Expand All @@ -748,23 +751,24 @@ def log_loss_valid(_task_key="Default"):
if not self.multi_task:
train_results = log_loss_train(loss, more_loss)
valid_results = log_loss_valid()
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results is not None:
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
)
)
else:
train_results = {_key: {} for _key in self.model_keys}
valid_results = {_key: {} for _key in self.model_keys}
Expand All @@ -787,33 +791,35 @@ def log_loss_valid(_task_key="Default"):
loss, more_loss, _task_key=_key
)
valid_results[_key] = log_loss_valid(_task_key=_key)
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
if valid_results is not None:
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
if valid_results is not None and valid_results[_key]:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
)
)

current_time = time.time()
train_time = current_time - self.t0
self.t0 = current_time
log.info(
format_training_message(
batch=_step_id,
wall_time=train_time,
if self.rank == 0:
log.info(
format_training_message(
batch=_step_id,
wall_time=train_time,
)
)
)

if fout:
if self.lcurve_should_print_header:
Expand Down

0 comments on commit 0cc0776

Please sign in to comment.