Skip to content

Commit

Permalink
Add compression API to BaseModel and AtomicModel (#4298)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced `enable_compression` method across multiple classes to
allow configuration of compression settings for descriptors.
- Enhanced robustness of output definitions and serialization processes
in the `DPAtomicModel` class.
- Added `enable_compression` method to `LinearEnergyAtomicModel` for
improved model compression capabilities.

- **Bug Fixes**
- Improved error handling in the `fitting_output_def` method to ensure
fallback functionality when the fitting network is unavailable.

These updates enhance the functionality and reliability of the model
management features.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cherryWangY and pre-commit-ci[bot] authored Nov 2, 2024
1 parent 25bb821 commit bfbe2ed
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 0 deletions.
31 changes: 31 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,37 @@ def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return self.descriptor.need_sorted_nlist_for_lower()

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression().
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.descriptor.enable_compression(
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
32 changes: 32 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,38 @@ def _sort_rcuts_sels(self) -> tuple[list[float], list[int]]:
)
return [p[0] for p in zipped], [p[1] for p in zipped]

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Compress model.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
for model in self.models:
model.enable_compression(
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
extended_coord,
Expand Down
25 changes: 25 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,31 @@ def change_type_map(
) -> None:
pass

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression().
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
raise NotImplementedError("This atomi model doesn't support compression!")

def make_atom_mask(
self,
atype: t_tensor,
Expand Down
22 changes: 22 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,28 @@ def update_sel(
cls = cls.get_class_by_type(model_type)
return cls.update_sel(train_data, type_map, local_jdata)

def enable_compression(
self,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Enable model compression by tabulation.
Parameters
----------
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
raise NotImplementedError("This atomic model doesn't support compression!")

@classmethod
def get_model(cls, model_params: dict) -> "BaseBaseModel":
"""Get the model by the parameters.
Expand Down
28 changes: 28 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,34 @@ def model_output_type(self) -> list[str]:
]
return vars

def enable_compression(
self,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call atomic_model enable_compression().
Parameters
----------
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.atomic_model.enable_compression(
self.get_min_nbor_dist(),
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def call(
self,
coord,
Expand Down
31 changes: 31 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,37 @@ def deserialize(cls, data) -> "DPAtomicModel":
obj = super().deserialize(data)
return obj

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call descriptor enable_compression().
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.descriptor.enable_compression(
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
extended_coord,
Expand Down
32 changes: 32 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,38 @@ def _sort_rcuts_sels(self) -> tuple[list[float], list[int]]:
sorted_sels: list[int] = outer_sorted[:, 1].to(torch.int64).tolist()
return sorted_rcuts, sorted_sels

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Compress model.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
for model in self.models:
model.enable_compression(
min_nbor_dist,
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

def forward_atomic(
self,
extended_coord: torch.Tensor,
Expand Down
28 changes: 28 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,34 @@ def model_output_type(self) -> list[str]:
vars.append(kk)
return vars

def enable_compression(
self,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Call atomic_model enable_compression().
Parameters
----------
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
self.atomic_model.enable_compression(
self.get_min_nbor_dist(),
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
)

# cannot use the name forward. torch script does not work
def forward_common(
self,
Expand Down

0 comments on commit bfbe2ed

Please sign in to comment.