-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 4 pt descriptor compression #4227
Conversation
…o devel pt_tabulate_fusion recommit by Yan.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
🧹 Outside diff range and nitpick comments (5)
deepmd/tf/utils/tabulate.py (1)
76-82
: Document the purpose of the hardcodedFalse
parameter insuper().__init__
.The last parameter in the superclass initialization is hardcoded as
False
without any explanation of its purpose. This reduces code maintainability and makes it harder to understand the initialization logic.Consider adding a comment explaining what this boolean parameter represents, or better yet, make it an explicit named parameter:
super().__init__( descrpt, neuron, type_one_side, exclude_types, - False, + enable_compression=False, # Disable compression during initialization )deepmd/pt/utils/tabulate.py (4)
31-31
: Correct 'it's' to 'its' in docstring possessivesIn the class docstring, "it's" should be "its" when indicating possession. Please change "as it's uniform stride" to "as its uniform stride" in both occurrences.
46-46
: Update activation function reference in docstringIn the parameter description for
activation_fn
, the reference tocommon.ActivationFn
may be incorrect. It should reflect the correct module path, such asdeepmd.pt.utils.utils.ActivationFn
.
471-471
: Typo in docstring: 'Then' should be 'The'The docstring for
_n_all_excluded
has a typo. It should read "The number of types excluding all types."
21-21
: Unused logger 'log'The logger
log
is defined but not used in the file. If logging is not required, consider removing it to clean up the code.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/pt/utils/tabulate.py
(1 hunks)deepmd/tf/utils/tabulate.py
(3 hunks)deepmd/utils/tabulate.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/pt/utils/tabulate.py
54-54: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
55-55: Do not perform function call ActivationFn
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
(B008)
deepmd/utils/tabulate.py
347-347: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
🔇 Additional comments (5)
deepmd/tf/utils/tabulate.py (1)
84-84
: Acknowledge static analysis warnings.
The code scanning alerts about overwriting attributes and mixing implicit/explicit returns have been addressed in the previous review comments through the suggested refactoring of the descrpt_type
attribute and the _get_descrpt_type
method.
Also applies to: 172-183
deepmd/utils/tabulate.py (4)
347-348
: Avoid using @lru_cache
on instance methods to prevent memory leaks
As previously noted, using functools.lru_cache
on instance methods can lead to memory leaks because the cache holds strong references to self
, preventing garbage collection of class instances. Consider refactoring _all_excluded
to avoid this issue.
🧰 Tools
🪛 Ruff
347-347: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
444-451
: Check for division by zero when rmax
equals rmin
in _spline5_switch
As previously mentioned, if rmax
equals rmin
, the expression (rmax - rmin)
becomes zero, resulting in a ZeroDivisionError
. Ensure that rmax
and rmin
are not equal, or handle this case appropriately to prevent a runtime error.
418-422
: Verify the impact of converting floats to integers
The _convert_numpy_float_to_int
method converts floating-point values of self.lower
and self.upper
to integers. This might lead to loss of precision or unintended truncation. Ensure that this conversion is intentional and does not adversely affect the accuracy of the tabulation results.
243-265
: Ensure array indices are within bounds in _build_lower
In the _build_lower
method, array slicing is performed using indices calculated from nspline
. If nspline
is not computed correctly or is out of expected bounds, this could lead to IndexError
. Verify that nspline
is correctly calculated and that all arrays (vv
, dd
, d2
) have appropriate lengths to prevent out-of-bounds access.
tf's se_t should not pass the test. Can you help me find out what the problem is? I compared the code and it seems to be consistent with the original implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Outside diff range and nitpick comments (1)
deepmd/tf/utils/tabulate.py (1)
37-37
: Add type hints for class attributes.The class should define type hints for its attributes to improve code maintainability and IDE support.
Add type hints at the class level:
class DPTabulate(BaseTabulate): descrpt_type: str graph: tf.Graph graph_def: tf.GraphDef suffix: str functype: int activation_fn: Callable[[tf.Tensor], tf.Tensor]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/pt/utils/tabulate.py
(1 hunks)deepmd/tf/utils/tabulate.py
(3 hunks)deepmd/utils/tabulate.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/pt/utils/tabulate.py
54-54: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
55-55: Do not perform function call ActivationFn
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
(B008)
deepmd/utils/tabulate.py
353-353: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
🔇 Additional comments (11)
deepmd/utils/tabulate.py (4)
19-65
: LGTM! Well-structured class initialization.
The class is well-organized with clear documentation and proper use of abstract base class pattern.
353-368
: Remove @lru_cache to prevent memory leaks.
The existing review comment about potential memory leaks with @lru_cache on instance methods is still valid.
🧰 Tools
🪛 Ruff
353-353: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
449-457
:
Add validation for rmax and rmin parameters.
The method performs division by (rmax - rmin) without validating that they are different.
Add parameter validation:
def _spline5_switch(self, xx, rmin, rmax):
+ if rmax <= rmin:
+ raise ValueError("rmax must be greater than rmin")
if xx < rmin:
vv = 1
elif xx < rmax:
Likely invalid or redundant comment.
429-448
:
Add validation for min_nbor_dist parameter.
The method uses min_nbor_dist in divisions without validating that it's positive and non-zero.
Add parameter validation:
def _get_env_mat_range(self, min_nbor_dist):
+ if min_nbor_dist <= 0:
+ raise ValueError("min_nbor_dist must be positive")
"""Change the embedding net range to sw / min_nbor_dist."""
sw = self._spline5_switch(min_nbor_dist, self.rcut_smth, self.rcut)
Likely invalid or redundant comment.
deepmd/tf/utils/tabulate.py (5)
76-82
: LGTM: Proper initialization of the parent class.
The constructor correctly initializes the parent class with the required parameters.
84-84
: Consider using a cached property for descriptor type.
The descrpt_type
is computed only once during initialization and doesn't change. Using @cached_property
would be more idiomatic and consistent with other properties in the class.
30-32
: LGTM: Clean import of base class.
The import of BaseTabulate
is clean and follows Python import conventions.
172-172
:
Fix implicit return in _get_descrpt_type.
The method has an implicit return path when no conditions match, which could lead to runtime errors.
76-82
: Verify descriptor type initialization.
The initialization of descrpt_type
happens after the parent class initialization. Let's verify that the parent class doesn't depend on this attribute.
Also applies to: 84-84
✅ Verification successful
Let me gather more information about the parent class and its initialization to make a conclusive decision.
Let me try a different approach to find the parent class and its usage of descrpt_type
.
Let me check the parent class (BaseTabulate) to see how it uses the descriptor.
Based on the code analysis, I can now provide the final response:
The initialization order is correct and safe. The parent class BaseTabulate
only initializes shared attributes and sets descrpt_type
to a default value "Base". The actual descriptor type is determined by the child class's _get_descrpt_type
method, which is called after parent initialization. The parent class doesn't use descrpt_type
during initialization, and its methods that use this attribute are only called later during the build
method.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any usage of descrpt_type in the parent class
rg -l "descrpt_type" | grep -v "tabulate.py$" | while read -r file; do
echo "=== $file ==="
rg "descrpt_type" "$file"
done
Length of output: 972
Script:
#!/bin/bash
# Find the parent class implementation and its initialization
ast-grep --pattern 'class $_ {
$$$
__init__($$$) {
$$$
}
$$$
}' deepmd/tf/utils/tabulate.py
Length of output: 105
Script:
#!/bin/bash
# First, let's see the content of tabulate.py to understand the inheritance
cat deepmd/tf/utils/tabulate.py
# Then search for any usage of descrpt_type in the file
rg -A 5 -B 5 "descrpt_type" deepmd/tf/utils/tabulate.py
Length of output: 23892
Script:
#!/bin/bash
# Check the BaseTabulate class implementation
cat deepmd/utils/tabulate.py
Length of output: 16199
deepmd/pt/utils/tabulate.py (2)
515-516
:
Improve error handling in grad_grad function
The function returns -1 for unsupported function types, while the grad function raises a ValueError. This inconsistency could lead to silent failures.
Make error handling consistent:
elif functype == 6:
return y * (1 - y) * (1 - 2 * y)
- else:
- return -1
+ raise ValueError(f"Unsupported function type: {functype}")
Likely invalid or redundant comment.
321-321
:
Add division by zero check in layer size calculation
The calculation basic_size // (self.ntypes - self._n_all_excluded)
could raise a ZeroDivisionError if all types are excluded.
Add a check:
if self.type_one_side:
+ denominator = self.ntypes - self._n_all_excluded
+ if denominator == 0:
+ raise ValueError("Cannot calculate layer size: all types are excluded")
- layer_size = basic_size // (self.ntypes - self._n_all_excluded)
+ layer_size = basic_size // denominator
Likely invalid or redundant comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Outside diff range and nitpick comments (4)
deepmd/utils/tabulate.py (3)
22-65
: Enhance constructor documentation with type hints and attribute descriptions.The constructor's documentation could be improved by:
- Adding type hints for all parameters
- Documenting the purpose of each parameter
- Documenting the attributes that must be initialized in subclasses
Apply this diff to improve the documentation:
def __init__( self, - descrpt, - neuron, - type_one_side, - exclude_types, - is_pt, + descrpt: Any, # TODO: Add specific type + neuron: list[int], + type_one_side: bool, + exclude_types: set[tuple[int, int]], + is_pt: bool, ) -> None: - """Constructor.""" + """Initialize the base tabulate class. + + Parameters + ---------- + descrpt : Any + The descriptor object + neuron : list[int] + List of neurons in each layer + type_one_side : bool + Whether to use one-sided type + exclude_types : set[tuple[int, int]] + Set of type pairs to exclude + is_pt : bool + Whether this is a PyTorch implementation + + Notes + ----- + The following attributes must be initialized in subclasses: + - descrpt_type: str + - sel_a: list + - rcut: float + - rcut_smth: float + - davg: np.ndarray + - dstd: np.ndarray + - ntypes: int + """
336-423
: Enhance abstract method documentation with complete type hints.The abstract methods would benefit from more detailed documentation and complete type hints.
Example improvement for
_get_descrpt_type
:@abstractmethod - def _get_descrpt_type(self): - """Get the descrpt type.""" + def _get_descrpt_type(self) -> str: + """Get the descriptor type. + + Returns + ------- + str + The type of descriptor. Must be one of: + - "Atten" + - "A" + - "T" + - "R" + - "AEbdV2" + """ pass🧰 Tools
🪛 Ruff
354-354: Use of
functools.lru_cache
orfunctools.cache
on methods can lead to memory leaks(B019)
1-458
: Add unit tests for mathematical operations.The file contains complex mathematical operations, particularly in the
build
and_build_lower
methods. Consider adding unit tests to verify:
- Correct calculation of spline coefficients
- Proper handling of boundary conditions
- Accuracy of tabulation results
Would you like me to help generate comprehensive unit tests for these mathematical operations?
🧰 Tools
🪛 Ruff
354-354: Use of
functools.lru_cache
orfunctools.cache
on methods can lead to memory leaks(B019)
deepmd/pt/utils/tabulate.py (1)
81-89
: Moveactivation_map
to a module-level constantThe
activation_map
dictionary is defined inside the__init__
method. Since it does not depend on any instance-specific data, defining it at the module level can improve code clarity and prevent it from being recreated with each instance.You can move
activation_map
outside the class definition:# Module-level constant ACTIVATION_MAP = { "tanh": 1, "gelu": 2, "gelu_tf": 2, "relu": 3, "relu6": 4, "softplus": 5, "sigmoid": 6, } class DPTabulate(BaseTabulate): def __init__(self, ...): # Use ACTIVATION_MAP here
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/pt/utils/tabulate.py
(1 hunks)deepmd/utils/tabulate.py
(1 hunks)source/tests/pt/test_tabulate.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/test_tabulate.py
🧰 Additional context used
🪛 Ruff
deepmd/pt/utils/tabulate.py
54-54: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
55-55: Do not perform function call ActivationFn
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
(B008)
deepmd/utils/tabulate.py
354-354: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
se_a, se_atten(DPA1), se_t, se_r
Summary by CodeRabbit
Release Notes
New Features
enable_compression
methods to various classes, allowing users to enable and configure compression settings.Bug Fixes
Tests
Documentation