Skip to content

Commit

Permalink
fix(pt): fix type annotations for dummy compress op; improve docs
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 12, 2024
1 parent c4a973a commit 55de77c
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 21 deletions.
10 changes: 5 additions & 5 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"):

def tabulate_fusion_se_a(
argument0,
argument1,
argument2,
argument3,
argument4,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
14 changes: 7 additions & 7 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_atten"):

def tabulate_fusion_se_atten(
argument0,
argument1,
argument2,
argument3,
argument4,
argument5,
argument6,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: torch.Tensor,
argument5: int,
argument6: bool,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_atten is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_r"):

def tabulate_fusion_se_r(
argument0,
argument1,
argument2,
argument3,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_r is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
10 changes: 5 additions & 5 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_t"):

def tabulate_fusion_se_t(
argument0,
argument1,
argument2,
argument3,
argument4,
argument0: torch.Tensor,
argument1: torch.Tensor,
argument2: torch.Tensor,
argument3: torch.Tensor,
argument4: int,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_t is not available since customized PyTorch OP library is not built when freezing the model. "
Expand Down
6 changes: 6 additions & 0 deletions doc/freeze/compress.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,9 @@ Notice: Model compression for the `se_atten_v2` descriptor is exclusively design
- relu6
- softplus
- sigmoid

## Requirements of installation {{ pytorch_icon }}

When compressing models in the PyTorch backend, the customized OP library for the Python interface must be installed when [freezing the model](../freeze/freeze.md).

The customized OP library for the Python interface can be installed by setting environment variable {envvar}`DP_ENABLE_PYTORCH` to `1` during [installation](../install/install-from-source.md).

0 comments on commit 55de77c

Please sign in to comment.