From 55de77cdf1427650629976c7501f6c60adad73a1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 11 Nov 2024 23:20:13 -0500 Subject: [PATCH] fix(pt): fix type annotations for dummy compress op; improve docs Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 10 +++++----- deepmd/pt/model/descriptor/se_atten.py | 14 +++++++------- deepmd/pt/model/descriptor/se_r.py | 8 ++++---- deepmd/pt/model/descriptor/se_t.py | 10 +++++----- doc/freeze/compress.md | 6 ++++++ 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index eadce86963..9b5ee6d2c4 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -73,11 +73,11 @@ if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"): def tabulate_fusion_se_a( - argument0, - argument1, - argument2, - argument3, - argument4, + argument0: torch.Tensor, + argument1: torch.Tensor, + argument2: torch.Tensor, + argument3: torch.Tensor, + argument4: int, ) -> list[torch.Tensor]: raise NotImplementedError( "tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. " diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 6ec02de514..a4e3d44bf0 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -52,13 +52,13 @@ if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_atten"): def tabulate_fusion_se_atten( - argument0, - argument1, - argument2, - argument3, - argument4, - argument5, - argument6, + argument0: torch.Tensor, + argument1: torch.Tensor, + argument2: torch.Tensor, + argument3: torch.Tensor, + argument4: torch.Tensor, + argument5: int, + argument6: bool, ) -> list[torch.Tensor]: raise NotImplementedError( "tabulate_fusion_se_atten is not available since customized PyTorch OP library is not built when freezing the model. " diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index f70fdfa9f1..beb8acd5d5 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -62,10 +62,10 @@ if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_r"): def tabulate_fusion_se_r( - argument0, - argument1, - argument2, - argument3, + argument0: torch.Tensor, + argument1: torch.Tensor, + argument2: torch.Tensor, + argument3: int, ) -> list[torch.Tensor]: raise NotImplementedError( "tabulate_fusion_se_r is not available since customized PyTorch OP library is not built when freezing the model. " diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 0eec78fd2f..ae9b3a9c1a 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -73,11 +73,11 @@ if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_t"): def tabulate_fusion_se_t( - argument0, - argument1, - argument2, - argument3, - argument4, + argument0: torch.Tensor, + argument1: torch.Tensor, + argument2: torch.Tensor, + argument3: torch.Tensor, + argument4: int, ) -> list[torch.Tensor]: raise NotImplementedError( "tabulate_fusion_se_t is not available since customized PyTorch OP library is not built when freezing the model. " diff --git a/doc/freeze/compress.md b/doc/freeze/compress.md index 4f30458df1..cdb12cc9e7 100644 --- a/doc/freeze/compress.md +++ b/doc/freeze/compress.md @@ -124,3 +124,9 @@ Notice: Model compression for the `se_atten_v2` descriptor is exclusively design - relu6 - softplus - sigmoid + +## Requirements of installation {{ pytorch_icon }} + +When compressing models in the PyTorch backend, the customized OP library for the Python interface must be installed when [freezing the model](../freeze/freeze.md). + +The customized OP library for the Python interface can be installed by setting environment variable {envvar}`DP_ENABLE_PYTORCH` to `1` during [installation](../install/install-from-source.md).