Skip to content

Commit

Permalink
feat(pt/dp): support dataid and sharable fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 25, 2024
1 parent ad30709 commit ae95c3c
Show file tree
Hide file tree
Showing 29 changed files with 569 additions and 151 deletions.
7 changes: 7 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def get_sel(self) -> list[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def set_dataid(self, data_idx):
"""
Set the data identification of this atomic model by the given data_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.fitting.set_dataid(data_idx)

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def get_model_rcuts(self) -> list[float]:
def get_sel(self) -> list[int]:
return [max([model.get_nsel() for model in self.models])]

def set_dataid(self, data_idx):
"""
Set the data identification of this atomic model by the given data_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
for model in self.models:
model.set_dataid(data_idx)

def get_model_nsels(self) -> list[int]:
"""Get the processed sels for each individual models. Not distinguishing types."""
return [model.get_nsel() for model in self.models]
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def get_sel(self) -> list[int]:
"""Returns the number of selected atoms for each type."""
pass

@abstractmethod
def set_dataid(self, data_idx) -> None:
"""
Set the data identification of this atomic model by the given data_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
pass

def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
return sum(self.get_sel())
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def get_type_map(self) -> list[str]:
def get_sel(self) -> list[int]:
return [self.sel]

def set_dataid(self, data_idx):
"""
Set the data identification of this atomic model by the given data_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
raise NotImplementedError(
"Data identification not supported for PairTabAtomicModel!"
)

def get_nsel(self) -> int:
return self.sel

Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dataid: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_dataid=numb_dataid,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand Down Expand Up @@ -159,7 +161,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
var_name = data.pop("var_name", None)
assert var_name == "dipole"
return super().deserialize(data)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dataid: int = 0,
bias_dos: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
trainable: Union[bool, list[bool]] = True,
Expand All @@ -60,6 +61,7 @@ def __init__(
bias_atom=bias_dos,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_dataid=numb_dataid,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
Expand All @@ -73,7 +75,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
data["numb_dos"] = data.pop("dim_out")
data.pop("tot_ener_zero", None)
data.pop("var_name", None)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dataid: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand All @@ -55,6 +56,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_dataid=numb_dataid,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand All @@ -73,7 +75,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)
Expand Down
35 changes: 34 additions & 1 deletion deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dataid: int = 0,
bias_atom_e: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
Expand All @@ -127,6 +128,7 @@ def __init__(
self.resnet_dt = resnet_dt
self.numb_fparam = numb_fparam
self.numb_aparam = numb_aparam
self.numb_dataid = numb_dataid
self.rcond = rcond
self.tot_ener_zero = tot_ener_zero
self.trainable = trainable
Expand Down Expand Up @@ -171,11 +173,16 @@ def __init__(
self.aparam_inv_std = np.ones(self.numb_aparam, dtype=self.prec)
else:
self.aparam_avg, self.aparam_inv_std = None, None
if self.numb_dataid > 0:
self.dataid = np.zeros(self.numb_dataid, dtype=self.prec)
else:
self.dataid = None
# init networks
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
+ self.numb_dataid
)
self.nets = NetworkCollection(
1 if not self.mixed_types else 0,
Expand Down Expand Up @@ -222,6 +229,13 @@ def get_type_map(self) -> list[str]:
"""Get the name to each type of atoms."""
return self.type_map

def set_dataid(self, data_idx):
"""
Set the data identification of this fitting net by the given data_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.dataid = np.eye(self.numb_dataid, dtype=self.prec)[data_idx]

def change_type_map(
self, type_map: list[str], model_with_new_type_stat=None
) -> None:
Expand Down Expand Up @@ -255,6 +269,8 @@ def __setitem__(self, key, value) -> None:
self.aparam_avg = value
elif key in ["aparam_inv_std"]:
self.aparam_inv_std = value
elif key in ["dataid"]:
self.dataid = value
elif key in ["scale"]:
self.scale = value
else:
Expand All @@ -271,6 +287,8 @@ def __getitem__(self, key):
return self.aparam_avg
elif key in ["aparam_inv_std"]:
return self.aparam_inv_std
elif key in ["dataid"]:
return self.dataid
elif key in ["scale"]:
return self.scale
else:
Expand All @@ -287,14 +305,15 @@ def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
"@class": "Fitting",
"@version": 2,
"@version": 3,
"var_name": self.var_name,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"numb_fparam": self.numb_fparam,
"numb_aparam": self.numb_aparam,
"numb_dataid": self.numb_dataid,
"rcond": self.rcond,
"activation_function": self.activation_function,
"precision": self.precision,
Expand All @@ -303,6 +322,7 @@ def serialize(self) -> dict:
"nets": self.nets.serialize(),
"@variables": {
"bias_atom_e": to_numpy_array(self.bias_atom_e),
"dataid": to_numpy_array(self.dataid),
"fparam_avg": to_numpy_array(self.fparam_avg),
"fparam_inv_std": to_numpy_array(self.fparam_inv_std),
"aparam_avg": to_numpy_array(self.aparam_avg),
Expand Down Expand Up @@ -423,6 +443,19 @@ def _call_common(
axis=-1,
)

if self.numb_dataid > 0:
assert self.dataid is not None
dataid = xp.tile(xp.reshape(self.dataid, [1, 1, -1]), [nf, nloc, 1])
xx = xp.concat(
[xx, dataid],
axis=-1,
)
if xx_zeros is not None:
xx_zeros = xp.concat(
[xx_zeros, dataid],
axis=-1,
)

# calculate the prediction
if not self.mixed_types:
outs = xp.zeros(
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dataid: int = 0,
bias_atom: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_dataid=numb_dataid,
rcond=rcond,
bias_atom_e=bias_atom,
tot_ener_zero=tot_ener_zero,
Expand Down Expand Up @@ -183,7 +185,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
check_version_compatibility(data.pop("@version", 1), 3, 1)
return super().deserialize(data)

def _net_out_dim(self):
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dataid: int = 0,
rcond: Optional[float] = None,
tot_ener_zero: bool = False,
trainable: Optional[list[bool]] = None,
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_dataid=numb_dataid,
rcond=rcond,
tot_ener_zero=tot_ener_zero,
trainable=trainable,
Expand Down Expand Up @@ -187,7 +189,7 @@ def __getitem__(self, key):
def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["@version"] = 3
data["@version"] = 4
data["embedding_width"] = self.embedding_width
data["fit_diag"] = self.fit_diag
data["shift_diag"] = self.shift_diag
Expand All @@ -198,7 +200,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 3, 1)
check_version_compatibility(data.pop("@version", 1), 4, 1)
var_name = data.pop("var_name", None)
assert var_name == "polar"
return super().deserialize(data)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dataid: int = 0,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = True,
Expand All @@ -99,6 +100,7 @@ def __init__(
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
numb_dataid=numb_dataid,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
Expand All @@ -111,7 +113,7 @@ def __init__(
@classmethod
def deserialize(cls, data: dict) -> "PropertyFittingNet":
data = data.copy()
check_version_compatibility(data.pop("@version"), 2, 1)
check_version_compatibility(data.pop("@version"), 3, 1)
data.pop("dim_out")
data.pop("var_name")
data.pop("tot_ener_zero")
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,9 @@ def serialize(self) -> dict:
def deserialize(cls, data) -> "CM":
return cls(atomic_model_=T_AtomicModel.deserialize(data))

def set_dataid(self, data_idx):
self.atomic_model.set_dataid(data_idx)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.atomic_model.get_dim_fparam()
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def get_sel(self) -> list[int]:
"""Get the neighbor selection."""
return self.sel

def set_dataid(self, data_idx):
"""
Set the data identification of this atomic model by the given data_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
self.fitting_net.set_dataid(data_idx)

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def get_model_rcuts(self) -> list[float]:
def get_sel(self) -> list[int]:
return [max([model.get_nsel() for model in self.models])]

def set_dataid(self, data_idx):
"""
Set the data identification of this atomic model by the given data_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
for model in self.models:
model.set_dataid(data_idx)

def get_model_nsels(self) -> list[int]:
"""Get the processed sels for each individual models. Not distinguishing types."""
return [model.get_nsel() for model in self.models]
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def get_type_map(self) -> list[str]:
def get_sel(self) -> list[int]:
return [self.sel]

def set_dataid(self, data_idx):
"""
Set the data identification of this atomic model by the given data_idx,
typically concatenated with the output of the descriptor and fed into the fitting net.
"""
raise NotImplementedError(
"Data identification not supported for PairTabAtomicModel!"
)

def get_nsel(self) -> int:
return self.sel

Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ def serialize(self) -> dict:
def deserialize(cls, data) -> "CM":
return cls(atomic_model_=T_AtomicModel.deserialize(data))

def set_dataid(self, data_idx):
self.atomic_model.set_dataid(data_idx)

@torch.jit.export
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down
Loading

0 comments on commit ae95c3c

Please sign in to comment.