Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute output stat for atomic model #3642

Merged
merged 13 commits into from
Apr 7, 2024
6 changes: 6 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@
class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.type_map = type_map

Check warning on line 35 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L35

Added line #L35 was not covered by tests

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in subclass
DPAtomicModel
.
Assignment overwrites attribute type_map, which was previously defined in subclass
DPAtomicModel
.
Assignment overwrites attribute type_map, which was previously defined in subclass
LinearEnergyAtomicModel
.
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

Check warning on line 41 in deepmd/dpmodel/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/base_atomic_model.py#L41

Added line #L41 was not covered by tests

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
Expand Down
6 changes: 1 addition & 5 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
self.descriptor = descriptor
self.fitting = fitting
self.type_map = type_map
super().__init__(**kwargs)
super().__init__(type_map, **kwargs)

Check warning on line 56 in deepmd/dpmodel/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/dp_atomic_model.py#L56

Added line #L56 was not covered by tests

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand All @@ -67,10 +67,6 @@
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
assert len(err_msg) == 0, "\n".join(err_msg)
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)
super().__init__(type_map, **kwargs)

Check warning on line 69 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L69

Added line #L69 was not covered by tests

def mixed_types(self) -> bool:
"""If true, the model
Expand Down
3 changes: 0 additions & 3 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def atomic_output_def(self) -> FittingOutputDef:
"""
return self.fitting_output_def()

def get_output_keys(self) -> List[str]:
return list(self.atomic_output_def().keys())

@abstractmethod
def get_rcut(self) -> float:
"""Get the cut-off radius."""
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,20 @@
rcut: float,
sel: Union[int, List[int]],
type_map: List[str],
rcond: Optional[float] = None,
atom_ener: Optional[List[float]] = None,
**kwargs,
):
super().__init__()
super().__init__(type_map, **kwargs)

Check warning on line 66 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L66

Added line #L66 was not covered by tests
self.tab_file = tab_file
self.rcut = rcut
self.type_map = type_map

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in superclass
BaseAtomicModel
.

self.tab = PairTab(self.tab_file, rcut=rcut)
self.type_map = type_map

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in superclass
BaseAtomicModel
.
self.ntypes = len(type_map)
self.rcond = rcond
self.atom_ener = atom_ener

Check warning on line 75 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L74-L75

Added lines #L74 - L75 were not covered by tests

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@
if not self.r_differentiable:
raise ValueError("only r_differentiable variable can calculate hessian")

@property
def size(self):
return self.output_size

Check warning on line 229 in deepmd/dpmodel/output_def.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/output_def.py#L229

Added line #L229 was not covered by tests


class FittingOutputDef:
"""Defines the shapes and other properties of the fitting network outputs.
Expand Down
186 changes: 176 additions & 10 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from deepmd.pt.utils import (
AtomExcludeMask,
PairExcludeMask,
env,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
Expand All @@ -35,19 +36,64 @@
)

log = logging.getLogger(__name__)
dtype = env.GLOBAL_PT_FLOAT_PRECISION
device = env.DEVICE

Check warning on line 40 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L39-L40

Added lines #L39 - L40 were not covered by tests

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


class BaseAtomicModel(BaseAtomicModel_):
class BaseAtomicModel(torch.nn.Module, BaseAtomicModel_):

Check warning on line 45 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L45

Added line #L45 was not covered by tests
def __init__(
self,
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
torch.nn.Module.__init__(self)
BaseAtomicModel_.__init__(self)
self.type_map = type_map

Check warning on line 54 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L52-L54

Added lines #L52 - L54 were not covered by tests
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = None
self.atom_ener = None

Check warning on line 58 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L57-L58

Added lines #L57 - L58 were not covered by tests

def init_out_stat(self):

Check warning on line 60 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L60

Added line #L60 was not covered by tests
"""Initialize the output bias."""
ntypes = self.get_ntypes()
self.bias_keys: List[str] = list(self.fitting_output_def().keys())
self.max_out_size = max(

Check warning on line 64 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L62-L64

Added lines #L62 - L64 were not covered by tests
[self.atomic_output_def()[kk].size for kk in self.bias_keys]
)
self.n_out = len(self.bias_keys)
out_bias_data = torch.zeros(

Check warning on line 68 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L67-L68

Added lines #L67 - L68 were not covered by tests
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device
)
out_std_data = torch.ones(

Check warning on line 71 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L71

Added line #L71 was not covered by tests
[self.n_out, ntypes, self.max_out_size], dtype=dtype, device=device
)
self.register_buffer("out_bias", out_bias_data)
self.register_buffer("out_std", out_std_data)

Check warning on line 75 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L74-L75

Added lines #L74 - L75 were not covered by tests

def __setitem__(self, key, value):
if key in ["out_bias"]:
self.out_bias = value
elif key in ["out_std"]:
self.out_std = value

Check warning on line 81 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L77-L81

Added lines #L77 - L81 were not covered by tests
else:
raise KeyError(key)

Check warning on line 83 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L83

Added line #L83 was not covered by tests

def __getitem__(self, key):
if key in ["out_bias"]:
return self.out_bias
elif key in ["out_std"]:
return self.out_std

Check warning on line 89 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L85-L89

Added lines #L85 - L89 were not covered by tests
else:
raise KeyError(key)

Check warning on line 91 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L91

Added line #L91 was not covered by tests

@torch.jit.export
def get_type_map(self) -> List[str]:

Check warning on line 94 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L93-L94

Added lines #L93 - L94 were not covered by tests
"""Get the type map."""
return self.type_map

Check warning on line 96 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L96

Added line #L96 was not covered by tests

def reinit_atom_exclude(
self,
Expand Down Expand Up @@ -165,6 +211,7 @@
fparam=fparam,
aparam=aparam,
)
ret_dict = self.apply_out_bias(ret_dict, atype)

Check warning on line 214 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L214

Added line #L214 was not covered by tests

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].to(torch.int32)
Expand Down Expand Up @@ -210,9 +257,60 @@
"""
raise NotImplementedError

def compute_or_load_out_stat(

Check warning on line 260 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L260

Added line #L260 was not covered by tests
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
self.change_out_bias(

Check warning on line 281 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L281

Added line #L281 was not covered by tests
merged,
stat_file_path=stat_file_path,
bias_adjust_mode="set-by-statistic",
)

def apply_out_bias(

Check warning on line 287 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L287

Added line #L287 was not covered by tests
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
"""Apply the bias to each atomic output.
The developer may override the method to define how the bias is applied
to the atomic output of the model.

Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc

"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
for kk in self.bias_keys:

Check warning on line 305 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L304-L305

Added lines #L304 - L305 were not covered by tests
# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + out_bias[kk][atype]
return ret

Check warning on line 308 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L307-L308

Added lines #L307 - L308 were not covered by tests

def change_out_bias(
self,
sample_merged,
stat_file_path: Optional[DPPath] = None,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
bias_adjust_mode="change-by-statistic",
) -> None:
"""Change the output bias according to the input data and the pretrained model.
Expand All @@ -233,20 +331,28 @@
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
if bias_adjust_mode == "change-by-statistic":
delta_bias = compute_output_stats(
delta_bias, out_std = compute_output_stats(

Check warning on line 334 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L334

Added line #L334 was not covered by tests
sample_merged,
self.get_ntypes(),
keys=self.get_output_keys(),
keys=list(self.atomic_output_def().keys()),
stat_file_path=stat_file_path,
model_forward=self._get_forward_wrapper_func(),
)["energy"]
self.set_out_bias(delta_bias, add=True)
rcond=self.rcond,
atom_ener=self.atom_ener,
)
# self.set_out_bias(delta_bias, add=True)
self._store_out_stat(delta_bias, out_std, add=True)

Check warning on line 344 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L344

Added line #L344 was not covered by tests
elif bias_adjust_mode == "set-by-statistic":
bias_atom = compute_output_stats(
bias_out, std_out = compute_output_stats(

Check warning on line 346 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L346

Added line #L346 was not covered by tests
sample_merged,
self.get_ntypes(),
keys=self.get_output_keys(),
)["energy"]
self.set_out_bias(bias_atom)
keys=list(self.atomic_output_def().keys()),
stat_file_path=stat_file_path,
rcond=self.rcond,
atom_ener=self.atom_ener,
)
# self.set_out_bias(bias_out)
self._store_out_stat(bias_out, std_out)

Check warning on line 355 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L355

Added line #L355 was not covered by tests
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)

Expand Down Expand Up @@ -279,3 +385,63 @@
return {kk: vv.detach() for kk, vv in atomic_ret.items()}

return model_forward

def _varsize(

Check warning on line 389 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L389

Added line #L389 was not covered by tests
self,
shape: List[int],
) -> int:
output_size = 1
len_shape = len(shape)
for i in range(len_shape):
output_size *= shape[i]
return output_size

Check warning on line 397 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L393-L397

Added lines #L393 - L397 were not covered by tests

def _get_bias_index(

Check warning on line 399 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L399

Added line #L399 was not covered by tests
self,
kk: str,
) -> int:
res: List[int] = []
for i, e in enumerate(self.bias_keys):
if e == kk:
res.append(i)
assert len(res) == 1
return res[0]

Check warning on line 408 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L403-L408

Added lines #L403 - L408 were not covered by tests

def _store_out_stat(

Check warning on line 410 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L410

Added line #L410 was not covered by tests
self,
out_bias: Dict[str, torch.Tensor],
out_std: Dict[str, torch.Tensor],
add: bool = False,
):
ntypes = self.get_ntypes()
out_bias_data = torch.clone(self.out_bias)
out_std_data = torch.clone(self.out_std)
for kk in out_bias.keys():
assert kk in out_std.keys()
idx = self._get_bias_index(kk)
size = self._varsize(self.atomic_output_def()[kk].shape)
if not add:
out_bias_data[idx, :, :size] = out_bias[kk].view(ntypes, size)

Check warning on line 424 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L416-L424

Added lines #L416 - L424 were not covered by tests
else:
out_bias_data[idx, :, :size] += out_bias[kk].view(ntypes, size)
out_std_data[idx, :, :size] = out_std[kk].view(ntypes, size)
self.out_bias.copy_(out_bias_data)
self.out_std.copy_(out_std_data)

Check warning on line 429 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L426-L429

Added lines #L426 - L429 were not covered by tests

def _fetch_out_stat(

Check warning on line 431 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L431

Added line #L431 was not covered by tests
self,
keys: List[str],
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
ret_bias = {}
ret_std = {}
ntypes = self.get_ntypes()
for kk in keys:
idx = self._get_bias_index(kk)
isize = self._varsize(self.atomic_output_def()[kk].shape)
ret_bias[kk] = self.out_bias[idx, :, :isize].view(

Check warning on line 441 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L435-L441

Added lines #L435 - L441 were not covered by tests
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
)
ret_std[kk] = self.out_std[idx, :, :isize].view(

Check warning on line 444 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L444

Added line #L444 was not covered by tests
[ntypes] + list(self.atomic_output_def()[kk].shape) # noqa: RUF005
)
return ret_bias, ret_std

Check warning on line 447 in deepmd/pt/model/atomic_model/base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/base_atomic_model.py#L447

Added line #L447 was not covered by tests
16 changes: 5 additions & 11 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


@BaseAtomicModel.register("standard")
class DPAtomicModel(torch.nn.Module, BaseAtomicModel):
class DPAtomicModel(BaseAtomicModel):

Check warning on line 37 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L37

Added line #L37 was not covered by tests
"""Model give atomic prediction of some physical property.

Parameters
Expand All @@ -55,17 +55,17 @@
type_map: List[str],
**kwargs,
):
torch.nn.Module.__init__(self)
super().__init__(type_map, **kwargs)

Check warning on line 58 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L58

Added line #L58 was not covered by tests
ntypes = len(type_map)
self.type_map = type_map

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute type_map, which was previously defined in superclass
BaseAtomicModel
.
self.ntypes = ntypes
self.descriptor = descriptor
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
# order matters ntypes and type_map should be initialized first.
BaseAtomicModel.__init__(self, **kwargs)
super().init_out_stat()

Check warning on line 66 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L66

Added line #L66 was not covered by tests

@torch.jit.export

Check warning on line 68 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L68

Added line #L68 was not covered by tests
def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return (
Expand All @@ -79,11 +79,6 @@
"""Get the cut-off radius."""
return self.rcut

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.sel
Expand Down Expand Up @@ -220,8 +215,7 @@
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
if self.fitting_net is not None:
self.fitting_net.compute_output_stats(wrapped_sampler, stat_file_path)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

Check warning on line 218 in deepmd/pt/model/atomic_model/dp_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dp_atomic_model.py#L218

Added line #L218 was not covered by tests

def set_out_bias(self, out_bias: torch.Tensor, add=False) -> None:
"""
Expand Down
Loading
Loading