Skip to content

Commit

Permalink
cast fitting
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 13, 2024
1 parent 638acc2 commit 41d80b9
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.common import (
cast_precision,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
Expand Down Expand Up @@ -174,6 +177,7 @@ def output_def(self):
]
)

@cast_precision
def call(
self,
descriptor: np.ndarray,
Expand Down
12 changes: 4 additions & 8 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,6 @@ def _call_common(
"""
xp = array_api_compat.array_namespace(descriptor, atype)
descriptor = xp.astype(descriptor, get_xp_precision(xp, self.precision))
if fparam is not None:
fparam = xp.astype(fparam, get_xp_precision(xp, self.precision))
if aparam is not None:
aparam = xp.astype(aparam, get_xp_precision(xp, self.precision))
nf, nloc, nd = descriptor.shape
net_dim_out = self._net_out_dim()
# check input dim
Expand Down Expand Up @@ -452,13 +447,14 @@ def _call_common(
outs = self.nets[()](xx)
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
outs = xp.astype(outs, get_xp_precision(xp, "global"))
outs += xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
xp.take(
xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0
),
[nf, nloc, net_dim_out],
)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))}
return {self.var_name: xp.astype(outs, descriptor.dtype)}
4 changes: 4 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.common import (
cast_precision,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
Expand Down Expand Up @@ -203,6 +206,7 @@ def output_def(self):
]
)

@cast_precision
def call(
self,
descriptor: np.ndarray,
Expand Down
11 changes: 9 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DEFAULT_PRECISION,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.fitting.base_fitting import (
Expand Down Expand Up @@ -241,6 +242,7 @@ def change_type_map(
self.scale = self.scale[remap_index]
self.constant_matrix = self.constant_matrix[remap_index]

@cast_precision
def call(
self,
descriptor: np.ndarray,
Expand Down Expand Up @@ -285,7 +287,8 @@ def call(
]
# out = out * self.scale[atype, ...]
scale_atype = xp.reshape(
xp.take(self.scale, xp.reshape(atype, [-1]), axis=0), (*atype.shape, 1)
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, [-1]), axis=0),
(*atype.shape, 1),
)
out = out * scale_atype
# (nframes * nloc, m1, 3)
Expand All @@ -308,7 +311,11 @@ def call(
if self.shift_diag:
# bias = self.constant_matrix[atype]
bias = xp.reshape(
xp.take(self.constant_matrix, xp.reshape(atype, [-1]), axis=0),
xp.take(
xp.astype(self.constant_matrix, out.dtype),
xp.reshape(atype, [-1]),
axis=0,
),
(nframes, nloc),
)
# (nframes, nloc, 1)
Expand Down

0 comments on commit 41d80b9

Please sign in to comment.