From 8cefbb1ab1b4de8c7d50229128ac8e99ea8f7c51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Dec 2023 13:39:19 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/fit/dipole.py | 9 +++++---- deepmd/model/tensor.py | 2 +- deepmd/train/trainer.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/deepmd/fit/dipole.py b/deepmd/fit/dipole.py index 66eec26c49..80c8b986d9 100644 --- a/deepmd/fit/dipole.py +++ b/deepmd/fit/dipole.py @@ -1,6 +1,7 @@ import logging from typing import List from typing import Optional + import numpy as np from paddle import nn @@ -242,7 +243,9 @@ def forward( if type_embedding is not None: nloc_mask = paddle.reshape( - paddle.tile(paddle.repeat_interleave(self.sel_mask, natoms[2:]), [nframes]), + paddle.tile( + paddle.repeat_interleave(self.sel_mask, natoms[2:]), [nframes] + ), [nframes, -1], ) atype_nall = paddle.reshape(atype, [-1, natoms[1]]) @@ -253,9 +256,7 @@ def forward( self.nloc_masked = paddle.shape( paddle.reshape(self.atype_nloc_masked, [nframes, -1]) )[1] - atype_embed = nn.embedding_lookup( - type_embedding, self.atype_nloc_masked - ) + atype_embed = nn.embedding_lookup(type_embedding, self.atype_nloc_masked) else: atype_embed = None diff --git a/deepmd/model/tensor.py b/deepmd/model/tensor.py index c04912c6fe..3f4f74fdaf 100644 --- a/deepmd/model/tensor.py +++ b/deepmd/model/tensor.py @@ -154,7 +154,7 @@ def forward( rot_mat = self.descrpt.get_rot_mat() rot_mat = paddle.clone(rot_mat, name="o_rot_mat" + suffix) - + output = self.fitting( dout, rot_mat, natoms, input_dict, reuse=reuse, suffix=suffix ) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 3424aa95c8..b2a5e2d8f6 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -28,8 +28,8 @@ from deepmd.env import tf from deepmd.env import tfv2 from deepmd.fit import Fitting -from deepmd.fit import ener from deepmd.fit import dipole +from deepmd.fit import ener from deepmd.loss import DOSLoss from deepmd.loss import EnerDipoleLoss from deepmd.loss import EnerSpinLoss