Skip to content

Commit

Permalink
fix(pt): fix type annotations for dummy compress op; improve docs (#4342
Browse files Browse the repository at this point in the history
)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Added type annotations to several functions for improved clarity and
type safety.
- Updated documentation to include installation requirements for the
PyTorch backend when compressing models.

- **Documentation**
- New section on installation requirements added to the `compress.md`
document.

- **Bug Fixes**
	- No bug fixes were introduced in this release. 

- **Refactor**
- Minor refactoring for better code readability without changing
existing functionalities.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
njzjz and wanghan-iapcm authored Nov 12, 2024
1 parent 4793125 commit 4a9ed88
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 21 deletions.
10 changes: 5 additions & 5 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"):

def tabulate_fusion_se_a(
argument0,
argument1,
argument2,
argument3,
argument4,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
14 changes: 7 additions & 7 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_atten"):

def tabulate_fusion_se_atten(
argument0,
argument1,
argument2,
argument3,
argument4,
argument5,
argument6,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: torch.Tensor,
argument5: int,
argument6: bool,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_atten is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_r"):

def tabulate_fusion_se_r(
argument0,
argument1,
argument2,
argument3,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_r is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
10 changes: 5 additions & 5 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_t"):

def tabulate_fusion_se_t(
argument0,
argument1,
argument2,
argument3,
argument4,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_t is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
6 changes: 6 additions & 0 deletions doc/freeze/compress.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,9 @@ Notice: Model compression for the `se_atten_v2` descriptor is exclusively design
- relu6
- softplus
- sigmoid

## Requirements of installation {{ pytorch_icon }}

When compressing models in the PyTorch backend, the customized OP library for the Python interface must be installed when [freezing the model](../freeze/freeze.md).

The customized OP library for the Python interface can be installed by setting environment variable {envvar}`DP_ENABLE_PYTORCH` to `1` during [installation](../install/install-from-source.md).

0 comments on commit 4a9ed88

Please sign in to comment.