Skip to content

Commit

Permalink
fix(2024Q1): optimize graph memory (copy deepmodeling#4006)
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jul 25, 2024
1 parent d39bb94 commit c07a56f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
4 changes: 3 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def train(FLAGS):


def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
model = inference.Tester(FLAGS.model, head=FLAGS.head).model
model.eval()
model = torch.jit.script(model)
torch.jit.save(
model,
FLAGS.output,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def forward_common_lower(
self.atomic_output_def(),
cc_ext,
do_atomic_virial=do_atomic_virial,
create_graph=self.training,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict
Expand Down
13 changes: 9 additions & 4 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def atomic_virial_corr(
faked_grad = torch.ones_like(sumce0)
lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad])
extended_virial_corr0 = torch.autograd.grad(
[sumce0], [extended_coord], grad_outputs=lst, create_graph=True
[sumce0], [extended_coord], grad_outputs=lst, create_graph=False, retain_graph=True,
)[0]
assert extended_virial_corr0 is not None
extended_virial_corr1 = torch.autograd.grad(
[sumce1], [extended_coord], grad_outputs=lst, create_graph=True
[sumce1], [extended_coord], grad_outputs=lst, create_graph=False, retain_graph=True,
)[0]
assert extended_virial_corr1 is not None
extended_virial_corr2 = torch.autograd.grad(
[sumce2], [extended_coord], grad_outputs=lst, create_graph=True
[sumce2], [extended_coord], grad_outputs=lst, create_graph=False, retain_graph=True,
)[0]
assert extended_virial_corr2 is not None
extended_virial_corr = torch.concat(
Expand All @@ -61,11 +61,12 @@ def task_deriv_one(
extended_coord: torch.Tensor,
do_virial: bool = True,
do_atomic_virial: bool = False,
create_graph: bool = True,
):
faked_grad = torch.ones_like(energy)
lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad])
extended_force = torch.autograd.grad(
[energy], [extended_coord], grad_outputs=lst, create_graph=True
[energy], [extended_coord], grad_outputs=lst, create_graph=create_graph, retain_graph=True,
)[0]
assert extended_force is not None
extended_force = -extended_force
Expand Down Expand Up @@ -106,6 +107,7 @@ def take_deriv(
coord_ext: torch.Tensor,
do_virial: bool = False,
do_atomic_virial: bool = False,
create_graph: bool = True,
):
size = 1
for ii in vdef.shape:
Expand All @@ -123,6 +125,7 @@ def take_deriv(
coord_ext,
do_virial=do_virial,
do_atomic_virial=do_atomic_virial,
create_graph=create_graph,
)
# nf x nloc x 1 x 3, nf x nloc x 1 x 9
ffi = ffi.unsqueeze(-2)
Expand All @@ -146,6 +149,7 @@ def fit_output_to_model_output(
fit_output_def: FittingOutputDef,
coord_ext: torch.Tensor,
do_atomic_virial: bool = False,
create_graph: bool = True,
) -> Dict[str, torch.Tensor]:
"""Transform the output of the fitting network to
the model output.
Expand All @@ -169,6 +173,7 @@ def fit_output_to_model_output(
coord_ext,
do_virial=vdef.c_differentiable,
do_atomic_virial=do_atomic_virial,
create_graph=create_graph,
)
model_ret[kk_derv_r] = dr
if vdef.c_differentiable:
Expand Down

0 comments on commit c07a56f

Please sign in to comment.