-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(pt): fix type annotations for dummy compress op; improve docs #4342
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough📝 WalkthroughWalkthroughThis pull request primarily introduces type annotations to several functions within the Changes
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (5)
doc/freeze/compress.md (1)
132-132
: Consider adding a complete example command.To make the instructions more actionable, consider adding a complete example command showing how to set the environment variable:
The customized OP library for the Python interface can be installed by setting environment variable {envvar}`DP_ENABLE_PYTORCH` to `1` during [installation](../install/install-from-source.md). + +```bash +export DP_ENABLE_PYTORCH=1 +pip install . +```deepmd/pt/model/descriptor/se_a.py (1)
Line range hint
82-85
: Consider enhancing the error message and documentation.The error message could be more specific about which operation is missing and provide a direct link to the model compression documentation. Additionally, the LAMMPS compatibility note could be more prominent.
Consider updating the error message like this:
- "tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. " + "The tabulate_fusion_se_a operation is not available. This operation requires the custom PyTorch operations library, which is not built. " + "Please refer to https://docs.deepmodeling.org/projects/deepmd/en/master/model-compression.html for details on model compression. "deepmd/pt/model/descriptor/se_t.py (1)
76-80
: Improve type annotations and documentationWhile adding type annotations is good for type safety, consider these improvements:
- Use more descriptive parameter names instead of generic
argumentN
- Add dimension information to tensor type hints using
# shape: ...
comments- Add a docstring explaining the parameters and return value
- argument0: torch.Tensor, - argument1: torch.Tensor, - argument2: torch.Tensor, - argument3: torch.Tensor, - argument4: int, + table_data: torch.Tensor, # shape: [table_size, out_dim] + table_info: torch.Tensor, # shape: [6] + env_deriv: torch.Tensor, # shape: [batch_size, 1] + env: torch.Tensor, # shape: [batch_size, n_types_i, n_types_j] + output_dim: int,Also consider adding a docstring:
"""Fuse and tabulate the environment matrix for the SE(3) descriptor. Parameters ---------- table_data : torch.Tensor The tabulated data for interpolation table_info : torch.Tensor Configuration info containing [lower, upper, upper_ext, stride1, stride2, check_freq] env_deriv : torch.Tensor Environment derivatives env : torch.Tensor Environment matrix output_dim : int Output dimension of the network Returns ------- list[torch.Tensor] List containing the fused descriptor values """deepmd/pt/model/descriptor/se_atten.py (2)
55-61
: Improve parameter names and add docstring.While the type annotations are correct, the function would benefit from:
- More descriptive parameter names that indicate their purpose
- A docstring explaining the parameters and return value
Consider renaming parameters and adding documentation:
def tabulate_fusion_se_atten( - argument0: torch.Tensor, - argument1: torch.Tensor, - argument2: torch.Tensor, - argument3: torch.Tensor, - argument4: torch.Tensor, - argument5: int, - argument6: bool, + compress_data: torch.Tensor, + compress_info: torch.Tensor, + input_data: torch.Tensor, + radial_data: torch.Tensor, + gate_data: torch.Tensor, + filter_neuron: int, + is_sorted: bool, ) -> list[torch.Tensor]: + """Fallback implementation for the custom tabulate_fusion_se_atten operation. + + Parameters + ---------- + compress_data : torch.Tensor + Compressed network data + compress_info : torch.Tensor + Compression configuration information + input_data : torch.Tensor + Input tensor for the network + radial_data : torch.Tensor + Radial information tensor + gate_data : torch.Tensor + Gating tensor + filter_neuron : int + Number of filter neurons + is_sorted : bool + Whether the input data is sorted + + Returns + ------- + list[torch.Tensor] + List of output tensors from the fusion operation + """
Line range hint
87-156
: Remove deprecatedtype
parameter.The
type
parameter is immediately deleted after being passed to the constructor, indicating it's deprecated. For better code clarity:
- Remove the parameter from the constructor signature
- Remove its documentation from the docstring
Apply this change:
def __init__( self, rcut: float, rcut_smth: float, sel: Union[list[int], int], ntypes: int, neuron: list = [25, 50, 100], axis_neuron: int = 16, tebd_dim: int = 8, tebd_input_mode: str = "concat", set_davg_zero: bool = True, attn: int = 128, attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, activation_function="tanh", precision: str = "float64", resnet_dt: bool = False, scaling_factor=1.0, normalize=True, temperature=None, smooth: bool = True, type_one_side: bool = False, exclude_types: list[tuple[int, int]] = [], env_protection: float = 0.0, trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, - seed: Optional[Union[int, list[int]]] = None, - type: Optional[str] = None, + seed: Optional[Union[int, list[int]]] = None ):Also remove the parameter from the docstring.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
deepmd/pt/model/descriptor/se_a.py
(1 hunks)deepmd/pt/model/descriptor/se_atten.py
(1 hunks)deepmd/pt/model/descriptor/se_r.py
(1 hunks)deepmd/pt/model/descriptor/se_t.py
(1 hunks)doc/freeze/compress.md
(1 hunks)
🔇 Additional comments (1)
deepmd/pt/model/descriptor/se_a.py (1)
76-80
: LGTM! Type annotations are accurate and helpful.
The type annotations correctly specify torch.Tensor
for tensor arguments and int
for the scalar argument, improving code clarity and type safety.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4342 +/- ##
==========================================
- Coverage 84.35% 84.35% -0.01%
==========================================
Files 593 593
Lines 55899 55899
Branches 3388 3388
==========================================
- Hits 47154 47153 -1
Misses 7635 7635
- Partials 1110 1111 +1 ☔ View full report in Codecov by Sentry. |
Summary by CodeRabbit
New Features
Documentation
compress.md
document.Bug Fixes
Refactor