Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt/dp): support case embedding and sharable fitting #4417

Merged
merged 16 commits into from
Nov 28, 2024

Conversation

iProzd
Copy link
Collaborator

@iProzd iProzd commented Nov 25, 2024

Summary by CodeRabbit

  • New Features

    • Introduced a set_case_embd method across multiple atomic model classes to enhance case embedding functionality.
    • Added a dim_case_embd parameter in various fitting classes to support case-specific embedding dimensions.
    • Updated serialization methods to include dim_case_embd in the output.
    • Added a comprehensive JSON configuration for multitask models in water simulations.
    • Introduced a new function to validate case embedding dimensions in multi-task training.
    • Updated the share_params method in the DescrptDPA2 class to streamline parameter sharing logic.
  • Bug Fixes

    • Improved version compatibility checks in deserialization methods across several classes.
  • Documentation

    • Enhanced documentation for multi-task training, emphasizing the transition to PyTorch and detailing configuration changes.
  • Tests

    • Updated test cases to incorporate new parameters and configurations related to case embeddings.
    • Introduced new tests for multitask learning configurations.

Copy link
Contributor

coderabbitai bot commented Nov 25, 2024

📝 Walkthrough

Walkthrough

The changes in this pull request primarily introduce a new method set_case_embd across various atomic model classes and fitting classes, allowing them to set case embeddings. Additionally, several classes have been updated to include a new parameter dim_case_embd, which specifies the dimension of case embeddings. The version compatibility checks for serialization and deserialization methods have also been updated across multiple classes. These modifications enhance the handling of case embeddings in the context of multi-task training.

Changes

File Path Change Summary
deepmd/dpmodel/atomic_model/*.py Added set_case_embd(case_idx: int) method to DPAtomicModel, LinearEnergyAtomicModel, DPZBLLinearEnergyAtomicModel, PairTabAtomicModel.
deepmd/dpmodel/atomic_model/make_base_atomic_model.py Introduced an abstract method set_case_embd(case_idx: int) -> None in BAM.
deepmd/dpmodel/fitting/*.py Added dim_case_embd parameter to constructors of DipoleFitting, DOSFittingNet, EnergyFittingNet, GeneralFitting, InvarFitting, PolarFitting, and PropertyFittingNet. Updated version checks in deserialize methods.
deepmd/pt/model/atomic_model/*.py Similar changes as above, adding set_case_embd(case_idx: int) to corresponding classes.
deepmd/pt/model/task/*.py Updated DipoleFittingNet, DOSFittingNet, EnergyFittingNet, InvarFitting, PolarFittingSeA, and PropertyFittingNet to include dim_case_embd and updated version checks in serialize and deserialize methods.
deepmd/utils/argcheck.py Added dim_case_embd argument to fitting_ener, fitting_dos, fitting_property, and fitting_polar functions.
source/tests/pt/model/water/*.json Introduced a new JSON configuration for multitask learning and modified existing configurations to reflect new parameter structures.
deepmd/tf/fit/*.py Added dim_case_embd to DipoleFittingSeA, DOSFitting, EnerFitting, and PolarFittingSeA classes, with validation checks and version updates in serialization methods.
deepmd/pt/model/descriptor/dpa2.py Simplified share_params method and updated change_type_map, serialize, and deserialize methods to reflect new parameter sharing logic.

Suggested reviewers

  • njzjz
  • wanghan-iapcm

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between b81a716 and bd3fbef.

📒 Files selected for processing (2)
  • deepmd/pd/model/task/fitting.py (8 hunks)
  • deepmd/pd/model/task/invar_fitting.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/pd/model/task/fitting.py
  • deepmd/pd/model/task/invar_fitting.py

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

🧹 Outside diff range and nitpick comments (25)
deepmd/dpmodel/fitting/ener_fitting.py (1)

Line range hint 35-78: Consider documenting data ID usage patterns

The addition of numb_dataid support enables new model sharing capabilities. Consider:

  1. Documenting recommended patterns for data ID assignment
  2. Adding examples of how to effectively use this feature in multi-task scenarios
  3. Updating model serialization documentation to reflect this change
source/tests/pt/model/water/multitask_sharefit.json (3)

27-37: Consider optimizing the fitting network architecture

The current configuration has:

  1. numb_dataid: 2 correctly matches the number of models
  2. The fitting network uses identical layer sizes [240, 240, 240]

Consider using a progressive reduction in layer sizes (e.g., [240, 180, 120]) which might provide better feature abstraction while reducing the number of parameters.


41-52: Consider differentiating model configurations

Both models are currently using identical configurations, including the same fitting network dataid (":1"). This setup might not fully leverage the benefits of multitask learning. Consider:

  1. Using different dataid values (":1" and ":2")
  2. Adjusting data_stat_nbatch differently for each model based on their respective dataset sizes

63-82: Review loss weighting strategy

The loss configuration shows:

  1. Identical settings for both models
  2. High initial force preference (1000 → 1)
  3. Disabled volume preference

Consider:

  1. Adjusting preferences based on each model's specific task
  2. Enabling volume preference for better structural prediction
  3. Adding a schedule for preference transition
deepmd/pt/model/task/property.py (1)

86-86: Add documentation for the new parameter.

The numb_dataid parameter is missing from the docstring. Please update the class documentation to include:

numb_dataid : int, optional
    Number of data identifiers. Defaults to 0.
deepmd/pt/model/task/invar_fitting.py (1)

Line range hint 94-146: Overall implementation looks good but needs documentation.

The implementation of numb_dataid support is clean and follows existing patterns. However, please ensure:

  1. Add docstring updates to document the new numb_dataid parameter
  2. Update the class docstring to reflect the version 3 compatibility requirement
  3. Consider adding examples of how to use the new data identification feature
source/tests/universal/dpmodel/fitting/test_fitting.py (1)

Line range hint 42-54: LGTM with a minor suggestion for the docstring.

The addition of numb_dataid is consistent with the PR objectives. Consider making the docstring more explicit about the shared value:

-    numb_param=0,  # test numb_fparam, numb_aparam and numb_dataid together
+    numb_param=0,  # shared value for testing numb_fparam, numb_aparam and numb_dataid
🧰 Tools
🪛 Ruff (0.7.0)

39-39: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)

68-73: Add type hints and enhance documentation.

The implementation looks good, but could benefit from some improvements:

-    def set_dataid(self, data_idx):
+    def set_dataid(self, data_idx: int) -> None:
         """
         Set the data identification of this atomic model by the given data_idx,
         typically concatenated with the output of the descriptor and fed into the fitting net.
+
+        Parameters
+        ----------
+        data_idx : int
+            The data identification index to be set
         """
         self.fitting.set_dataid(data_idx)
deepmd/dpmodel/atomic_model/make_base_atomic_model.py (1)

71-77: Add type hints and enhance documentation for set_dataid.

While the implementation is good, consider these improvements to make the interface more robust:

  1. Add type hints for the data_idx parameter
  2. Enhance the docstring with parameter type, valid range, and error conditions
 @abstractmethod
-def set_dataid(self, data_idx) -> None:
+def set_dataid(self, data_idx: int) -> None:
     """
     Set the data identification of this atomic model by the given data_idx,
     typically concatenated with the output of the descriptor and fed into the fitting net.
+
+    Parameters
+    ----------
+    data_idx : int
+        The index used to identify the data. Must be a non-negative integer.
+
+    Raises
+    ------
+    ValueError
+        If data_idx is negative or exceeds the maximum allowed value.
     """
     pass
deepmd/pt/model/task/ener.py (1)

83-85: Consider documenting version changes and migration path

The version bump to 3 indicates a breaking change in the serialization format. Consider:

  1. Documenting the version change in the changelog
  2. Providing migration instructions for users with existing serialized models
  3. Adding tests to verify backward compatibility handling
deepmd/dpmodel/fitting/invar_fitting.py (1)

126-126: Document the numb_dataid parameter in the class docstring.

The new parameter numb_dataid should be documented in the Parameters section of the class docstring to maintain consistency with the existing comprehensive documentation.

Add this to the Parameters section of the docstring:

    Parameters
    ----------
+   numb_dataid : int, optional
+           Number of data identifiers for the fitting process. Defaults to 0.
    bias_atom
            Bias for each element.

Also applies to: 159-159

deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

96-101: Add error handling for edge cases.

The implementation looks good and aligns with the PR objectives. However, consider adding error handling for robustness:

 def set_dataid(self, data_idx):
     """
     Set the data identification of this atomic model by the given data_idx,
     typically concatenated with the output of the descriptor and fed into the fitting net.
+
+    Parameters
+    ----------
+    data_idx : int
+        The data identification index to be set
+
+    Raises
+    ------
+    AttributeError
+        If fitting_net is None
     """
+    if self.fitting_net is None:
+        raise AttributeError("Cannot set dataid: fitting_net is None")
     self.fitting_net.set_dataid(data_idx)

The suggested changes:

  1. Add type hints and parameter documentation
  2. Add error handling for the case when fitting_net is None
  3. Document the potential exception
deepmd/dpmodel/atomic_model/pairtab_atomic_model.py (2)

123-130: LGTM! Consider enhancing the documentation.

The implementation correctly indicates that data identification is not supported for this model type. However, it would be helpful to extend the docstring to explain why data identification is not supported for the PairTabAtomicModel class, helping users understand the limitation.

Consider updating the docstring to:

     def set_dataid(self, data_idx):
         """
         Set the data identification of this atomic model by the given data_idx,
         typically concatenated with the output of the descriptor and fed into the fitting net.
+
+        This functionality is not supported for PairTabAtomicModel as it uses pre-computed
+        tabulated values and does not involve a fitting process that would benefit from
+        data identification.
         """

123-130: Good architectural alignment with other atomic models.

The implementation properly follows the interface contract by implementing the set_dataid method, while clearly indicating its limitations. This maintains consistency with other atomic models while properly handling the specific constraints of the PairTabAtomicModel.

source/tests/pt/test_multitask.py (1)

Line range hint 236-273: Consider increasing test coverage for shared fitting

While the test implementation is correct, running for only 1 step (numb_steps = 1) might not thoroughly verify the shared fitting behavior, especially parameter sharing during training.

Consider:

  1. Increasing numb_steps to ensure parameters remain properly shared throughout training
  2. Adding assertions to verify the shared parameters after multiple training steps
deepmd/dpmodel/atomic_model/linear_atomic_model.py (1)

137-144: Enhance documentation and error handling

While the implementation correctly propagates the data identification to sub-models, consider the following improvements:

  1. Enhance the docstring by:

    • Specifying the type of data_idx parameter
    • Adding a return type annotation (-> None)
    • Documenting potential exceptions
  2. Add error handling to gracefully handle potential exceptions from sub-models.

-    def set_dataid(self, data_idx):
+    def set_dataid(self, data_idx: int) -> None:
         """
         Set the data identification of this atomic model by the given data_idx,
         typically concatenated with the output of the descriptor and fed into the fitting net.
+
+        Parameters
+        ----------
+        data_idx : int
+            The data identification index to be set
+
+        Raises
+        ------
+        ValueError
+            If data_idx is invalid
+        RuntimeError
+            If setting dataid fails for any sub-model
         """
+        if not isinstance(data_idx, int) or data_idx < 0:
+            raise ValueError(f"data_idx must be a non-negative integer, got {data_idx}")
+
         for model in self.models:
-            model.set_dataid(data_idx)
+            try:
+                model.set_dataid(data_idx)
+            except Exception as e:
+                raise RuntimeError(f"Failed to set dataid for model {model}: {str(e)}")
deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1)

144-151: Enhance docstring with rationale for unsupported feature

The implementation correctly raises NotImplementedError, but the docstring could be more informative by explaining why data identification is not supported for this model type.

     def set_dataid(self, data_idx):
         """
         Set the data identification of this atomic model by the given data_idx,
         typically concatenated with the output of the descriptor and fed into the fitting net.
+
+        This feature is not supported in PairTabAtomicModel as it uses pre-computed
+        tabulated values for pairwise interactions, which are independent of data
+        identification.
         """
         raise NotImplementedError(
             "Data identification not supported for PairTabAtomicModel!"
         )
deepmd/dpmodel/fitting/general_fitting.py (1)

446-458: Consider caching the tiled dataid tensor

The current implementation tiles the dataid tensor for every call. Consider caching the tiled tensor if the shape (nf, nloc) remains constant across multiple calls, which is common in batch processing scenarios.

 if self.numb_dataid > 0:
     assert self.dataid is not None
-    dataid = xp.tile(xp.reshape(self.dataid, [1, 1, -1]), [nf, nloc, 1])
+    # Cache the reshaped dataid
+    if not hasattr(self, '_reshaped_dataid'):
+        self._reshaped_dataid = xp.reshape(self.dataid, [1, 1, -1])
+    # Tile the cached tensor
+    dataid = xp.tile(self._reshaped_dataid, [nf, nloc, 1])
     xx = xp.concat(
         [xx, dataid],
         axis=-1,
     )
deepmd/pt/model/task/fitting.py (3)

218-226: Document the rationale for zeros initialization

There's a commented line suggesting an alternative initialization using an identity matrix. Consider documenting why zeros initialization was chosen over the identity matrix approach, or if the identity matrix approach would be more appropriate.


368-376: Add parameter validation in set_dataid

Consider adding validation for the data_idx parameter to ensure it's within the valid range [0, numb_dataid).

 def set_dataid(self, data_idx):
     """
     Set the data identification of this fitting net by the given data_idx,
     typically concatenated with the output of the descriptor and fed into the fitting net.
     """
+    if not 0 <= data_idx < self.numb_dataid:
+        raise ValueError(f"data_idx must be in range [0, {self.numb_dataid})")
     self.dataid = torch.eye(self.numb_dataid, dtype=self.prec, device=device)[
         data_idx
     ]

507-519: LGTM: Consider caching the tiled dataid tensor

The dataid concatenation is implemented correctly. However, since the dataid tensor is constant during forward passes, consider caching the tiled version to avoid repeated operations.

 if self.numb_dataid > 0:
     assert self.dataid is not None
-    dataid = torch.tile(self.dataid.reshape([1, 1, -1]), [nf, nloc, 1])
+    if not hasattr(self, '_cached_dataid') or self._cached_dataid.shape[:2] != (nf, nloc):
+        self._cached_dataid = torch.tile(self.dataid.reshape([1, 1, -1]), [nf, nloc, 1])
+    dataid = self._cached_dataid
     xx = torch.cat(
         [xx, dataid],
         dim=-1,
     )
deepmd/dpmodel/model/make_model.py (1)

555-557: Add type hints and documentation for the new method.

The new set_dataid method should include type hints and documentation to maintain consistency with the rest of the codebase.

-        def set_dataid(self, data_idx):
-            self.atomic_model.set_dataid(data_idx)
+        def set_dataid(self, data_idx: int) -> None:
+            """Set the data identifier for the atomic model.
+            
+            Parameters
+            ----------
+            data_idx : int
+                The data identifier to be set on the atomic model.
+            """
+            self.atomic_model.set_dataid(data_idx)
deepmd/pt/model/atomic_model/linear_atomic_model.py (2)

161-168: Enhance docstring with type hints and example usage.

The docstring should include parameter type hints and an example usage for better clarity.

 def set_dataid(self, data_idx):
     """
     Set the data identification of this atomic model by the given data_idx,
     typically concatenated with the output of the descriptor and fed into the fitting net.
+
+    Parameters
+    ----------
+    data_idx : int
+        The data identification index to be set.
+
+    Example
+    -------
+    >>> model = LinearEnergyAtomicModel(...)
+    >>> model.set_dataid(0)
     """

161-168: Consider adding error handling for invalid data_idx.

The method should validate the data_idx parameter and handle potential errors from sub-models.

 def set_dataid(self, data_idx):
     """
     Set the data identification of this atomic model by the given data_idx,
     typically concatenated with the output of the descriptor and fed into the fitting net.
     """
+    if not isinstance(data_idx, (int, torch.Tensor)):
+        raise TypeError(f"data_idx must be an integer or tensor, got {type(data_idx)}")
     for model in self.models:
-        model.set_dataid(data_idx)
+        try:
+            model.set_dataid(data_idx)
+        except Exception as e:
+            raise RuntimeError(f"Failed to set data_idx for model {model.__class__.__name__}: {str(e)}")
deepmd/utils/argcheck.py (1)

Line range hint 1436-1469: LGTM! The numb_dataid parameter is consistently implemented across fitting functions.

The implementation:

  • Adds proper documentation explaining the parameter's purpose for multitask models
  • Uses consistent default value of 0
  • Maintains consistent parameter type (int) and optional status
  • Clearly indicates PyTorch-only support

Consider documenting the following in the codebase:

  1. Examples of how to use data identification embeddings in multitask scenarios
  2. Best practices for choosing the dimension of data identification embeddings
  3. Performance implications of using data identification embeddings

Also applies to: 1520-1543, 1575-1593, 1624-1659, 1702-1730

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between ad30709 and ae95c3c.

📒 Files selected for processing (29)
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/make_base_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/pairtab_atomic_model.py (1 hunks)
  • deepmd/dpmodel/fitting/dipole_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/dos_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/ener_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/general_fitting.py (9 hunks)
  • deepmd/dpmodel/fitting/invar_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/polarizability_fitting.py (4 hunks)
  • deepmd/dpmodel/fitting/property_fitting.py (3 hunks)
  • deepmd/dpmodel/model/make_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/linear_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1 hunks)
  • deepmd/pt/model/model/make_model.py (1 hunks)
  • deepmd/pt/model/task/dipole.py (3 hunks)
  • deepmd/pt/model/task/dos.py (3 hunks)
  • deepmd/pt/model/task/ener.py (3 hunks)
  • deepmd/pt/model/task/fitting.py (10 hunks)
  • deepmd/pt/model/task/invar_fitting.py (3 hunks)
  • deepmd/pt/model/task/polarizability.py (4 hunks)
  • deepmd/pt/model/task/property.py (3 hunks)
  • deepmd/pt/train/training.py (2 hunks)
  • deepmd/pt/train/wrapper.py (2 hunks)
  • deepmd/utils/argcheck.py (10 hunks)
  • source/tests/pt/model/water/multitask_sharefit.json (1 hunks)
  • source/tests/pt/test_multitask.py (4 hunks)
  • source/tests/universal/dpmodel/fitting/test_fitting.py (10 hunks)
🔇 Additional comments (53)
deepmd/pt/train/training.py (2)

268-268: Properly pass 'resuming' parameter to model initialization

The addition of the resuming parameter in the call to get_model_for_wrapper ensures that the model loading logic correctly handles resuming states during training.


1270-1285: Verify correct setting of data IDs for multitask models

The logic in get_model_for_wrapper appropriately sets data IDs for each model when not resuming. This ensures that data identification indices are correctly assigned during fresh multitask training sessions.

deepmd/dpmodel/fitting/ener_fitting.py (3)

59-59: LGTM!

The numb_dataid parameter is correctly passed to the superclass.


35-35: LGTM! Verify version compatibility

The addition of numb_dataid parameter with a default value maintains backward compatibility.

✅ Verification successful

Parameter addition is consistently implemented across fitting classes

The numb_dataid parameter with default value 0 is consistently implemented across all fitting classes in the codebase:

  • Base class GeneralFitting properly handles the parameter
  • All derived fitting classes (Energy, Property, DOS, Dipole, Polarizability, etc.) correctly pass it to their superclass
  • Parameter is well-documented in deepmd/utils/argcheck.py as "The dimension of the data identification embedding"
  • Tests in source/tests/universal/dpmodel/fitting/test_fitting.py verify this functionality
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if this version change is consistently applied across related fitting classes
rg -A 2 "numb_dataid.*=.*0" --type python

Length of output: 74


Script:

#!/bin/bash
# Search for numb_dataid in Python files
rg "numb_dataid" -t py

# Check for similar fitting classes
fd "fitting.py$" --exec rg "class.*Fitting" {} \;

Length of output: 8178


78-78: Verify version bump documentation

The version compatibility check has been updated from 2 to 3, which indicates a breaking change. This change needs to be properly documented.

✅ Verification successful

Version bump is part of a consistent update across fitting modules

The version compatibility check update from 2 to 3 in ener_fitting.py is part of a broader, consistent change across the PyTorch and dpmodel fitting modules. The search results show:

  • All fitting modules in deepmd/dpmodel/fitting/ use version 3:

    • ener_fitting.py: version 3
    • invar_fitting.py: version 3
    • property_fitting.py: version 3
    • dipole_fitting.py: version 3
    • dos_fitting.py: version 3
    • polarizability_fitting.py: version 4
  • Similarly, the corresponding PyTorch task modules in deepmd/pt/model/task/ have matching version numbers.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for consistent version numbers across fitting classes
echo "Checking version compatibility in fitting classes..."
rg "check_version_compatibility.*3" --type python

# Check for documentation updates
echo "Checking for version documentation..."
fd -e md -e rst | xargs rg -l "version.*3.*breaking change"

Length of output: 344


Script:

#!/bin/bash
# Check for version compatibility checks in Python files
echo "Checking version compatibility in fitting classes..."
rg "check_version_compatibility.*[0-9]" -t py

# Look for version-related documentation
echo -e "\nChecking documentation files..."
find . -type f -name "*.md" -o -name "*.rst" | xargs grep -l "version"

# Check git history for version-related commits
echo -e "\nChecking recent commits for version changes..."
git log --oneline -n 10 | grep -i "version"

Length of output: 8807

deepmd/dpmodel/fitting/dos_fitting.py (2)

39-39: LGTM: Constructor changes for data ID support

The addition of numb_dataid parameter is well-implemented with proper type hints and default value, maintaining backward compatibility while enabling new functionality.

Also applies to: 64-64


Line range hint 78-85: Verify deserialization of numb_dataid field

While the version compatibility check is updated correctly, we should verify that the numb_dataid field is properly handled during deserialization.

Let's check how other fitting classes handle this field:

✅ Verification successful

Based on the search results, I can now provide a final response:

The numb_dataid field is properly handled during deserialization

The code review shows that:

  1. The numb_dataid field is properly handled in the parent GeneralFitting class which all fitting classes inherit from
  2. The field is:
    • Initialized with a default value of 0 in all fitting classes
    • Properly passed through constructors
    • Included in the serialization via serialize() method
    • Correctly deserialized through the parent class's deserialize() method
  3. The field is used consistently across the codebase for data identification purposes
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check how numb_dataid is handled in deserialization across fitting classes
# Expected: Find consistent patterns of numb_dataid handling in deserialize methods

# Search for deserialize methods in fitting classes
ast-grep --pattern 'def deserialize(cls, $_) {
  $$$
  numb_dataid
  $$$
}'

# Backup: Search with ripgrep for more context
rg -B 5 -A 5 'numb_dataid.*deserialize' deepmd/dpmodel/fitting/

Length of output: 143


Script:

#!/bin/bash
# Let's try a broader search to understand the deserialization patterns in fitting classes

# First, let's find all deserialize methods in fitting classes
ast-grep --pattern 'def deserialize(cls, data):
  $$$
}'

# Also search for any mentions of numb_dataid in the fitting directory
rg "numb_dataid" deepmd/dpmodel/fitting/

# Let's also check the parent class implementation
fd "fitting.py" --type f --exec cat {}

Length of output: 132806


Script:

#!/bin/bash
# Let's search for the handling of numb_dataid in the fitting classes and their deserialization

# Search for numb_dataid in fitting classes
rg -B 5 -A 5 "numb_dataid" deepmd/dpmodel/fitting/

# Search for deserialize methods that handle numb_dataid
ast-grep --pattern 'def deserialize(cls, data):
  $$$
  numb_dataid
  $$$
}'

# Also check the parent class implementation
rg -B 5 -A 5 "deserialize.*data" deepmd/dpmodel/fitting/general_fitting.py

Length of output: 16976

source/tests/pt/model/water/multitask_sharefit.json (2)

11-14: Verify the sel parameters for water system

The selection parameters [46, 92] seem unusually high for a water system (O, H, B). These parameters typically represent the maximum number of neighbors considered, but water molecules usually have much fewer interactions.

✅ Verification successful

The sel parameters [46, 92] are appropriate for water systems

Based on the search results, the selection parameters [46, 92] are consistently used across multiple water system configurations in the codebase:

  • These exact values appear in many water-related configurations including:
    • examples/water/se_e2_a/input.json
    • examples/water_tensor/dipole/dipole_input.json
    • examples/water_tensor/polar/polar_input.json
    • Multiple test configurations under source/tests/tf/water_*

The widespread and consistent use of these selection parameters [46, 92] in water system configurations across the codebase indicates this is a validated parameter choice for water systems, contrary to the initial concern.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other water system configurations to compare sel parameters
rg -A 10 '"sel":\s*\[' --glob '**/*.json' | grep -B 5 -A 5 '"O".*"H"'

Length of output: 71


Script:

#!/bin/bash
# Let's try a different approach to find water-related configuration files and their sel parameters
# First, find all json files with "water" in their path
fd -e json . | grep -i "water"

Length of output: 2368


Script:

#!/bin/bash
# Let's also search for sel parameters in any json files
rg -A 2 '"sel":' --type json

Length of output: 13676


Script:

#!/bin/bash
# Search for type_map configurations that might indicate water systems
rg -A 5 '"type_map".*\[' --type json

Length of output: 21082


56-60: Review learning rate decay parameters

The current learning rate configuration:

  • Starts at 2e-4
  • Decays by 0.98 every 5000 steps
  • Stops at 3.51e-8 (very small)

This aggressive decay might cause training instability. Consider:

  1. Increasing the decay_steps or decay_rate
  2. Using a higher stop_lr (~1e-6)
deepmd/pt/model/task/dos.py (2)

104-104: Verify version compatibility across the codebase.

The version compatibility check has been updated from 2 to 3, which indicates a breaking change.

Let's verify version consistency across related files:

✅ Verification successful

Version compatibility check is consistent across the codebase

The version compatibility check with version 3 is consistently implemented across all related files in both deepmd/dpmodel and deepmd/pt/model directories. All deserialize methods and descriptors uniformly use check_version_compatibility(version, 3, 1), indicating this is an intentional, coordinated update rather than an inconsistency.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check version compatibility checks across the codebase
# Expected: All related classes should use consistent version numbers

# Search for version compatibility checks
rg "check_version_compatibility.*3.*1" -B 2 -A 2

# Search for related version constants or configurations
fd "version|VERSION" --type f --exec rg "version.*=.*3" {} \;

Length of output: 4782


50-50: LGTM! Verify integration with data identification system.

The addition of numb_dataid parameter and its propagation to the parent class is implemented correctly.

Let's verify the consistency of this parameter across related classes:

Also applies to: 77-77

✅ Verification successful

Based on the search results, I can now provide a final response:

The numb_dataid parameter is consistently implemented across the codebase

The addition of numb_dataid parameter in DOSFittingNet is part of a well-integrated data identification system that:

  • Is consistently implemented across all fitting networks (ener, dos, dipole, polarizability, etc.)
  • Has proper documentation indicating it's for multitask model training with data identification embeddings
  • Has consistent default value of 0 across all implementations
  • Is properly handled in the base Fitting class with appropriate tensor initialization and propagation
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if numb_dataid is consistently implemented across fitting classes
# Expected: All fitting classes should have similar numb_dataid parameter handling

# Search for class definitions with numb_dataid parameter
ast-grep --pattern 'class $_ {
  def __init__($$$, numb_dataid: int = 0, $$$) {
    $$$
  }
}'

# Search for super() calls passing numb_dataid
rg "super\(\)\.__init__.*numb_dataid" -A 2

Length of output: 141


Script:

#!/bin/bash
# Let's try a broader search to understand the usage pattern of numb_dataid

# Search for numb_dataid parameter in function definitions
ast-grep --pattern 'def $_($$$ numb_dataid$$$) {
  $$$
}'

# Search for any occurrence of numb_dataid
rg "numb_dataid" -B 2 -A 2

# Search for InvarFitting class definition and its constructor
ast-grep --pattern 'class InvarFitting {
  $$$
  def __init__($$$) {
    $$$
  }
  $$$
}'

Length of output: 28406

deepmd/pt/model/task/property.py (1)

106-106: LGTM!

The numb_dataid parameter is correctly passed to the parent class constructor.

deepmd/dpmodel/fitting/property_fitting.py (3)

81-81: LGTM: Clean implementation of numb_dataid parameter

The addition of the numb_dataid parameter is well-structured with proper type annotation and default value, maintaining backward compatibility. The parameter is correctly forwarded to the parent class.

Also applies to: 103-103


Line range hint 81-103: Verify consistency across fitting classes

Let's ensure the numb_dataid implementation is consistent across all fitting classes.

Also applies to: 116-116

✅ Verification successful

The previous search pattern might have been too specific. Let's try a broader search to find all fitting classes and their initialization parameters, particularly focusing on numb_dataid usage.


Based on the search results, I can now generate the final response as we have sufficient evidence about the numb_dataid implementation across fitting classes.

The numb_dataid parameter implementation is consistent across fitting classes

The implementation is verified to be consistent across the codebase with the following observations:

  • All fitting classes use the same parameter signature: numb_dataid: int = 0
  • The parameter is properly initialized in parent classes and consistently used in child classes
  • The implementation includes proper handling of dataid tensors/arrays when numb_dataid > 0
  • The serialization/deserialization of the parameter is consistent across classes
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for consistent implementation across fitting classes
echo "Checking numb_dataid parameter in other fitting classes..."
ast-grep --pattern 'class $_(InvarFitting) {
  $$$
  def __init__($$$, numb_dataid: int = 0, $$$) {
    $$$
  }
  $$$
}'

echo "Checking version compatibility in other fitting classes..."
fd fitting.py -x rg -A 2 "check_version_compatibility.*3"

Length of output: 1202


Script:

#!/bin/bash
# Let's try a different pattern to find fitting classes and their numb_dataid parameter
echo "Searching for fitting class definitions..."
fd fitting.py -x ast-grep --pattern 'class $_ {
  $$$
  def __init__($$$) {
    $$$
  }
  $$$
}'

echo "Searching for numb_dataid usage..."
fd fitting.py -x rg -B 2 -A 2 "numb_dataid"

Length of output: 7198


116-116: Verify version compatibility handling

The version check has been updated from 2 to 3, which could affect the loading of existing serialized models.

Let's check for version handling across the codebase:

✅ Verification successful

Version compatibility change is properly handled

The version update from 2 to 3 is well-managed in the codebase:

  1. The change is consistently applied across all deserialize methods in both PyTorch and TensorFlow implementations
  2. The model compatibility documentation shows there's a clear upgrade path using dp convert-from utility
  3. The version check implementation in deepmd/utils/version.py ensures backward compatibility by accepting a minimal version (defaulting to 1)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other version checks and migration logic
echo "Checking for version compatibility patterns..."
rg -A 2 "check_version_compatibility.*3" 

echo "Checking for migration utilities or documentation..."
fd -e py -e md . -x rg -l "migration|upgrade|version.*2.*3"

Length of output: 3563


Script:

#!/bin/bash
# Let's check the version compatibility function implementation
echo "Checking version compatibility implementation..."
rg -A 5 "def check_version_compatibility"

# Let's also check for any version-related changes in recent commits
echo "Checking version-related changes in recent history..."
git log -n 5 --oneline --grep="version"

# Let's check the documentation for version compatibility
echo "Checking version compatibility documentation..."
cat doc/troubleshooting/model-compatability.md

Length of output: 1762


Script:

#!/bin/bash
# Let's check the implementation of the version compatibility function
echo "Checking version compatibility function content..."
cat deepmd/utils/version.py

# Let's check for any conversion utilities
echo "Checking for conversion utilities..."
rg -l "convert.*from" --type py

# Let's check if there are any tests related to version compatibility
echo "Checking version-related tests..."
fd -e py test -x rg -l "test.*version|version.*test"

Length of output: 3021

deepmd/pt/model/task/invar_fitting.py (2)

117-117: Verify parent class implementation of numb_dataid.

The parameter is correctly passed to the parent class. Let's verify the parent class implementation.

✅ Verification successful

Parameter numb_dataid is correctly implemented in parent class

The parent class GeneralFitting in both PyTorch and native implementations correctly handles the numb_dataid parameter:

  1. Both implementations accept numb_dataid in their constructor with default value of 0
  2. Both initialize self.dataid buffer/attribute when numb_dataid > 0
  3. Both properly concatenate dataid to the input tensor during forward/call operations
  4. Both have consistent serialization/deserialization support for the parameter
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check parent class implementation
# Expected: Find GeneralFitting class and verify numb_dataid handling

# Search for GeneralFitting class definition and its usage of numb_dataid
ast-grep --pattern 'class GeneralFitting($_):
  $$$
  def __init__($params) {
    $$$
  }
  $$$
}'

Length of output: 100


Script:

#!/bin/bash
# Let's try a different approach to find the parent class implementation
# First, let's locate the file containing GeneralFitting
rg -l "class GeneralFitting"

# Then search for numb_dataid usage in constructor
rg "def __init__.*numb_dataid" -A 5

Length of output: 142


Script:

#!/bin/bash
# Now that we found the files, let's examine their content
echo "=== Content of deepmd/pt/model/task/fitting.py ==="
cat deepmd/pt/model/task/fitting.py

echo -e "\n=== Content of deepmd/dpmodel/fitting/general_fitting.py ==="
cat deepmd/dpmodel/fitting/general_fitting.py

Length of output: 40337


94-94: Verify the usage of numb_dataid parameter.

The new parameter is properly typed and follows the existing parameter pattern. Let's verify its usage across the codebase to ensure consistent implementation.

✅ Verification successful

Parameter numb_dataid is consistently implemented across the codebase

The parameter is properly integrated into the codebase with consistent implementation:

  1. It's used in all fitting classes (energy, dipole, polarizability, DOS, etc.) with the same signature and default value of 0
  2. The parameter is well-documented in deepmd/utils/argcheck.py as "The dimension of the data identification embedding for multitask models"
  3. The base implementation in deepmd/pt/model/task/fitting.py handles the parameter correctly:
    • Initializes data identification tensors when numb_dataid > 0
    • Properly includes it in the input dimension calculations
  4. The training module (deepmd/pt/train/training.py) validates that all models have consistent numb_dataid values
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check how numb_dataid is used across the codebase
# Expected: Find other fitting classes and their usage of numb_dataid

# Search for other fitting classes using numb_dataid
rg -l "numb_dataid" --type python

# Search for actual usage of the parameter in code
ast-grep --pattern 'numb_dataid = $_'

Length of output: 104


Script:

#!/bin/bash
# Let's try different search approaches

# Search for files containing numb_dataid
rg -l "numb_dataid"

# Search for actual usage with context
rg -A 3 "numb_dataid"

# Search for class definitions that might use this parameter
ast-grep --pattern 'class $_ {
  def __init__($$$, numb_dataid: int$$$) {
    $$$
  }
}'

Length of output: 23938

source/tests/universal/dpmodel/fitting/test_fitting.py (2)

Line range hint 81-216: Implementation is consistently applied across all fitting types.

The addition of numb_dataid is uniformly implemented across all fitting parameter functions (Dos, Dipole, Polar, Property), maintaining consistency in:

  • Parameter documentation
  • Input dictionary structure
  • Test parameterization

This consistency ensures that data identification is properly tested across all fitting types.


Line range hint 220-231: Verify test coverage for numb_dataid parameter.

The parameterization tests numb_dataid with values 0 and 2, which is consistent with the existing tests for numb_fparam and numb_aparam. Let's verify that these test cases provide adequate coverage.

✅ Verification successful

Based on the comprehensive review of the test files, I can now generate the final response:

Test coverage for numb_dataid parameter is adequate.

The parameter is well-tested through the following mechanisms:

  • It's tested alongside numb_fparam and numb_aparam with values 0 and 2 across all fitting types (Energy, DOS, Dipole, Polar, and Property)
  • The test coverage is implemented through parameterized tests that verify the parameter's functionality across different configurations:
    • Different precision types ("float64", "float32", "bfloat16")
    • Mixed and non-mixed types (True/False)
    • Different backends (TF, PyTorch, JAX, Array API)
  • The parameter is consistently tested in combination with other related parameters like exclude_types and precision
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for additional tests that might be relevant to dataid functionality
rg -l "numb_dataid" "source/tests"

# Look for any existing test assertions related to dataid
rg "assert.*dataid" "source/tests"

Length of output: 174


Script:

#!/bin/bash
# Let's examine the test file content to understand the test coverage
cat source/tests/universal/dpmodel/fitting/test_fitting.py

# Also check the JSON file that mentions numb_dataid
cat source/tests/pt/model/water/multitask_sharefit.json

Length of output: 9279


Script:

#!/bin/bash
# Let's check the test base class to see how these parameters are used in tests
cat source/tests/common/cases/fitting/fitting.py

Length of output: 127


Script:

#!/bin/bash
# Let's try to find the correct path for the fitting test base class
fd -t f "fitting.py" source/tests

# And check for any other test files that might be testing the fitting functionality
rg -l "class.*Fitting.*Test" source/tests

Length of output: 1421


Script:

#!/bin/bash
# Let's check the base test class and consistent test cases
cat source/tests/universal/common/cases/fitting/fitting.py
cat source/tests/consistent/fitting/common.py
cat source/tests/consistent/fitting/test_ener.py

Length of output: 11456

🧰 Tools
🪛 Ruff (0.7.0)

39-39: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

deepmd/pt/model/task/dipole.py (3)

84-84: LGTM: Parameter addition follows best practices

The new numb_dataid parameter is well-integrated with proper type hints and a sensible default value.


107-107: LGTM: Proper parent class initialization

The numb_dataid parameter is correctly forwarded to the parent class constructor.


133-133: Verify version compatibility across the codebase

The version check has been updated from 2 to 3. This change needs verification to ensure:

  1. Consistent version numbers across related fitting classes
  2. Proper handling of existing serialized models
✅ Verification successful

Version compatibility change is consistent with most task modules

The version check update from 2 to 3 in dipole.py aligns with most other task modules:

  • property.py, invar_fitting.py, dos.py, ener.py all use version 3
  • Only polarizability.py uses a higher version (4)

This indicates the version bump is part of a coordinated update across the codebase and maintains consistency with related classes.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check version compatibility across fitting classes

# Check version numbers in other fitting classes
echo "Checking version numbers in fitting classes..."
rg -A 1 'check_version_compatibility.*version.*\d+.*\d+.*\d+' deepmd/pt/model/task/

# Check for any migration utilities or version handling code
echo "Checking for version migration code..."
rg -l 'deserialize|serialize' deepmd/pt/model/task/

Length of output: 1573

deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)

68-73: Verify fitting component integration.

Let's ensure the set_dataid method is properly integrated with the fitting component.

✅ Verification successful

Integration of set_dataid method is properly implemented

The verification shows that:

  • The set_dataid method is consistently implemented across all fitting components in both pt and dpmodel paths
  • The method is properly propagated through the model hierarchy:
    • Called by training code (deepmd/pt/train/training.py)
    • Implemented in base models and atomic models
    • Correctly delegated to the fitting components
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the set_dataid method implementation in fitting components

# Check if all fitting components implement set_dataid
echo "Checking fitting components for set_dataid implementation..."
ast-grep --pattern 'class $class implements BaseFitting {
  $$$
  set_dataid($$$) {
    $$$
  }
  $$$
}'

# Check for any direct usage of set_dataid
echo "Checking usage of set_dataid..."
rg "set_dataid" -A 2

Length of output: 4471

deepmd/dpmodel/atomic_model/make_base_atomic_model.py (1)

71-77: Verify implementations across model types.

Let's ensure all atomic model implementations properly implement the new set_dataid method with consistent behavior.

✅ Verification successful

Let me try a different search strategy to find the implementations.


All atomic model implementations properly implement the required method

The verification shows that all classes inheriting from BaseAtomicModel have implemented the set_dataid method consistently:

  • DPAtomicModel
  • PairTabAtomicModel
  • LinearEnergyAtomicModel

Each implementation includes the proper docstring and follows the same signature as the abstract method. The implementations exist in both the PyTorch (pt) and core (dpmodel) versions of the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check implementations of set_dataid across atomic models
# Expected results: All atomic model classes should implement set_dataid

# Find all potential atomic model implementations
echo "Searching for atomic model implementations..."
rg -l "class.*\(.*BAM.*\)" --type py

# Check for set_dataid implementations
echo "Checking set_dataid implementations..."
ast-grep --pattern 'class $CLASS {
  $$$
  def set_dataid($ARGS) {
    $$$
  }
  $$$
}'

Length of output: 316


Script:

#!/bin/bash
# Search for class definitions that inherit from BaseAtomicModel
echo "Searching for classes inheriting from BaseAtomicModel..."
rg "class.*\(.*BaseAtomicModel.*\)" -A 5 --type py

# Search specifically for set_dataid method implementations
echo -e "\nSearching for set_dataid implementations..."
rg "def set_dataid" -B 2 -A 2 --type py

Length of output: 8848

deepmd/pt/train/wrapper.py (1)

Line range hint 115-131: Verify model structure compatibility

The changes assume all models have an atomic_model attribute. Let's verify this assumption across the codebase.

✅ Verification successful

Based on the extensive search results, I can now make a definitive assessment of the atomic_model structure across the codebase:

All model implementations consistently follow the atomic_model pattern

  • The codebase has a well-structured hierarchy of atomic models:
    • Base class: BaseAtomicModel defined in deepmd/dpmodel/atomic_model/base_atomic_model.py
    • Multiple specialized implementations like DPAtomicModel, DPEnergyAtomicModel, DPPolarAtomicModel etc.
    • All model classes properly inherit and implement the atomic_model interface
    • The attribute access pattern using __getattr__ is consistently implemented across all model types

The code changes in the wrapper are safe as they rely on a standardized atomic_model structure that is uniformly present across the entire codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if all model implementations have the expected atomic_model structure

# Search for class definitions that might be used with ModelWrapper
ast-grep --pattern 'class $name($_) {
  $$$
  def forward($$$) {
    $$$
  }
  $$$
}'

# Search for atomic_model references to understand the expected structure
rg "atomic_model" -A 5 -B 5

Length of output: 112437

deepmd/dpmodel/fitting/dipole_fitting.py (2)

98-98: LGTM: Clean parameter addition

The new numb_dataid parameter is well-placed and properly typed with a sensible default value.


134-134: LGTM: Proper parent class initialization

The numb_dataid parameter is correctly passed to the parent class constructor.

deepmd/pt/model/task/ener.py (2)

53-53: LGTM: Parameter addition follows consistent pattern

The addition of numb_dataid parameter with proper typing and default value aligns with the existing parameter pattern and supports the new dataid functionality.


71-71: LGTM: Correct parameter propagation to superclass

The numb_dataid parameter is properly passed to the superclass constructor, maintaining the inheritance chain.

deepmd/dpmodel/fitting/invar_fitting.py (1)

188-188: LGTM! Version compatibility update is appropriate.

The version check update from 2 to 3 properly reflects the breaking change in the serialization format due to the addition of the numb_dataid parameter.

deepmd/pt/model/task/polarizability.py (2)

Line range hint 196-207: Verify version consistency across serialization methods

The version bump from 3 to 4 and compatibility check update look good. This aligns with the addition of the numb_dataid parameter.

Let's verify version consistency across related files:

✅ Verification successful

Version consistency verified across the codebase

The version bump from 3 to 4 in the polarizability fitting module is consistent across both implementations:

  • deepmd/pt/model/task/polarizability.py: version 4
  • deepmd/dpmodel/fitting/polarizability_fitting.py: version 4

Both files maintain backward compatibility by accepting versions 1 through 4 in their deserialize methods. The version bump is isolated to polarizability-related modules and doesn't affect other components, which maintain their respective versions.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check version numbers in serialize/deserialize methods
# Look for other fitting classes and their version numbers

# Search for version numbers in serialize methods
rg -A 1 '@version.*=.*[0-9]+' --type py

# Search for version compatibility checks
rg -A 1 'check_version_compatibility.*[0-9]+.*[0-9]+.*[0-9]+' --type py

Length of output: 7971


Line range hint 88-132: Verify consistent usage of numb_dataid across fitting classes

The addition of numb_dataid parameter looks good. The default value and superclass propagation are correctly implemented.

Let's verify the consistency across other fitting classes:

✅ Verification successful

Based on the search results, I can now generate the final response since we have enough information to verify the consistency of numb_dataid usage across fitting classes:

The numb_dataid parameter is consistently implemented

The parameter is correctly implemented across all fitting classes with consistent default value of 0 and proper inheritance chain:

  • All fitting classes inherit from either GeneralFitting or InvarFitting
  • The parameter is consistently passed to the parent class in all implementations
  • Default value of 0 is maintained across all classes (EnergyFittingNet, DOSFittingNet, DipoleFittingNet, PropertyFittingNet, PolarFittingNet)
  • Parameter type and position in constructor signature is consistent
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if numb_dataid is consistently implemented across fitting classes
# Look for other fitting classes and their numb_dataid parameter

# Search for fitting class definitions and their numb_dataid parameter
ast-grep --pattern 'class $_(GeneralFitting):
  $$$
  def __init__($$$, numb_dataid: int = 0, $$$):
    $$$'

Length of output: 110


Script:

#!/bin/bash
# Let's try a different approach to find fitting classes and their numb_dataid parameter

# First, find all Python files that might contain fitting classes
echo "=== Files containing 'Fitting' classes ==="
rg "class.*Fitting" --type py

echo -e "\n=== Checking numb_dataid parameter in constructors ==="
# Look for constructor definitions with numb_dataid parameter
rg "def __init__.*numb_dataid.*=" --type py -A 2

Length of output: 8488


Script:

#!/bin/bash
# Let's check the implementation of GeneralFitting class since it's the parent class
echo "=== Checking GeneralFitting implementation ==="
rg "class GeneralFitting" -A 20 --type py

echo -e "\n=== Checking PolarFittingNet and similar classes ==="
# Look for specific fitting classes and their constructors
for class in "PolarFittingNet" "DipoleFittingNet" "EnergyFittingNet" "DOSFittingNet" "PropertyFittingNet"; do
    echo -e "\n--- $class implementation ---"
    rg "class $class.*:.*\n.*def __init__" -A 10 --type py
done

Length of output: 5422


Script:

#!/bin/bash
# Let's fix the regex and use multiline mode
echo "=== Checking fitting class implementations ==="
for class in "PolarFittingNet" "DipoleFittingNet" "EnergyFittingNet" "DOSFittingNet" "PropertyFittingNet"; do
    echo -e "\n--- $class implementation ---"
    rg "class $class.*:" --multiline --type py -A 30
done

echo -e "\n=== Checking numb_dataid in GeneralFitting constructor ==="
rg "def __init__.*numb_dataid.*:.*$" --multiline --type py -B 5 -A 10 deepmd/pt/model/task/fitting.py

Length of output: 37190

deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

96-101: Verify integration with tests and other atomic models.

Let's ensure proper test coverage and consistent implementation across atomic models.

deepmd/dpmodel/fitting/polarizability_fitting.py (3)

104-104: LGTM: Parameter addition is consistent and well-implemented

The addition of numb_dataid parameter follows the established pattern across fitting classes and is properly integrated into both the constructor signature and superclass initialization.

Also applies to: 154-154


192-192: LGTM: Version update aligns with feature addition

The version bump from 3 to 4 appropriately tracks the addition of dataid support while maintaining backward compatibility.


203-203: LGTM: Version compatibility check properly updated

The version compatibility check is correctly updated to handle versions 1 through 4, maintaining backward compatibility while supporting the new dataid feature.

source/tests/pt/test_multitask.py (3)

55-55: LGTM: Well-structured parameter comparison for shared fitting

The implementation correctly handles shared fitting by:

  1. Adding a configurable share_fitting flag
  2. Excluding model-specific parameters (bias_atom_e, dataid) from comparison

Also applies to: 70-75


42-48: Verify the existence of multitask_sharefit.json template file

The template loading looks good, but we should verify that the file exists to prevent runtime errors.

✅ Verification successful

Template file exists at the expected location

The multitask_sharefit.json template file exists at ./source/tests/pt/model/water/multitask_sharefit.json, which is consistent with the path being constructed in the code relative to __file__. The template loading implementation is correct and safe.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if the template file exists
# Expected: File should exist in the water directory

fd -t f "multitask_sharefit.json" -x echo "Template file found: {}"

Length of output: 142


Line range hint 1-394: Verify test coverage for shared fitting functionality

Let's ensure we have adequate test coverage for the new shared fitting functionality.

✅ Verification successful

Let me gather more information about the test coverage by examining the implementation details.


Let me check the test configuration files to understand the test coverage better.


Test coverage for shared fitting functionality is adequate

The test coverage for shared fitting functionality is comprehensive with:

  • Dedicated test configuration (multitask_sharefit.json) that properly defines shared fitting networks
  • TestMultiTaskSeASharefit class that explicitly tests shared fitting with appropriate assertions
  • Proper verification of parameter sharing in test_multitask_train method
  • Clear distinction between shared and non-shared fitting configurations

The TODOs found in other test files are unrelated to the multitask shared fitting functionality.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check test coverage and related test files
# Expected: Should find related test files and verify coverage

# Find all related test files
echo "Related test files:"
fd -t f "test.*\.py" -x grep -l "share_fitting"

# Check for any TODOs or missing test cases
echo -e "\nPotential missing test cases:"
rg -i "todo|fixme|xxx" -g "test_*"

Length of output: 405


Script:

#!/bin/bash
# Check implementation of shared fitting functionality
ast-grep --pattern 'def test_multitask_train' -A 10

# Check test class implementations
rg "class.*Test.*ShareFit" -A 5

# Check for any other shared fitting related implementations
rg "share_fitting" -A 2 -B 2

Length of output: 1371


Script:

#!/bin/bash
# Check the multitask template files
echo "Content of multitask.json:"
cat source/tests/pt/water/multitask.json

echo -e "\nContent of multitask_sharefit.json:"
cat source/tests/pt/water/multitask_sharefit.json

# Check if there are any other test configurations
echo -e "\nOther test configurations:"
fd -t f "multitask.*\.json"

Length of output: 6216

deepmd/dpmodel/atomic_model/linear_atomic_model.py (1)

137-144: Verify sub-model compatibility and version handling

Please ensure:

  1. All possible sub-model types implement the set_dataid method
  2. Version compatibility checks are updated to handle this new feature
✅ Verification successful

Let me gather more specific information about the set_dataid implementation across models.


Let me check one more aspect to verify the inheritance chain and method implementations.


Based on the gathered information, I can now provide the final response:

Implementation of set_dataid is properly handled across the model hierarchy

The verification shows:

  • set_dataid is defined as an abstract method in the base class (make_base_atomic_model.py)
  • All atomic model implementations properly handle the method:
    • linear_atomic_model.py: Correctly propagates to sub-models
    • dp_atomic_model.py: Properly delegates to fitting net
    • pairtab_atomic_model.py: Explicitly raises NotImplementedError as intended

Version compatibility is also properly maintained with @version checks in the relevant model implementations.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that all atomic model implementations have set_dataid method

# Find all atomic model implementations
echo "Checking for atomic model implementations..."
rg -l "class.*AtomicModel.*\(" 

# Check for set_dataid implementations
echo "Checking for set_dataid implementations..."
ast-grep --pattern 'class $CLASS extends $_ {
  $$$
  set_dataid($ARGS) {
    $$$
  }
  $$$
}'

# Check version compatibility handling
echo "Checking version compatibility..."
rg "check_version_compatibility|@version" --type py

Length of output: 18104


Script:

#!/bin/bash
# Search for set_dataid method implementations in atomic models
echo "Checking set_dataid implementations in atomic models..."
ast-grep --pattern 'def set_dataid($ARGS) {
  $$$
}'

# Search for base class implementation
echo -e "\nChecking base class implementation..."
rg -A 5 "class.*BaseAtomicModel" deepmd/dpmodel/atomic_model/base_atomic_model.py

Length of output: 562


Script:

#!/bin/bash
# Check BaseAtomicModel_ implementation and inheritance
echo "Checking BaseAtomicModel_ implementation..."
rg -B 2 -A 10 "class BaseAtomicModel_" --type py

# Check all atomic model classes that might implement set_dataid
echo -e "\nChecking set_dataid in model files..."
rg -B 2 -A 5 "def set_dataid" deepmd/dpmodel/atomic_model/ deepmd/pt/model/atomic_model/ deepmd/jax/atomic_model/

Length of output: 5424

deepmd/dpmodel/fitting/general_fitting.py (4)

108-108: LGTM: Constructor changes for data ID support

The addition of numb_dataid parameter is well-integrated into the constructor, maintaining backward compatibility with a default value of 0.

Also applies to: 131-131


176-179: LGTM: Data ID initialization is consistent

The initialization of self.dataid follows the established pattern and properly handles both zero and non-zero cases with appropriate precision.


185-185: LGTM: Input dimension calculation updated correctly

The addition of self.numb_dataid to the input dimension calculation is properly integrated.


308-308: LGTM: Serialization properly updated

The serialization changes are complete with version increment and proper inclusion of new fields.

Also applies to: 316-316, 325-325

deepmd/pt/model/task/fitting.py (5)

70-71: LGTM: Consistent dataid sharing implementation

The addition of dataid sharing follows the same pattern as bias_atom_e, ensuring consistent parameter sharing between instances.


144-144: LGTM: Well-documented numb_dataid parameter

The numb_dataid parameter is properly initialized with a default value of 0, maintaining backward compatibility, and is well-documented in the class docstring.

Also applies to: 166-166


231-231: LGTM: Correct input dimension update

The input dimension is properly updated to include the dataid features.


389-390: LGTM: Consistent dataid accessor implementation

The dataid handling in setitem and getitem follows the same pattern as other parameters, maintaining consistency.

Also applies to: 407-408


291-291: Verify version compatibility handling

The version has been incremented from 2 to 3 with the addition of dataid-related fields. Ensure that the deserialize method properly handles both version 2 and 3 formats.

Also applies to: 299-299, 308-308

deepmd/dpmodel/model/make_model.py (1)

555-557: Verify consistent implementation across related classes.

Since this is part of a larger change to support data identification, let's verify that all required classes have implemented the set_dataid method consistently.

✅ Verification successful

Implementation is consistent across the codebase

The verification shows that set_dataid is properly implemented across all relevant classes:

  • Abstract method defined in base class make_base_atomic_model.py
  • Consistent implementations in all atomic models:
    • dp_atomic_model.py
    • linear_atomic_model.py
    • pairtab_atomic_model.py
  • Proper delegation in model classes (make_model.py)
  • Implementation in fitting classes
  • All implementations include proper documentation

The method under review follows the established pattern of delegating to the atomic model, which is consistent with other implementations in the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for consistent implementation of set_dataid across related classes

# Search for class definitions and their set_dataid implementations
echo "Checking set_dataid implementations across classes:"
ast-grep --pattern 'class $CLASS_NAME {
  $$$
  set_dataid($PARAM) {
    $$$
  }
  $$$
}'

# Search for any TODO or unimplemented set_dataid methods
echo -e "\nChecking for unimplemented or TODO set_dataid methods:"
rg -A 2 'def set_dataid.*:.*(\#.*TODO|raise\s+NotImplementedError|pass\s*$)'

Length of output: 408


Script:

#!/bin/bash
# Let's try a different approach to find set_dataid implementations

# First, find all Python files
echo "Python files containing 'set_dataid':"
rg -l "set_dataid" --type py

echo -e "\nImplementations of set_dataid:"
rg -B2 -A2 "def set_dataid" --type py

echo -e "\nUsages/calls of set_dataid:"
rg "\.set_dataid\(" --type py

Length of output: 5882

deepmd/pt/model/atomic_model/linear_atomic_model.py (1)

161-168: Verify set_dataid implementation in all model classes.

Let's verify that all model classes in the inheritance chain implement the set_dataid method to ensure consistent behavior.

✅ Verification successful

Implementation of set_dataid is consistent across model hierarchy

The verification shows that set_dataid is properly implemented across the model hierarchy:

  1. The abstract method is defined in BaseAtomicModel (deepmd/dpmodel/atomic_model/make_base_atomic_model.py)
  2. Concrete implementations exist in all required model classes:
    • DPAtomicModel (deepmd/dpmodel/atomic_model/dp_atomic_model.py)
    • LinearAtomicModel (deepmd/dpmodel/atomic_model/linear_atomic_model.py)
    • PairTabAtomicModel (deepmd/dpmodel/atomic_model/pairtab_atomic_model.py)
  3. The method is also properly propagated through model wrappers (make_model)
  4. The implementation is consistent across both PyTorch and DPModel backends

The implementation follows the composite pattern correctly, with each model properly delegating the call to its components.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if all model classes implement set_dataid method

# Search for class definitions that might need set_dataid
echo "Searching for model classes..."
ast-grep --pattern 'class $NAME($_) {
  $$$
}' --lang python

# Search for set_dataid implementations
echo "Searching for set_dataid implementations..."
ast-grep --pattern 'def set_dataid($$$) {
  $$$
}' --lang python

Length of output: 1197


Script:

#!/bin/bash
# Search for class definitions and set_dataid implementations using ripgrep
echo "Searching for model classes and set_dataid implementations..."
rg "class.*Model" -A 5 --type python
echo "---"
rg "def set_dataid" -A 2 --type python

Length of output: 289


Script:

#!/bin/bash
# Search for class definitions and set_dataid implementations
echo "Searching for model classes..."
rg "class.*Model" -A 5
echo "---"
echo "Searching for set_dataid implementations..."
rg "def set_dataid" -A 2

Length of output: 156381


Script:

#!/bin/bash
# Search specifically for set_dataid implementations
echo "Searching for set_dataid implementations..."
rg "def set_dataid" -B 2 -A 2

Length of output: 4504

deepmd/pt/model/model/make_model.py (1)

517-519: LGTM: Clean implementation of data ID delegation.

The set_dataid method correctly delegates the data ID setting to the underlying atomic model, maintaining the class's responsibility hierarchy.

Let's verify the consistency of set_dataid implementation across the codebase:

✅ Verification successful

Implementation of set_dataid is consistent across the codebase

The verification shows that:

  • The set_dataid method is consistently implemented across all atomic models (dp_atomic_model.py, linear_atomic_model.py, pairtab_atomic_model.py)
  • The method signature is uniform, taking a single data_idx parameter
  • The delegation pattern is correctly followed in both PyTorch (pt) and base model implementations
  • The method is properly used in the training module for setting data identification
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for consistent implementation of set_dataid across atomic models
# Expected: All atomic models should have set_dataid method with similar signature

# Search for set_dataid method definitions
ast-grep --pattern 'def set_dataid($params) { $$$ }'

# Search for any direct usage of set_dataid
rg 'set_dataid' -A 2

Length of output: 4215

deepmd/utils/argcheck.py (1)

Line range hint 1436-1469: Verify PyTorch-only implementation of numb_dataid

The parameter is marked as PyTorch-only. Let's verify there are no TensorFlow dependencies.

✅ Verification successful

Based on the gathered evidence, I can now provide the final response:

PyTorch-only implementation of numb_dataid is correctly marked

The verification confirms that:

  • The parameter is only implemented in PyTorch files under deepmd/pt/
  • No implementation exists in TensorFlow-specific directories
  • The implementation is properly isolated to PyTorch modules with data identification embedding support
  • The universal test suite handles this parameter appropriately without TensorFlow dependencies
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any TensorFlow-specific implementations of numb_dataid
# Test: Search for numb_dataid in TensorFlow-specific code. Expect: No matches.

# Look for numb_dataid in TensorFlow-specific files
rg -l "numb_dataid" | rg -i "tensorflow|tf"

Length of output: 43


Script:

#!/bin/bash
# Let's first find all occurrences of numb_dataid to understand its usage
rg -l "numb_dataid"

Length of output: 701


Script:

#!/bin/bash
# Let's also check for any implementation files containing dataid to ensure proper backend support
rg -p "class.*Fitting" | rg -B 5 -A 10 "dataid"

Length of output: 47


Script:

#!/bin/bash
# Check test files for any backend-specific tests
fd "test.*\.py" | xargs rg "numb_dataid|data_id"

Length of output: 1155


Script:

#!/bin/bash
# Let's check the content of the files to verify if the implementation is PyTorch-only
rg -B 3 -A 3 "numb_dataid" deepmd/tf/

Length of output: 37


Script:

#!/bin/bash
# Check the test file to understand the implementation details
cat source/tests/universal/dpmodel/fitting/test_fitting.py

Length of output: 6394


Script:

#!/bin/bash
# Check PyTorch implementation details
rg -B 5 -A 5 "numb_dataid" deepmd/pt/model/task/fitting.py

Length of output: 2991

deepmd/pt/train/training.py Outdated Show resolved Hide resolved
deepmd/pt/train/training.py Outdated Show resolved Hide resolved
deepmd/pt/model/task/property.py Show resolved Hide resolved
deepmd/dpmodel/fitting/dipole_fitting.py Show resolved Hide resolved
deepmd/pt/model/task/ener.py Show resolved Hide resolved
source/tests/pt/test_multitask.py Outdated Show resolved Hide resolved
deepmd/pt/model/atomic_model/pairtab_atomic_model.py Outdated Show resolved Hide resolved
deepmd/dpmodel/fitting/general_fitting.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
examples/water_multi_task/pytorch_example/input_torch_sharefit.json (2)

76-76: Document the new fitting network reference syntax

The syntax shared_fit_with_id:1 appears to use a new format with :1 suffix. Consider adding documentation to explain this syntax and its implications.

Also applies to: 81-81


93-110: Document the loss configuration choices

A few points to consider:

  1. Both models use identical loss configurations. If this is intentional, please document the reasoning.
  2. Volume preferences are set to 0. Consider documenting why volume is not included in the loss calculation.
doc/train/multi-task-training.md (1)

59-65: Consider enhancing the documentation with additional details.

To make the documentation even more helpful, consider adding:

  1. Trade-offs between the two approaches to help users choose the most appropriate one
  2. Guidelines on when to use each approach based on use cases
  3. Explanation of how different shared_level values affect the sharing behavior

Example addition:

    2. **Descriptor and fitting network sharing with data identification**:
       - Share the descriptor with `shared_level`=0.
       - Share the fitting network with `shared_level`=1.
       - {ref}`numb_dataid <model[standard]/fitting_net[ener]/numb_dataid>` must be set to the number of model branches, which will distinguish different data tasks using a one-hot embedding.
       - See [here](../../examples/water_multi_task/pytorch_example/input_torch_sharefit.json) for an example.
+      
+    **Choosing between approaches:**
+    - Use "Descriptor sharing only" when your tasks require different fitting networks due to significantly different output properties or scales
+    - Use "Descriptor and fitting network sharing" when your tasks are similar but need task-specific adaptations
+    
+    The `shared_level` parameter controls the degree of parameter sharing:
+    - Level 0: Complete sharing of all parameters
+    - Level 1: Partial sharing with task-specific adaptations
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between ae95c3c and 412aea9.

📒 Files selected for processing (3)
  • doc/train/multi-task-training.md (1 hunks)
  • examples/water_multi_task/pytorch_example/input_torch_sharefit.json (1 hunks)
  • source/tests/common/test_examples.py (1 hunks)
🔇 Additional comments (7)
source/tests/common/test_examples.py (2)

67-67: LGTM! Good test coverage for sharable fitting.

The addition of input_torch_sharefit.json to input_files_multi appropriately extends test coverage for the new sharable fitting functionality while leveraging the existing multi-task testing infrastructure.


67-67: Verify the new JSON file.

Let's ensure the new configuration file exists and follows the expected structure for multi-task configurations.

✅ Verification successful

JSON file structure verified and valid for multi-task configuration

The JSON file exists and follows the expected structure for multi-task configurations with:

  • Proper model configuration with shared descriptors and fitting networks
  • Two water models (water_1 and water_2) sharing the same descriptor and fitting network
  • Complete training configuration with proper loss functions and data paths
  • Appropriate shared fitting configuration through shared_fit_with_id mechanism
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the existence and structure of the new JSON file

# Check if the file exists
FILE="examples/water_multi_task/pytorch_example/input_torch_sharefit.json"
if [ ! -f "$FILE" ]; then
    echo "Error: $FILE does not exist"
    exit 1
fi

# Verify JSON structure contains required multi-task fields
jq -e '.model.type and .model.descriptor and .model.fitting_net' "$FILE"

# Check for sharable fitting configuration
jq -e '.model | select(has("fitting_net_dict") or has("share_fitting"))' "$FILE"

Length of output: 443


Script:

#!/bin/bash
# Let's check the actual content of the JSON file
echo "Content of the JSON file:"
cat examples/water_multi_task/pytorch_example/input_torch_sharefit.json

# Also check if the file is referenced in any test files
echo -e "\nFiles referencing this JSON:"
rg -l "input_torch_sharefit.json" --type py

Length of output: 3995

examples/water_multi_task/pytorch_example/input_torch_sharefit.json (4)

67-67: Verify numb_dataid matches model count

The numb_dataid value of 2 correctly matches the number of models (water_1 and water_2) that share the fitting network.


73-82: Consider consolidating identical model configurations

Both water_1 and water_2 models use identical configurations. If this is intentional, consider adding a comment explaining why two separate models are needed. If not, consider consolidating them into a single model definition.


85-91: LGTM: Learning rate configuration is well-structured

The exponential decay configuration with the specified parameters provides a good balance for training stability and convergence.


148-153: Consider adjusting training parameters

The configuration uses reasonable defaults for steps and frequencies, but consider:

  1. 100,000 steps might be excessive for this setup - verify if this matches your convergence requirements
  2. Equal save_freq and disp_freq (100) might generate too many checkpoints - consider increasing save_freq
doc/train/multi-task-training.md (1)

59-65: Well-documented approaches with clear examples!

The documentation clearly explains the two approaches for multi-task training and provides example configurations for each approach. The technical requirements, especially for the data identification approach, are well-specified.

iProzd and others added 2 commits November 25, 2024 22:15
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Duo <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (3)
source/tests/pt/test_multitask.py (3)

45-51: Add error handling for template file loading

While the code follows existing patterns, consider adding explicit error handling for file operations to provide clearer error messages if the template file is missing.

-    with open(multitask_sharefit_template_json) as f:
-        multitask_sharefit_template = json.load(f)
+    try:
+        with open(multitask_sharefit_template_json) as f:
+            multitask_sharefit_template = json.load(f)
+    except FileNotFoundError:
+        raise FileNotFoundError(f"Multitask sharefit template not found at {multitask_sharefit_template_json}")
+    except json.JSONDecodeError:
+        raise ValueError(f"Invalid JSON format in {multitask_sharefit_template_json}")

58-58: LGTM! Consider adding docstring for share_fitting flag

The implementation correctly handles shared fitting parameters with appropriate exclusions. Consider adding a docstring to document the purpose of the share_fitting flag and its impact on parameter comparison.

 class MultiTaskTrainTest:
+    """Base class for multitask training tests.
+    
+    Attributes:
+        share_fitting: When True, enables shared fitting network parameters between models,
+                      except for bias_atom_e and dataid parameters.
+    """

Also applies to: 73-78


239-277: Add test class documentation

The implementation looks good, but would benefit from documentation explaining the test scenarios and expectations for shared fitting.

 class TestMultiTaskSeASharefit(unittest.TestCase, MultiTaskTrainTest):
+    """Test case for multitask training with shared fitting networks.
+    
+    This test verifies that:
+    1. Models correctly share fitting network parameters
+    2. Excluded parameters (bias_atom_e, dataid) remain model-specific
+    3. Training proceeds correctly with shared fitting networks
+    """
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 412aea9 and 27de319.

📒 Files selected for processing (2)
  • deepmd/pt/train/training.py (2 hunks)
  • source/tests/pt/test_multitask.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/train/training.py

Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find dataid is not self-explanatory. I would suggest something like caseid .

numb_dataid should be named as numb_cases

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (6)
deepmd/tf/fit/dipole.py (2)

87-87: Update docstring to document the new parameter.

The numb_dataid parameter is missing from the class docstring. Please add documentation for this parameter.

Add the following to the docstring parameters section:

    Parameters
    ----------
+   numb_dataid : int
+           Number of data identifiers, not supported in TensorFlow implementation

128-129: Consider handling TensorFlow limitations at a higher level.

Currently, the TensorFlow limitation for data IDs is handled within this class. Consider moving this validation to a higher level in the architecture, possibly during the model configuration phase, to make the limitation more visible and maintainable.

This would:

  1. Make the limitation more discoverable
  2. Centralize framework-specific restrictions
  3. Simplify future updates if TensorFlow adds support
deepmd/tf/fit/dos.py (2)

114-114: Enhance the error message for unsupported numb_dataid.

The error message could be more informative by explaining why dataid is not supported in TensorFlow and suggesting potential alternatives.

-    if numb_dataid > 0:
-        raise ValueError("numb_dataid is not supported in TensorFlow.")
+    if numb_dataid > 0:
+        raise ValueError(
+            "numb_dataid > 0 is not supported in TensorFlow implementation. "
+            "Consider using a different backend that supports data identification, "
+            "or set numb_dataid=0 to disable this feature."
+        )

Also applies to: 136-138


679-679: LGTM! Consider documenting version changes.

The serialization changes are consistent and maintain backward compatibility. The version bump from 2 to 3 correctly reflects the addition of the dataid field.

Consider adding a comment explaining the version history:

# Version history:
# 1: Initial version
# 2: <previous changes>
# 3: Added support for dataid field

Also applies to: 706-716, 739-739

deepmd/tf/fit/polar.py (2)

98-98: Validate the new numb_dataid parameter implementation.

The implementation raises a ValueError if numb_dataid > 0, indicating that this feature is not supported in TensorFlow. However, the error message could be more informative.

Consider enhancing the error message to provide more context:

-            raise ValueError("numb_dataid is not supported in TensorFlow.")
+            raise ValueError("Data ID feature is not supported in TensorFlow implementation. Consider using other backends that support this feature.")

Also applies to: 166-172


166-172: Validate error handling consistency.

The error handling for unsupported features (numb_fparam, numb_aparam, numb_dataid) follows a consistent pattern, which is good. However, consider grouping these checks together for better maintainability.

Consider refactoring to group the validation checks:

-        if numb_fparam > 0:
-            raise ValueError("numb_fparam is not supported in the dipole fitting")
-        if numb_aparam > 0:
-            raise ValueError("numb_aparam is not supported in the dipole fitting")
-        if numb_dataid > 0:
-            raise ValueError("numb_dataid is not supported in TensorFlow.")
+        unsupported_params = {
+            'numb_fparam': "not supported in the dipole fitting",
+            'numb_aparam': "not supported in the dipole fitting",
+            'numb_dataid': "not supported in TensorFlow"
+        }
+        for param_name, value in unsupported_params.items():
+            if locals()[param_name] > 0:
+                raise ValueError(f"{param_name} is {value}")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 27de319 and 692d82d.

📒 Files selected for processing (4)
  • deepmd/tf/fit/dipole.py (5 hunks)
  • deepmd/tf/fit/dos.py (6 hunks)
  • deepmd/tf/fit/ener.py (6 hunks)
  • deepmd/tf/fit/polar.py (5 hunks)
🔇 Additional comments (9)
deepmd/tf/fit/dipole.py (3)

Line range hint 392-402: LGTM! Version bump and serialization changes are appropriate.

The version increment and addition of numb_dataid to the serialized data follow best practices:

  • Version bumped from 2 to 3 to reflect the schema change
  • New field added consistently with similar fields

436-436: LGTM! Version compatibility check is correctly updated.

The version compatibility check is properly updated to handle the new version 3 while maintaining backward compatibility with versions 1 and 2.


128-129: Verify TensorFlow limitation for data ID support.

The error message states that data IDs are not supported in TensorFlow. Let's verify this limitation.

✅ Verification successful

Data ID feature is confirmed to be PyTorch-only, TensorFlow limitation is correct

The code search results confirm that:

  • Data ID feature is explicitly documented as PyTorch-only in the codebase (see doc_only_pt_supported + doc_numb_dataid in argcheck.py)
  • All TensorFlow fitting implementations (ener.py, polar.py, dipole.py, dos.py) consistently raise the "not supported in TensorFlow" error
  • The feature is fully implemented in PyTorch path (deepmd/pt/) with working data ID functionality
  • Documentation (multi-task-training.md) describes this as a PyTorch feature for multi-task model training
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any existing data ID implementations or discussions in TensorFlow
rg -g '!*.pyc' -g '!*.so' -i 'dataid|data[_\s]id' .

# Search for related issues or discussions
gh issue list -R deepmodeling/deepmd-kit -S "dataid TensorFlow"

Length of output: 14930

deepmd/tf/fit/polar.py (2)

633-633: Verify version compatibility check implementation.

The version compatibility check has been updated to support version 4 while maintaining backward compatibility with version 1 for PT compatibility.


585-585: Version increment needs validation across dependent modules.

The version has been incremented from 3 to 4 to accommodate the new numb_dataid parameter. This change requires careful validation to ensure backward compatibility.

Also applies to: 595-595

✅ Verification successful

Let's examine the version usage more specifically in the relevant model files.


Based on the analysis of the codebase, I can now provide a final response:

Version compatibility is properly handled across the codebase

The version increment from 3 to 4 in the polarizability module is safe because:

  1. The codebase has a robust version compatibility checking system through check_version_compatibility() function that validates version changes across all modules
  2. The polarizability module explicitly checks version compatibility in both TF and PT implementations:
    • deepmd/pt/model/task/polarizability.py: check_version_compatibility(data.pop("@version", 1), 4, 1)
    • deepmd/tf/fit/polar.py: check_version_compatibility(data.pop("@version", 1), 4, 1)
  3. The version increment accommodates the new numb_dataid parameter while maintaining backward compatibility with version 1 through 4
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for version dependencies in other files
rg -l "version.*[34]" | grep -v "polar.py"

Length of output: 1902


Script:

#!/bin/bash
# Check for version-related code in polarizability and related model files
rg -A 5 -B 5 '"version":\s*[34]' deepmd/tf/fit/polar.py deepmd/dpmodel/fitting/polarizability_fitting.py deepmd/pt/model/task/polarizability.py

Length of output: 145


Script:

#!/bin/bash
# Look for version validation or compatibility checks
ast-grep --pattern 'if $_.version $$$'

Length of output: 40


Script:

#!/bin/bash
# Check for version-related model loading or validation code
rg -A 5 "load.*model.*version|check.*version|compatible.*version"

Length of output: 64305

deepmd/tf/fit/ener.py (4)

Line range hint 159-196: LGTM: Constructor changes properly handle the new parameter.

The addition of numb_dataid parameter with validation is well-implemented. The error message clearly indicates that this feature is not supported in TensorFlow.


885-885: LGTM: Version compatibility properly updated.

Version number is correctly bumped from 2 to 3 to reflect the addition of the new numb_dataid parameter. The version compatibility check is properly updated.

Also applies to: 911-911


921-921: LGTM: Serialization format properly updated.

The serialization format is correctly updated to include the new numb_dataid parameter and dataid field, while maintaining backward compatibility.

Also applies to: 953-953


Line range hint 159-196: Verify consistent handling of numb_dataid.

Let's verify that all code paths properly handle the new parameter.

Also applies to: 885-953

✅ Verification successful

Parameter numb_dataid is correctly handled with proper error messages in TensorFlow and full support in PyTorch

The verification shows consistent handling of numb_dataid across the codebase:

  • TensorFlow implementation explicitly raises ValueError when numb_dataid > 0 with consistent error message "numb_dataid is not supported in TensorFlow"
  • PyTorch implementation fully supports numb_dataid with proper initialization and usage:
    • Initializes dataid buffer when numb_dataid > 0
    • Properly extends input dimension by numb_dataid
    • Handles dataid in forward pass by tiling and concatenating with input
    • Used in multi-task training as documented in doc/train/multi-task-training.md
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any inconsistencies in numb_dataid handling

# Check if numb_dataid is used consistently
rg -A 5 "numb_dataid"

# Check for any TODO or FIXME comments related to dataid
rg -i "todo.*dataid|fixme.*dataid"

# Check for any potential error handling of dataid
rg "raise.*dataid|error.*dataid"

Length of output: 40117

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
doc/train/multi-task-training.md (2)

56-58: Enhance documentation of the shared_level parameter.

The explanation of the shared_level parameter could be more comprehensive. Consider adding:

  1. Valid values for shared_level
  2. How different values affect parameter sharing
  3. Examples of when to use different shared_level values
-    and use the user-defined integer `shared_level` in the code to share the corresponding module to varying degrees.
-    (The default value for `shared_level` is set to 0. If default sharing is enabled, all parameters in the descriptor will be shared,
-    whereas in the fitting network, all parameters will be shared except for `bias_atom_e` and `dataid`.).
+    and use the user-defined integer `shared_level` in the code to control parameter sharing:
+    - `shared_level=0` (default): Full sharing. All parameters in the descriptor will be shared.
+      In the fitting network, all parameters will be shared except for `bias_atom_e` and `dataid`.
+    - `shared_level=1`: [Document behavior]
+    - `shared_level=2`: [Document behavior]
+    
+    Example use cases:
+    - Use `shared_level=0` when [explain scenario]
+    - Use `shared_level=1` when [explain scenario]

60-65: Enhance the multi-task training approaches section.

While the two approaches are clearly listed, the documentation would be more helpful with:

  1. Pros and cons of each approach
  2. Small inline examples demonstrating the key differences
  3. Guidelines for choosing between approaches
     2. **Descriptor and fitting network sharing with data identification**:
        - Share the descriptor and the fitting network with `shared_level`=0.
        - {ref}`numb_dataid <model[standard]/fitting_net[ener]/numb_dataid>` must be set to the number of model branches, which will distinguish different data tasks using a one-hot embedding.
-       - See [here](../../examples/water_multi_task/pytorch_example/input_torch_sharefit.json) for an example.
+       - See [here](../../examples/water_multi_task/pytorch_example/input_torch_sharefit.json) for a complete example.
+
+    **Choosing between approaches:**
+    - Use approach 1 when [explain scenario], as it provides [benefits]
+    - Use approach 2 when [explain scenario], as it provides [benefits]
+
+    **Quick example of approach 1:**
+    ```json
+    {
+        "model": {
+            "shared_dict": {
+                "descriptor": { ... }
+            },
+            "model_dict": { ... }
+        }
+    }
+    ```
+
+    **Quick example of approach 2:**
+    ```json
+    {
+        "model": {
+            "shared_dict": {
+                "descriptor": { ... },
+                "fitting_net": { ... }
+            },
+            "model_dict": {
+                "model1": {
+                    "fitting_net": {
+                        "numb_dataid": 2
+                    }
+                }
+            }
+        }
+    }
+    ```
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 692d82d and 5c4bfa8.

📒 Files selected for processing (4)
  • deepmd/pt/model/task/fitting.py (10 hunks)
  • doc/train/multi-task-training.md (1 hunks)
  • examples/water_multi_task/pytorch_example/input_torch_sharefit.json (1 hunks)
  • source/tests/pt/model/water/multitask_sharefit.json (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/pt/model/task/fitting.py
  • examples/water_multi_task/pytorch_example/input_torch_sharefit.json
  • source/tests/pt/model/water/multitask_sharefit.json

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Outside diff range and nitpick comments (15)
examples/water_multi_task/pytorch_example/input_torch_sharefit.json (2)

59-69: Document the shared_fit_with_id configuration

The shared_fit_with_id configuration introduces a new concept with numb_caseid, but its purpose and usage are not clearly documented. This was also noted in a previous review comment.

Consider adding documentation to explain:

  • The purpose and significance of numb_caseid
  • How the ID sharing mechanism works between models
  • The relationship between numb_caseid and the number of models

117-146: Document the data sharing strategy

Both models use identical training data paths. If this is intentional for the multi-task learning setup:

  • Document why the same training data is appropriate for both models
  • Explain how the shared data benefits the multi-task learning process
  • Consider adding a comment in the configuration explaining the data sharing strategy
deepmd/pt/model/task/invar_fitting.py (1)

94-94: Document the new numb_caseid parameter

The new parameter numb_caseid is added but not documented in the class docstring. Please add parameter documentation following the existing NumPy docstring format.

Add this to the Parameters section of the class docstring:

    Parameters
    ----------
+   numb_caseid : int, optional
+       Number of case identifiers. Defaults to 0.
    activation_function : str

Also applies to: 117-117

doc/train/multi-task-training.md (1)

60-65: Consider enhancing the multitask training approaches documentation.

The documentation clearly explains the two approaches but could be enhanced by adding:

  1. Guidance on when to use each approach (trade-offs, use cases)
  2. Whether it's possible to mix both approaches in the same training setup
  3. Examples of how the numb_caseid value relates to the number of model branches
source/tests/pt/test_multitask.py (2)

45-51: Add error handling for template file loading

Consider adding explicit error handling for the case when the template file is missing.

-    with open(multitask_sharefit_template_json) as f:
-        multitask_sharefit_template = json.load(f)
+    try:
+        with open(multitask_sharefit_template_json) as f:
+            multitask_sharefit_template = json.load(f)
+    except FileNotFoundError:
+        raise FileNotFoundError(
+            f"Shared fitting template file not found: {multitask_sharefit_template_json}"
+        )
+    except json.JSONDecodeError:
+        raise ValueError(
+            f"Invalid JSON in shared fitting template: {multitask_sharefit_template_json}"
+        )

239-274: Add class documentation

Please add a docstring to explain the purpose of this test class and what aspects of shared fitting it verifies.

 class TestMultiTaskSeASharefit(unittest.TestCase, MultiTaskTrainTest):
+    """Test class for verifying shared fitting functionality in multi-task training.
+    
+    This class tests the scenario where multiple models share the same fitting network
+    parameters, except for bias_atom_e and caseid parameters which remain model-specific.
+    """

Consider adding specific test methods

While inheriting test methods from MultiTaskTrainTest is good, consider adding specific test methods to verify shared fitting behavior.

def test_shared_fitting_parameters(self) -> None:
    """Verify that appropriate parameters are shared between models."""
    self.config = update_deepmd_input(self.config, warning=True)
    self.config = normalize(self.config, multi_task=True)
    trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
    
    # Get parameters before training
    initial_params = trainer.wrapper.model.state_dict()
    
    # Run training
    trainer.run()
    
    # Get parameters after training
    trained_params = trainer.wrapper.model.state_dict()
    
    # Verify shared parameters were updated together
    for key in trained_params:
        if ('fitting_net' in key and 
            'bias_atom_e' not in key and 
            'caseid' not in key):
            if 'model_1' in key:
                paired_key = key.replace('model_1', 'model_2')
                torch.testing.assert_close(
                    trained_params[key],
                    trained_params[paired_key],
                    msg=f"Parameters not shared: {key} vs {paired_key}"
                )
deepmd/tf/fit/dipole.py (2)

123-123: Consider enhancing the error message

While the validation is correct, the error message could be more informative by explaining why case IDs are not supported in TensorFlow or suggesting alternatives.

-            raise ValueError("numb_caseid is not supported in TensorFlow.")
+            raise ValueError("numb_caseid > 0 is not supported in TensorFlow backend. Consider using other backends that support multi-task learning.")

Also applies to: 128-129


87-87: Consider adding documentation about TensorFlow limitations

The class docstring should document why numb_caseid is not supported in TensorFlow to help users understand the limitation.

Add to the docstring:

    Parameters
    ----------
+   numb_caseid : int, optional
+       Number of case IDs for multi-task learning. Currently not supported in TensorFlow backend
+       due to architectural limitations. Default is 0.
    sel_type : list[int]
deepmd/dpmodel/fitting/general_fitting.py (1)

446-458: Improve assertion message for caseid check

The implementation is correct, but the assertion message could be more descriptive.

-            assert self.caseid is not None
+            assert self.caseid is not None, "caseid should not be None when numb_caseid > 0"
deepmd/pt/model/task/fitting.py (2)

135-135: Add documentation for the numb_caseid parameter.

The numb_caseid parameter has been added to the constructor, but its documentation is missing from the class docstring.

Add the following to the class docstring under the Parameters section:

    numb_aparam : int
        Number of atomic parameters.
+   numb_caseid : int
+       Number of case identifiers.
    activation_function : str
        Activation function.

Also applies to: 157-157


209-217: Remove commented-out initialization code.

The commented-out line suggests an alternative initialization using an identity matrix. If this is not needed, it should be removed to maintain clean code.

            self.register_buffer(
                "caseid",
                torch.zeros(self.numb_caseid, dtype=self.prec, device=device),
-               # torch.eye(self.numb_caseid, dtype=self.prec, device=device)[0],
            )
deepmd/tf/fit/polar.py (1)

171-172: Enhance error message clarity

The current error message could be more informative. Consider expanding it to explain why case IDs aren't supported in TensorFlow and what alternatives are available.

-            raise ValueError("numb_caseid is not supported in TensorFlow.")
+            raise ValueError("Case IDs (numb_caseid > 0) are not supported in TensorFlow. Consider using PyTorch implementation for case ID support.")
deepmd/pt/train/training.py (2)

1270-1286: Add docstring to document the resuming parameter.

The function's behavior regarding case ID initialization should be clearly documented.

Add this docstring:

 def get_model_for_wrapper(_model_params, resuming=False):
+    """Get model(s) for the wrapper with optional case ID initialization.
+    
+    Args:
+        _model_params: Model parameters dictionary
+        resuming: If True, skip case ID initialization for multi-task models
+    
+    Returns:
+        Single model or dictionary of models for multi-task training
+    """

1289-1308: Add type hints to improve code clarity.

The function's return type and parameters should be clearly typed.

Add type hints:

-def get_caseid_config(_model_params):
+def get_caseid_config(_model_params: dict) -> tuple[bool, dict[str, int]]:
deepmd/tf/fit/ener.py (1)

953-953: Clarify the purpose of 'caseid' set to None in serialization

The 'caseid' variable is serialized with a value of None. Please confirm if this is intended. If it's not currently used, consider omitting it or providing a comment explaining its future use.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 5c4bfa8 and 048036b.

📒 Files selected for processing (34)
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/make_base_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/pairtab_atomic_model.py (1 hunks)
  • deepmd/dpmodel/fitting/dipole_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/dos_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/ener_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/general_fitting.py (9 hunks)
  • deepmd/dpmodel/fitting/invar_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/polarizability_fitting.py (4 hunks)
  • deepmd/dpmodel/fitting/property_fitting.py (3 hunks)
  • deepmd/dpmodel/model/make_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/linear_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1 hunks)
  • deepmd/pt/model/model/make_model.py (1 hunks)
  • deepmd/pt/model/task/dipole.py (3 hunks)
  • deepmd/pt/model/task/dos.py (3 hunks)
  • deepmd/pt/model/task/ener.py (3 hunks)
  • deepmd/pt/model/task/fitting.py (10 hunks)
  • deepmd/pt/model/task/invar_fitting.py (3 hunks)
  • deepmd/pt/model/task/polarizability.py (4 hunks)
  • deepmd/pt/model/task/property.py (3 hunks)
  • deepmd/pt/train/training.py (2 hunks)
  • deepmd/tf/fit/dipole.py (5 hunks)
  • deepmd/tf/fit/dos.py (6 hunks)
  • deepmd/tf/fit/ener.py (6 hunks)
  • deepmd/tf/fit/polar.py (5 hunks)
  • deepmd/utils/argcheck.py (10 hunks)
  • doc/train/multi-task-training.md (1 hunks)
  • examples/water_multi_task/pytorch_example/input_torch_sharefit.json (1 hunks)
  • source/tests/pt/model/water/multitask_sharefit.json (1 hunks)
  • source/tests/pt/test_multitask.py (3 hunks)
  • source/tests/universal/dpmodel/fitting/test_fitting.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (24)
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py
  • deepmd/dpmodel/atomic_model/make_base_atomic_model.py
  • deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
  • deepmd/dpmodel/fitting/dipole_fitting.py
  • deepmd/dpmodel/fitting/dos_fitting.py
  • deepmd/dpmodel/fitting/ener_fitting.py
  • deepmd/dpmodel/fitting/invar_fitting.py
  • deepmd/dpmodel/fitting/polarizability_fitting.py
  • deepmd/dpmodel/fitting/property_fitting.py
  • deepmd/dpmodel/model/make_model.py
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/atomic_model/linear_atomic_model.py
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py
  • deepmd/pt/model/model/make_model.py
  • deepmd/pt/model/task/dipole.py
  • deepmd/pt/model/task/dos.py
  • deepmd/pt/model/task/ener.py
  • deepmd/pt/model/task/polarizability.py
  • deepmd/pt/model/task/property.py
  • deepmd/tf/fit/dos.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/model/water/multitask_sharefit.json
  • source/tests/universal/dpmodel/fitting/test_fitting.py
🔇 Additional comments (26)
examples/water_multi_task/pytorch_example/input_torch_sharefit.json (4)

72-84: LGTM: Model configuration properly utilizes shared components

The model dictionary correctly references the shared components for both water models, which is appropriate for multi-task learning with shared parameters.


85-91: LGTM: Learning rate configuration is well-defined

The exponential decay configuration with specified start/stop rates and decay steps is appropriate for this training setup.


136-146: ⚠️ Potential issue

Add validation data for water_2

The water_2 model is missing validation data, which could affect model evaluation and training monitoring.


92-111: Verify if identical loss configurations are intended

Both water_1 and water_2 use identical loss configurations. While this might be intentional for multi-task learning, consider:

  • Documenting why identical loss weights are appropriate
  • Whether different weights might be beneficial for different tasks
✅ Verification successful

Based on the search results, I can now generate the final response since I have enough context about multi-task configurations in the codebase:

Identical loss configurations are intentional for this multi-task setup

The identical loss configurations between water_1 and water_2 are appropriate because:

  • Both models share the same fitting network (using "shared_fit_with_id" configuration)
  • The codebase's multi-task implementation (deepmd/pt/train/training.py) supports identical loss weights for shared architectures
  • This is a standard pattern seen in other multi-task examples (e.g., source/tests/pt/model/water/multitask_sharefit.json)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for similar multi-task configurations in the codebase
rg -l "start_pref_[efv]" | xargs rg "model_dict|loss_dict" -C 5

Length of output: 21862

deepmd/pt/model/task/invar_fitting.py (1)

146-146: Version compatibility documentation needed

A previous review comment already addresses the version compatibility documentation requirements.

source/tests/pt/test_multitask.py (1)

58-58: LGTM! Well-structured parameter comparison logic

The changes properly handle shared fitting parameters while correctly excluding bias and case ID parameters from comparison.

Also applies to: 73-78

deepmd/tf/fit/dipole.py (3)

87-87: LGTM: Parameter addition follows existing patterns

The new numb_caseid parameter is correctly added with appropriate type hints and a sensible default value.


392-392: LGTM: Proper version management and serialization

The version bump and serialization changes are appropriate for the new parameter addition, maintaining backward compatibility through proper versioning.

Also applies to: 402-402


436-436: Verify version compatibility across dependent modules

The version check is correctly updated, but we should verify that all dependent modules are compatible with version 3 of the serialized format.

✅ Verification successful

Based on the comprehensive search results, I can now provide a final response:

Version compatibility is properly maintained across the codebase

The version check for DipoleFitting is consistent with other fitting modules in both TF and PyTorch implementations:

  • All fitting modules (ener, dos, polar, dipole) use version 3 as maximum and version 1 as minimum
  • The version compatibility check is implemented uniformly across the codebase
  • The serialization format version 3 is properly declared in the dipole.py file
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for other modules that might deserialize DipoleFittingSeA data
# to ensure they can handle version 3

# Search for deserialize calls
rg -l "deserialize.*DipoleFitting" 

# Search for version checks related to DipoleFitting
rg "check_version_compatibility.*DipoleFitting"

Length of output: 86


Script:

#!/bin/bash
# Let's broaden the search to find all version compatibility checks and serialization patterns
# This will help understand the versioning system across the codebase

# Search for all version compatibility checks
rg "check_version_compatibility" -A 2

# Search for any serialization/deserialization related to dipole
rg -i "dipole.*(?:serial|deserial)" -A 2

# Look for version-related code in dipole.py
rg "@version" "deepmd/tf/fit/dipole.py" -A 2

Length of output: 26657

deepmd/dpmodel/fitting/general_fitting.py (5)

108-108: LGTM: Constructor changes are well-implemented

The addition of numb_caseid parameter and its initialization follow the existing pattern and are properly documented.

Also applies to: 131-131


176-179: LGTM: Proper initialization of caseid

The initialization logic follows the same pattern as other similar parameters and correctly handles the precision.


185-185: LGTM: Input dimension calculation properly updated

The input dimension calculation correctly includes the case identification dimension.


272-273: LGTM: Dictionary access methods properly updated

The addition of 'caseid' to setitem and getitem is consistent with the existing pattern.

Also applies to: 290-291


308-308: LGTM: Serialization properly handles case identification

The version bump and addition of case identification fields to serialization are appropriate and consistent.

Also applies to: 316-316, 325-325

deepmd/pt/model/task/fitting.py (5)

67-67: LGTM: Comment accurately reflects parameter sharing behavior.

The comment clearly documents that both bias_atom_e and caseid are excluded from parameter sharing when shared_level is 0.


221-222: LGTM: Input dimension calculation correctly includes case identifiers.

The numb_caseid is properly added to the input dimension calculation for the neural network.


282-282: LGTM: Serialization properly handles case identification data.

The serialization has been correctly updated to include case identification data with appropriate version bump.

Also applies to: 290-290, 299-299


380-381: LGTM: Consistent attribute access implementation.

The caseid attribute is properly handled in both __setitem__ and __getitem__ methods, maintaining consistency with other attributes.

Also applies to: 398-399


498-510: LGTM: Case identification properly integrated into forward pass.

The case identification is correctly concatenated to the input tensor, with proper handling of both normal and zero inputs. The implementation follows the same pattern as fparam and aparam handling.

deepmd/tf/fit/polar.py (3)

98-98: LGTM: Constructor changes for case ID support

The addition of numb_caseid parameter is well-implemented with:

  • Safe default value of 0
  • Early validation to prevent unsupported usage in TensorFlow
  • Proper initialization in the class

Also applies to: 166-166, 171-172


595-595: LGTM: Proper serialization of new parameter

The numb_caseid field is correctly added to the serialized data, maintaining consistency with the constructor parameters.


585-585: Clarify version compatibility comment

While the version increment is appropriate for the new parameter, the comment "to allow PT version" needs elaboration. Consider documenting:

  • What PT version means
  • Why this compatibility is needed
  • Any specific version requirements

Also applies to: 633-634

✅ Verification successful

Let me gather more context about the version changes and PT references.


Let me check for any version-related changes in the model implementation and serialization.


Based on the gathered information, I can now provide a final response:

PT version refers to PyTorch backend compatibility

The version increment from 3 to 4 in the polarizability fitting module was necessary to maintain compatibility with the PyTorch backend implementation. This is evident from:

  • The version check is used in both TensorFlow (PolarFittingSeA) and PyTorch (PolarFittingPT) implementations
  • The comment "to allow PT version" refers to PyTorch version compatibility
  • The version increment ensures proper model serialization/deserialization across different backend implementations
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for other version compatibility comments
rg -A 2 "PT version" 

Length of output: 327


Script:

#!/bin/bash
# Check for version-related changes and context
rg -B 2 -A 2 '@version.*[0-9]' deepmd/tf/fit/polar.py

# Look for PT/PyTorch related imports or configurations
rg -B 2 -A 2 'torch|pytorch|PT' deepmd/tf/fit/polar.py

# Check the commit history for version changes
git log -p deepmd/tf/fit/polar.py | rg -B 2 -A 2 '@version|PT version'

# Look for version documentation
fd README.md | xargs rg -i 'version|compatibility'

Length of output: 1022


Script:

#!/bin/bash
# Check for model serialization and version handling
ast-grep --pattern 'def serialize($$$)' deepmd/tf/fit/polar.py

# Look for version compatibility checks
rg -B 3 -A 3 'check_version_compatibility' --type py

# Check for any PyTorch/PT related model implementations
fd polar.py | xargs rg -B 2 -A 2 'class.*Polar'

Length of output: 62817

deepmd/tf/fit/ener.py (4)

194-196: Validation for numb_caseid is appropriate

The validation check for numb_caseid ensures that users are informed when an unsupported value is provided in TensorFlow. The error handling is correctly implemented.


885-885: Versioning updated appropriately for deserialization

The check_version_compatibility call has been updated to reflect the new version (3), ensuring that the deserialization process is compatible with the updated data structure.


911-911: Version number increased to 3 in serialization

The version number in the serialized data is correctly updated to 3 to match the changes in the data structure.


921-921: numb_caseid parameter is correctly included in serialization

The addition of numb_caseid to the serialized data ensures that this parameter is preserved when the model is saved and loaded.

doc/train/multi-task-training.md Outdated Show resolved Hide resolved
deepmd/dpmodel/fitting/general_fitting.py Outdated Show resolved Hide resolved
deepmd/pt/model/task/fitting.py Outdated Show resolved Hide resolved
deepmd/tf/fit/ener.py Outdated Show resolved Hide resolved
deepmd/pt/model/task/dipole.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
source/tests/pt/model/water/multitask.json (1)

Line range hint 1-185: Consider enhancing the test configuration

As this appears to be a test configuration, consider the following improvements:

  1. Use different model architectures or hyperparameters for model_1 and model_2 to better test the "sharable fitting" feature
  2. Reduce numb_steps from 100000 to a smaller number suitable for testing
  3. Add comments explaining the purpose of this test configuration and how it validates the dataid and sharable fitting features
source/tests/pt/model/water/multitask_sharefit.json (1)

41-54: Consider diversifying model configurations

Both models use identical configurations which might limit the learning capacity of the multitask setup. Consider:

  1. Using different descriptor parameters for each model
  2. Adjusting the neuron configurations to capture different aspects of the system
deepmd/pt/model/task/fitting.py (1)

213-213: Remove commented code.

The commented initialization approach should be removed if it's not being used.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 048036b and d399c6e.

📒 Files selected for processing (13)
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/make_base_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/pairtab_atomic_model.py (1 hunks)
  • deepmd/dpmodel/fitting/general_fitting.py (9 hunks)
  • deepmd/dpmodel/model/make_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/linear_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1 hunks)
  • deepmd/pt/model/model/make_model.py (1 hunks)
  • deepmd/pt/model/task/fitting.py (10 hunks)
  • source/tests/pt/model/water/multitask.json (1 hunks)
  • source/tests/pt/model/water/multitask_sharefit.json (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (10)
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py
  • deepmd/dpmodel/atomic_model/make_base_atomic_model.py
  • deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
  • deepmd/dpmodel/fitting/general_fitting.py
  • deepmd/dpmodel/model/make_model.py
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/atomic_model/linear_atomic_model.py
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py
  • deepmd/pt/model/model/make_model.py
🔇 Additional comments (11)
source/tests/pt/model/water/multitask.json (1)

13-14: Please verify the selection parameters for atom types

The sel array has been modified to include a new value 4. Given that my_type_map contains three atom types ["O", "H", "B"], please verify:

  1. The relationship between the sel array [46, 92, 4] and the atom types
  2. Whether these numbers are appropriate for the respective atom types
  3. How this change relates to the dataid support mentioned in the PR title
✅ Verification successful

Let me analyze the search results and generate additional verification scripts to understand the relationship between the sel parameters and atom types better.

The search results show that most water-related configurations use [46, 92] for the sel parameter, but we need to understand:

  1. The documentation of what these numbers represent
  2. The specific type_map configuration in this test case to verify the atom type order
  3. Whether the addition of '4' is consistent with other multitask configurations

This will help us make a definitive assessment about the changes.


Let me search for the complete content of the multitask configuration to understand the full context of the type mapping.

The documentation clearly explains that sel represents the maximum number of type-i neighbors of an atom, and we see that standard water configurations use [46, 92] for O and H atoms. We need to see the complete multitask configuration to understand:

  1. The definition of my_type_map to confirm the atom types order
  2. Whether the addition of '4' corresponds to the third atom type (B)
  3. If this pattern is consistent with other multitask configurations

Selection parameters are correctly configured for the multitask model

The sel array [46, 92, 4] is correctly aligned with the my_type_map ["O", "H", "B"] configuration, where:

  • 46 represents the maximum number of O-type neighbors
  • 92 represents the maximum number of H-type neighbors
  • 4 represents the maximum number of B-type neighbors

This configuration is consistent with:

  1. The documentation which explains that sel values represent the maximum number of type-i neighbors
  2. The standard water model configuration [46, 92] for O and H atoms
  3. The same configuration being used in the related multitask_sharefit.json
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for similar descriptor configurations in test files
rg -A 10 '"type":\s*"se_e2_a"' --type json

# Look for documentation or examples of sel parameter usage
fd -e md -e rst . | xargs rg -l "sel.*parameter|descriptor.*sel"

Length of output: 17998


Script:

#!/bin/bash
# Check the content of the model documentation for sel parameter
cat doc/model/sel.md

# Look for type_map configurations in the same directory
rg -A 5 '"type_map"' source/tests/pt/model/water/

Length of output: 4509


Script:

#!/bin/bash
# Get the complete content of the multitask.json to see the my_type_map definition
cat source/tests/pt/model/water/multitask.json

# Look for other multitask examples that might have similar sel configurations
rg -A 10 '"my_type_map"' --type json

Length of output: 7254

source/tests/pt/model/water/multitask_sharefit.json (6)

56-63: LGTM: Learning rate configuration is well structured

The exponential decay strategy with the given parameters provides a good balance between initial learning and convergence.


64-83: LGTM: Loss configuration aligns with water system requirements

The energy-based loss with high initial force preference is appropriate for water system modeling.


89-123: **** The existing review comment about optimizing training data configuration and batch size is still valid.


125-128: **** The existing review comment about increasing the number of training steps is still valid.


11-15: Verify the sel parameter configuration

The sel parameter [46, 92, 4] appears to be inconsistent with the number of atom types in my_type_map ["O", "H", "B"]. Each element in sel should correspond to the selection of neighbors for each atom type.


1-134: Verify the shared fitting implementation

While the overall structure of the configuration is well-organized, please ensure that the shared fitting implementation (numb_caseid: 2) is properly tested with different scenarios:

  1. Different input data distributions
  2. Various atom type combinations
  3. Edge cases in the fitting process
✅ Verification successful

Let me gather more information about the test coverage and implementation details.


Let me check the documentation content and other test configurations for a complete understanding.


Shared fitting configuration is correctly implemented

The shared fitting implementation in the configuration is properly set up and aligns with the documented best practices:

  • The numb_caseid: 2 correctly matches the number of model branches (model_1 and model_2)
  • The configuration follows the documented approach of "Descriptor and fitting network sharing with data identification"
  • Equal task sampling weights (0.5/0.5) ensure balanced training between the two models
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other shared fitting test configurations
rg -A 5 '"numb_caseid"' --glob '**/tests/**/*.json'

Length of output: 466


Script:

#!/bin/bash
# Check for test data variations and atom types
rg -l "data_0" --glob '**/tests/**/*.json'

# Look for other multitask configurations
rg -l "model_prob" --glob '**/tests/**/*.json'

# Search for any documentation or examples of shared fitting
fd -e md -e rst . | xargs rg -l "shared.*fit|numb_caseid"

Length of output: 529


Script:

#!/bin/bash
# Check documentation content
cat doc/train/multi-task-training.md

# Compare with non-shared multitask configuration
cat source/tests/pt/model/water/multitask.json

# Check data directory structure
fd -t d "data_0" source/tests/pt/model/water/

Length of output: 10627

deepmd/pt/model/task/fitting.py (4)

67-67: LGTM: Comment accurately reflects parameter sharing behavior.

The comment correctly documents that both bias_atom_e and caseid are excluded from parameter sharing when shared_level is 0.


359-367: LGTM with existing validation suggestion.

The implementation correctly creates a one-hot encoded tensor for case identification.


282-282: LGTM: Proper versioning and serialization.

The version increment and addition of new fields (numb_caseid and caseid) to serialization are handled correctly.

Also applies to: 290-290, 299-299


498-510: LGTM: Verify tensor shapes in production.

The case identification concatenation is implemented correctly. However, it's important to verify the tensor shapes in production, especially for edge cases with different batch sizes.

deepmd/pt/model/task/fitting.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Nov 26, 2024

Codecov Report

Attention: Patch coverage is 75.59055% with 31 lines in your changes missing coverage. Please review.

Project coverage is 83.30%. Comparing base (4a45fe5) to head (bd3fbef).
Report is 5 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/dpmodel/fitting/general_fitting.py 75.00% 4 Missing ⚠️
deepmd/pd/model/task/fitting.py 60.00% 4 Missing ⚠️
deepmd/dpmodel/atomic_model/linear_atomic_model.py 40.00% 3 Missing ⚠️
...eepmd/pt/model/atomic_model/linear_atomic_model.py 40.00% 3 Missing ⚠️
deepmd/pt/model/task/fitting.py 87.50% 2 Missing ⚠️
deepmd/tf/fit/dipole.py 50.00% 2 Missing ⚠️
deepmd/tf/fit/dos.py 50.00% 2 Missing ⚠️
deepmd/dpmodel/atomic_model/dp_atomic_model.py 50.00% 1 Missing ⚠️
...pmd/dpmodel/atomic_model/make_base_atomic_model.py 66.66% 1 Missing ⚠️
...eepmd/dpmodel/atomic_model/pairtab_atomic_model.py 50.00% 1 Missing ⚠️
... and 8 more
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4417      +/-   ##
==========================================
+ Coverage   83.27%   83.30%   +0.02%     
==========================================
  Files         667      667              
  Lines       61445    61533      +88     
  Branches     3486     3486              
==========================================
+ Hits        51166    51257      +91     
+ Misses       9152     9149       -3     
  Partials     1127     1127              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (9)
doc/train/multi-task-training.md (2)

56-63: Enhance the shared_level parameter documentation

The shared_level documentation needs additional clarity:

  1. Add an example showing the complete syntax for specifying different shared levels
  2. Explain the colon syntax in "my_descriptor:shared_level"
  3. Include practical examples for each sharing level (0 and 1)

Consider adding this example:

{
  "model": {
    "descriptor": {
      "type": "se_e2_a",
      "shared_key": "my_descriptor:1"  // Level 1: shares only type embedding
    }
  }
}
🧰 Tools
🪛 LanguageTool

[misspelling] ~62-~62: Use “an” instead of ‘a’ if the following word starts with a vowel sound, e.g. ‘an article’, ‘an hour’.
Context: ...escriptors support all levels (e.g., se_a only supports level 0) - - For fitting ...

(EN_A_VS_AN)


64-69: Enhance the multitask approaches documentation

Consider adding:

  1. A comparison table showing pros/cons of each approach
  2. Guidelines for choosing between the approaches
  3. More details about how numb_caseid affects model behavior and data handling

Consider adding this comparison:

| Aspect | Descriptor-only Sharing | Full Sharing with Data ID |
|--------|------------------------|---------------------------|
| Use Case| Different fitting requirements | Similar fitting patterns |
| Memory | Higher (separate fitting nets) | Lower (shared parameters) |
| Flexibility | More flexible | More efficient |
deepmd/pt/model/task/fitting.py (2)

211-219: Consider alternative initialization strategy.

The current implementation initializes caseid with zeros, but there's a commented-out alternative using torch.eye. Consider if the identity matrix initialization might be more appropriate as it would provide unique, orthogonal case identifiers by default.

-                torch.zeros(self.numb_caseid, dtype=self.prec, device=device),
+                torch.eye(self.numb_caseid, dtype=self.prec, device=device)[0],

500-512: Consider adding shape assertions for debugging.

The implementation correctly handles case identification concatenation. Consider adding shape assertions to help with debugging:

         if self.numb_caseid > 0:
             assert self.caseid is not None
             caseid = torch.tile(self.caseid.reshape([1, 1, -1]), [nf, nloc, 1])
+            assert caseid.shape == (nf, nloc, self.numb_caseid), f"Expected shape {(nf, nloc, self.numb_caseid)}, got {caseid.shape}"
             xx = torch.cat(
                 [xx, caseid],
                 dim=-1,
             )
deepmd/tf/fit/dos.py (2)

77-78: Verify the error message for unsupported feature

The error message "numb_caseid is not supported in TensorFlow" could be more informative. Consider providing guidance on alternatives or explaining why it's not supported.

-            raise ValueError("numb_caseid is not supported in TensorFlow.")
+            raise ValueError("numb_caseid is not supported in TensorFlow. Please use PyTorch backend for this feature or consider using alternative approaches for case identification.")

Also applies to: 116-116, 138-140


Line range hint 1-745: Consider architectural documentation for TensorFlow limitations

The code introduces a feature flag (numb_caseid) that is explicitly unsupported in TensorFlow. This architectural decision should be documented:

  1. Why is this feature only supported in PyTorch?
  2. What are the technical limitations in TensorFlow?
  3. Is there a migration path for users who need this feature?

Consider adding a section in the documentation explaining these architectural decisions and providing guidance for users.

deepmd/tf/fit/polar.py (1)

173-174: Consider adding a TODO comment for future backend support.

The error message indicates TensorFlow-specific limitation. Consider adding a TODO comment to document potential future support in other backends.

 if numb_caseid > 0:
+    # TODO: Add support for case IDs in other backends
     raise ValueError("numb_caseid is not supported in TensorFlow.")
deepmd/pt/model/descriptor/dpa2.py (2)

407-408: Consider updating documentation to reflect supported sharing levels

The share_params method now supports shared_level values of 0 and 1. To improve clarity, please update the method's docstring or comments to specify the supported shared_level values and their intended behaviors.


409-410: Provide a more informative error message for unsupported sharing levels

In the else block, a NotImplementedError is raised without an explanatory message. Consider adding an informative error message indicating the unsupported shared_level value and listing the supported values (0 and 1). This will aid users in debugging and understanding the limitations.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between d399c6e and ff5c3f8.

📒 Files selected for processing (14)
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py (2 hunks)
  • deepmd/pt/model/atomic_model/linear_atomic_model.py (2 hunks)
  • deepmd/pt/model/descriptor/dpa2.py (1 hunks)
  • deepmd/pt/model/task/dipole.py (4 hunks)
  • deepmd/pt/model/task/fitting.py (11 hunks)
  • deepmd/pt/model/task/invar_fitting.py (4 hunks)
  • deepmd/pt/model/task/polarizability.py (5 hunks)
  • deepmd/pt/model/task/property.py (4 hunks)
  • deepmd/tf/descriptor/se_a_ebd.py (1 hunks)
  • deepmd/tf/fit/dipole.py (6 hunks)
  • deepmd/tf/fit/dos.py (7 hunks)
  • deepmd/tf/fit/ener.py (7 hunks)
  • deepmd/tf/fit/polar.py (6 hunks)
  • doc/train/multi-task-training.md (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py
  • deepmd/pt/model/atomic_model/linear_atomic_model.py
  • deepmd/pt/model/task/dipole.py
  • deepmd/pt/model/task/invar_fitting.py
  • deepmd/pt/model/task/polarizability.py
  • deepmd/pt/model/task/property.py
  • deepmd/tf/fit/dipole.py
🧰 Additional context used
🪛 LanguageTool
doc/train/multi-task-training.md

[misspelling] ~62-~62: Use “an” instead of ‘a’ if the following word starts with a vowel sound, e.g. ‘an article’, ‘an hour’.
Context: ...escriptors support all levels (e.g., se_a only supports level 0) - - For fitting ...

(EN_A_VS_AN)

🔇 Additional comments (13)
deepmd/pt/model/task/fitting.py (6)

67-67: LGTM: Appropriate parameter sharing exclusion.

The modification correctly excludes both bias_atom_e and caseid from parameter sharing, which is appropriate as these parameters should remain instance-specific.


98-99: LGTM: Well-documented parameter addition.

The numb_caseid parameter is properly documented and has a safe default value of 0.


224-224: LGTM: Correct input dimension calculation.

The input dimension calculation correctly includes the case identification dimension.


361-369: LGTM: Proper one-hot encoding implementation.

The method correctly implements one-hot encoding for case identification using torch.eye.


382-383: LGTM: Consistent attribute access implementation.

The changes to __setitem__ and __getitem__ maintain consistency with the existing pattern.

Also applies to: 400-401


284-284: Verify version compatibility in deserialization.

The serialization version has been bumped and new fields added correctly. However, ensure that the deserialization method properly handles older versions without these fields.

Also applies to: 292-292, 301-301

✅ Verification successful

No version handling needed in deserialization

The deserialization method in GeneralFitting class is simple and doesn't require version handling because:

  1. It only extracts variables and networks from the serialized data
  2. All other fields are passed directly to the constructor as kwargs
  3. The new fields (numb_caseid and caseid) are optional with default values in the constructor

This design makes it naturally backward compatible - older versions without these fields will work correctly since the constructor handles missing parameters with defaults.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for version handling in deserialize methods
rg -A 5 "def deserialize.*version" 

Length of output: 36


Script:

#!/bin/bash
# Search for deserialization related code with broader patterns
rg -A 10 "deserialize|from_json|load" deepmd/pt/model/task/fitting.py

Length of output: 737


Script:

#!/bin/bash
# Look for version handling in the fitting module
rg -A 10 "@version|version" deepmd/pt/model/task/fitting.py

Length of output: 520


Script:

#!/bin/bash
# Search for any backward compatibility handling
ast-grep --pattern 'if $version'

Length of output: 34


Script:

#!/bin/bash
# Check the entire file content for context
cat deepmd/pt/model/task/fitting.py

Length of output: 20753

deepmd/tf/descriptor/se_a_ebd.py (1)

61-62: ⚠️ Potential issue

Implementation missing for documented parameter

The numb_caseid parameter is documented in the class docstring but not implemented in the class. This creates an inconsistency between the documentation and the actual implementation.

Let's verify how this parameter is used across the codebase:

Suggested implementation:

     def __init__(
         self,
         rcut: float,
         rcut_smth: float,
         sel: list[int],
         neuron: list[int] = [24, 48, 96],
         axis_neuron: int = 8,
         resnet_dt: bool = False,
         trainable: bool = True,
         seed: Optional[int] = None,
         type_one_side: bool = True,
         type_nchanl: int = 2,
         type_nlayer: int = 1,
         numb_aparam: int = 0,
+        numb_caseid: int = 0,
         set_davg_zero: bool = False,
         activation_function: str = "tanh",
         precision: str = "default",
         exclude_types: list[list[int]] = [],
         **kwargs,
     ) -> None:
         """Constructor."""
         DescrptSeA.__init__(
             self,
             rcut,
             rcut_smth,
             sel,
             neuron=neuron,
             axis_neuron=axis_neuron,
             resnet_dt=resnet_dt,
             trainable=trainable,
             seed=seed,
             type_one_side=type_one_side,
             set_davg_zero=set_davg_zero,
             activation_function=activation_function,
             precision=precision,
         )
         self.type_nchanl = type_nchanl
         self.type_nlayer = type_nlayer
         self.type_one_side = type_one_side
         self.numb_aparam = numb_aparam
+        self.numb_caseid = numb_caseid
deepmd/tf/fit/polar.py (3)

66-67: LGTM: Parameter addition and validation are well implemented.

The numb_caseid parameter is properly:

  • Documented in the class docstring
  • Added with a default value of 0
  • Validated with a clear error message

Also applies to: 100-100, 168-168, 173-174


587-587: LGTM: Serialization changes are consistent.

The version increment and addition of numb_caseid to the serialized data are appropriate for maintaining backward compatibility.

Also applies to: 597-597


635-635: LGTM: Version compatibility check is properly updated.

The version compatibility check is correctly updated to support versions 1-4, maintaining backward compatibility while allowing the new version.

deepmd/tf/fit/ener.py (3)

120-121: LGTM: Parameter documentation is complete

The numb_caseid parameter is properly documented in the docstring.


913-913: Verify serialization/deserialization compatibility

The serialization format has been updated with new fields:

  1. Version number increased to 3
  2. Added numb_caseid field
  3. Added caseid to variables (set to None)

Ensure that:

  1. Deserialization handles both old and new formats
  2. Unit tests cover both formats

Also applies to: 923-923, 955-955

✅ Verification successful

Based on the search results, I can now provide a definitive response about the serialization changes:

Serialization version changes are properly handled

The codebase shows comprehensive version compatibility checks:

  • The new version 3 is properly handled in deserialization with check_version_compatibility() calls that accept version 3 as maximum version
  • Both PyTorch and TensorFlow implementations handle version compatibility
  • The deserialization code explicitly handles both old and new formats by using default values for new fields:
    • check_version_compatibility(data.pop("@version", 1), 3, 1) pattern shows backward compatibility with version 1
    • New fields like numb_caseid are safely handled through dictionary pop operations with defaults

The serialization changes follow good practices:

  1. Version number increment is appropriate for the schema changes
  2. Backward compatibility is maintained
  3. Version checks are consistent across the codebase
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for related test files
fd "test.*ener" -t f

# Check for serialization test coverage
rg -A 5 "test.*serialize.*version" 

Length of output: 325


Script:

#!/bin/bash
# Let's check the test files for serialization tests
for test_file in $(fd "test.*ener" -t f); do
    echo "=== Checking $test_file ==="
    rg -B 2 -A 5 "serialize|deserialize" "$test_file"
done

# Check implementation of deserialization to verify version handling
ast-grep --pattern 'def restore_descriptor($$$) {
    $$$
}'

# Look for version checks in deserialization
rg -B 2 -A 10 "version.*=|@version" deepmd/

Length of output: 105236


887-887: Verify version compatibility with existing models

The version check has been updated from 2 to 3. Ensure that:

  1. Existing models with version 2 can still be loaded
  2. Version 3 changes are documented
✅ Verification successful

Let me analyze the version compatibility checks across the codebase to verify the impact.


Let me check the implementation of the version compatibility function to understand its behavior.


Based on the analysis of the codebase, I can now generate the final response:

Version 3 update is safe and properly implemented

The version update from 2 to 3 is necessary to support the new numb_caseid field in PyTorch models for multi-task training. The change is safe because:

  1. The version compatibility check ensures backward compatibility by accepting any version between 1 and 3
  2. The numb_caseid feature is only supported in PyTorch models, and TensorFlow models explicitly raise an error if numb_caseid > 0
  3. The version update is consistently implemented across all fitting modules (ener, dos, dipole, etc.)
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for version compatibility documentation
rg "version.*2.*to.*3" README.md docs/
# Check for other version checks that might need updating
ast-grep --pattern 'check_version_compatibility($_, $_, $_)'

Length of output: 7931


Script:

#!/bin/bash
# Check for version compatibility implementation
rg -A 5 "def check_version_compatibility" 
# Check for version-related changes in the changelog
rg -i "version.*3" CHANGELOG.md

Length of output: 499


Script:

#!/bin/bash
# Get the full implementation of check_version_compatibility
cat deepmd/utils/version.py
# Look for version-related changes in the PR
rg -B2 -A2 "numb_caseid"

Length of output: 38369

doc/train/multi-task-training.md Outdated Show resolved Hide resolved
deepmd/tf/fit/dos.py Show resolved Hide resolved
deepmd/tf/fit/ener.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (10)
doc/train/multi-task-training.md (2)

56-64: Minor improvements needed in shared_level documentation

The documentation is much clearer now, but there are a few minor points to address:

  1. Use "an" instead of "a" before "se_a" for correct grammar
  2. Consider explicitly stating that level 0 is the default when not specified

Apply these changes:

-    - Not all descriptors support all levels (e.g., se_a only supports level 0)
+    - Not all descriptors support all levels (e.g., an se_a only supports level 0)
🧰 Tools
🪛 LanguageTool

[misspelling] ~62-~62: Use “an” instead of ‘a’ if the following word starts with a vowel sound, e.g. ‘an article’, ‘an hour’.
Context: ...escriptors support all levels (e.g., se_a only supports level 0) - - For fitting ...

(EN_A_VS_AN)


65-69: Consider expanding dim_case_embd documentation

The documentation clearly presents the two approaches, but the explanation of dim_case_embd could be more detailed. Consider adding:

  1. An explanation of how the one-hot embedding works
  2. A small example showing the embedding structure
  3. The implications of changing this parameter after training

Would you like me to propose expanded documentation for the dim_case_embd parameter?

source/tests/pt/test_multitask.py (2)

58-58: LGTM! Consider adding a docstring.

The shared fitting parameter comparison logic is well implemented. Consider adding a docstring to explain the significance of excluding bias_atom_e and case_embd from the comparison.

 def test_multitask_train(self) -> None:
+    """Test multitask training with support for shared fitting.
+    
+    Note:
+        When share_fitting is enabled, all fitting network parameters are shared
+        except for bias_atom_e and case_embd which are task-specific.
+    """
     # test multitask training

Also applies to: 73-78


239-277: LGTM! Add class documentation.

The test class implementation is well-structured and follows the established patterns. Consider adding class-level documentation to explain the purpose of shared fitting tests.

 class TestMultiTaskSeASharefit(unittest.TestCase, MultiTaskTrainTest):
+    """Test cases for multi-task training with shared fitting networks.
+    
+    This class verifies that:
+    1. Models correctly share fitting network parameters
+    2. Task-specific parameters (bias_atom_e, case_embd) remain independent
+    3. The shared fitting configuration loads and trains successfully
+    """
     def setUp(self) -> None:
deepmd/tf/fit/polar.py (2)

173-174: Consider improving the error message

While the validation is correct, the error message could be more informative by explaining why case embeddings are not supported in TensorFlow and potentially suggesting alternatives.

-            raise ValueError("dim_case_embd is not supported in TensorFlow.")
+            raise ValueError("Case embeddings (dim_case_embd > 0) are not supported in TensorFlow implementation. Please use PyTorch implementation for this feature.")

Line range hint 12-24: Consider adding type hints for imported components

The imports could benefit from explicit type hints to improve code maintainability and IDE support.

-from deepmd.tf.descriptor import (
-    DescrptSeA,
-)
+from deepmd.tf.descriptor import (
+    DescrptSeA,  # type: tf.Module
-)
deepmd/tf/fit/ener.py (1)

120-121: Update error message to clarify PyTorch alternative

The error message should be more informative about the alternative solution.

-            raise ValueError("dim_case_embd is not supported in TensorFlow.")
+            raise ValueError("dim_case_embd is not supported in TensorFlow. Use PyTorch backend for case embeddings support.")

Also applies to: 161-161, 196-198

deepmd/pt/train/training.py (2)

1270-1285: LGTM: Case embedding initialization logic

The implementation correctly handles case embedding initialization for multi-task models, with proper checks for:

  1. Multi-task configuration presence
  2. Training state (new vs resuming)
  3. Case embedding configuration validation

Consider adding docstring documentation to explain:

  • The purpose of case embeddings
  • The resuming parameter's effect
  • Return value structure
 def get_model_for_wrapper(_model_params, resuming=False):
+    """Initialize model with optional case embedding support for multi-task learning.
+    
+    Args:
+        _model_params: Model configuration parameters
+        resuming: If True, skip case embedding initialization as it's loaded from checkpoint
+    
+    Returns:
+        Single model instance or dict of models for multi-task setup
+    """

1289-1311: LGTM: Case embedding configuration validation

The implementation provides robust validation for case embeddings in multi-task models:

  1. Validates multi-task setup requirement
  2. Ensures consistent case embedding dimensions across models
  3. Returns clear configuration for embedding initialization

Consider enhancing the error message to be more descriptive:

-            f"All models must have the same dimension of data identification, while the settings are: {numb_case_embd_list}"
+            f"Inconsistent case embedding dimensions across models. Found dimensions {dict(zip(sorted_model_keys, numb_case_embd_list))}. All models must have the same dimension."
deepmd/utils/argcheck.py (1)

Line range hint 1436-1469: Consider refactoring the repeated dim_case_embd parameter definition.

The implementation is correct, but the same parameter definition is duplicated across all fitting functions. Consider extracting it into a helper function to improve maintainability.

Example refactor:

def _add_case_embd_arg():
    """Helper function to create the case embedding dimension argument."""
    return Argument(
        "dim_case_embd",
        int,
        optional=True,
        default=0,
        doc=doc_only_pt_supported + doc_numb_case_embd,
    )

# Usage in fitting functions:
def fitting_ener():
    return [
        # ... other arguments ...
        _add_case_embd_arg(),
        # ... remaining arguments ...
    ]
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between ff5c3f8 and 99369d9.

📒 Files selected for processing (34)
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py (2 hunks)
  • deepmd/dpmodel/atomic_model/make_base_atomic_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/pairtab_atomic_model.py (1 hunks)
  • deepmd/dpmodel/fitting/dipole_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/dos_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/ener_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/general_fitting.py (9 hunks)
  • deepmd/dpmodel/fitting/invar_fitting.py (3 hunks)
  • deepmd/dpmodel/fitting/polarizability_fitting.py (4 hunks)
  • deepmd/dpmodel/fitting/property_fitting.py (3 hunks)
  • deepmd/dpmodel/model/make_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/linear_atomic_model.py (2 hunks)
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py (1 hunks)
  • deepmd/pt/model/model/make_model.py (1 hunks)
  • deepmd/pt/model/task/dipole.py (4 hunks)
  • deepmd/pt/model/task/dos.py (3 hunks)
  • deepmd/pt/model/task/ener.py (3 hunks)
  • deepmd/pt/model/task/fitting.py (11 hunks)
  • deepmd/pt/model/task/invar_fitting.py (4 hunks)
  • deepmd/pt/model/task/polarizability.py (5 hunks)
  • deepmd/pt/model/task/property.py (4 hunks)
  • deepmd/pt/train/training.py (2 hunks)
  • deepmd/tf/fit/dipole.py (6 hunks)
  • deepmd/tf/fit/dos.py (7 hunks)
  • deepmd/tf/fit/ener.py (7 hunks)
  • deepmd/tf/fit/polar.py (6 hunks)
  • deepmd/utils/argcheck.py (10 hunks)
  • doc/train/multi-task-training.md (1 hunks)
  • examples/water_multi_task/pytorch_example/input_torch_sharefit.json (1 hunks)
  • source/tests/pt/model/water/multitask_sharefit.json (1 hunks)
  • source/tests/pt/test_multitask.py (3 hunks)
  • source/tests/universal/dpmodel/fitting/test_fitting.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (27)
  • deepmd/dpmodel/atomic_model/dp_atomic_model.py
  • deepmd/dpmodel/atomic_model/linear_atomic_model.py
  • deepmd/dpmodel/atomic_model/make_base_atomic_model.py
  • deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
  • deepmd/dpmodel/fitting/dipole_fitting.py
  • deepmd/dpmodel/fitting/dos_fitting.py
  • deepmd/dpmodel/fitting/ener_fitting.py
  • deepmd/dpmodel/fitting/invar_fitting.py
  • deepmd/dpmodel/fitting/polarizability_fitting.py
  • deepmd/dpmodel/fitting/property_fitting.py
  • deepmd/dpmodel/model/make_model.py
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/atomic_model/linear_atomic_model.py
  • deepmd/pt/model/atomic_model/pairtab_atomic_model.py
  • deepmd/pt/model/model/make_model.py
  • deepmd/pt/model/task/dipole.py
  • deepmd/pt/model/task/dos.py
  • deepmd/pt/model/task/ener.py
  • deepmd/pt/model/task/fitting.py
  • deepmd/pt/model/task/invar_fitting.py
  • deepmd/pt/model/task/polarizability.py
  • deepmd/pt/model/task/property.py
  • deepmd/tf/fit/dipole.py
  • deepmd/tf/fit/dos.py
  • examples/water_multi_task/pytorch_example/input_torch_sharefit.json
  • source/tests/pt/model/water/multitask_sharefit.json
  • source/tests/universal/dpmodel/fitting/test_fitting.py
🧰 Additional context used
🪛 LanguageTool
doc/train/multi-task-training.md

[misspelling] ~62-~62: Use “an” instead of ‘a’ if the following word starts with a vowel sound, e.g. ‘an article’, ‘an hour’.
Context: ...escriptors support all levels (e.g., se_a only supports level 0) - - For fitting ...

(EN_A_VS_AN)

🔇 Additional comments (11)
deepmd/dpmodel/fitting/general_fitting.py (5)

108-108: LGTM! Constructor changes are well-implemented.

The initialization of case embedding parameters follows the established pattern and includes proper type hints and default values.

Also applies to: 131-131, 176-179


185-185: LGTM! Input dimension calculation is correct.

The case embedding dimension is properly incorporated into the total input dimension calculation.


272-273: LGTM! Serialization changes are complete.

The version bump and addition of case embedding fields to serialization are properly implemented.

Also applies to: 290-291, 308-308, 316-316, 325-325


446-458: LGTM! Case embedding handling in _call_common is well-implemented.

The case embedding is properly concatenated to the input tensor and correctly handled in the vacuum contribution calculation.


232-238: 🛠️ Refactor suggestion

Add input validation for case_idx

Similar to the existing review comments for set_dataid and set_caseid, consider adding input validation for case_idx.

 def set_case_embd(self, case_idx: int):
     """
     Set the case embedding of this fitting net by the given case_idx,
     typically concatenated with the output of the descriptor and fed into the fitting net.
     """
+    if not isinstance(case_idx, (int, np.integer)):
+        raise TypeError(f"case_idx must be an integer, got {type(case_idx)}")
+    if not 0 <= case_idx < self.dim_case_embd:
+        raise ValueError(f"case_idx must be in range [0, {self.dim_case_embd}), got {case_idx}")
     self.case_embd = np.eye(self.dim_case_embd, dtype=self.prec)[case_idx]
deepmd/tf/fit/polar.py (2)

66-67: LGTM: Parameter addition is well-documented

The new parameter dim_case_embd is properly added with clear documentation and a default value of 0.

Also applies to: 100-100


587-587: Verify version compatibility handling

The version update from 3 to 4 is properly implemented in both serialization and deserialization. However, let's verify there are no other version-dependent features that need handling.

Also applies to: 597-597, 635-635

✅ Verification successful

Let me analyze the version-related code further by searching for version handling logic.


Let me search for version-related code specifically in the polar fitting module to verify the version compatibility handling.


Based on the code analysis, I can now provide a definitive response about the version compatibility handling in the polar fitting module.

Version compatibility is properly implemented

The version update to 4 is correctly handled in both TensorFlow and PyTorch implementations:

  • The TensorFlow implementation (polar.py) explicitly sets version 4 in serialization and checks version compatibility with maximum version 4 and minimum version 1
  • The PyTorch implementation (polarizability.py) maintains the same version compatibility checks
  • Both implementations use the standard version compatibility checking utility
  • The version change is properly documented in the serialized data structure
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for version-related code and compatibility handling
# Test: Look for version checks and compatibility handling
rg -A 5 "@version|check_version_compatibility" 

Length of output: 144602


Script:

#!/bin/bash
# Search for version-related code handling and compatibility checks
rg -A 5 "version.*=|check_version|load_version" 

Length of output: 106004


Script:

#!/bin/bash
# Search for version-related code in polar fitting module and its tests
rg -A 5 "version|check_version" deepmd/tf/fit/polar.py deepmd/pt/model/task/polarizability.py

Length of output: 2615

deepmd/tf/fit/ener.py (2)

887-887: LGTM: Version compatibility check updated correctly

The version bump from 2 to 3 is appropriate for the new case embedding feature.


913-913: LGTM: Serialization format updated consistently

The serialization changes properly handle the new case embedding feature:

  • Version number updated to 3
  • Added dim_case_embd field
  • Added case_embd to variables (initialized as None)

Also applies to: 923-923, 955-955

deepmd/pt/train/training.py (1)

268-268: LGTM: Model initialization parameter addition

The addition of the resuming parameter helps control when case embeddings should be initialized, preventing unnecessary reinitialization when resuming training.

deepmd/utils/argcheck.py (1)

Line range hint 1520-1543: LGTM! Identical implementation pattern as previous segment.

The implementation follows the same correct pattern as in the energy fitting function.

Also applies to: 1575-1593, 1624-1659, 1702-1730

source/tests/pt/test_multitask.py Show resolved Hide resolved
@iProzd iProzd changed the title feat(pt/dp): support dataid and sharable fitting feat(pt/dp): support case embedding and sharable fitting Nov 27, 2024
@iProzd iProzd requested a review from njzjz November 27, 2024 11:21
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
doc/train/multi-task-training.md (1)

65-70: Add context for choosing between multi-task approaches

The documentation clearly explains the two approaches but would benefit from additional context:

  1. When to use each approach (use cases)
  2. Trade-offs between the approaches
  3. Performance implications of using shared fitting networks vs. separate ones
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 581de82 and 656d200.

📒 Files selected for processing (1)
  • doc/train/multi-task-training.md (1 hunks)
🧰 Additional context used
🪛 LanguageTool
doc/train/multi-task-training.md

[misspelling] ~63-~63: Use “an” instead of ‘a’ if the following word starts with a vowel sound, e.g. ‘an article’, ‘an hour’.
Context: ...escriptors support all levels (e.g., se_a only supports level 0) - For fitting ...

(EN_A_VS_AN)

🪛 Markdownlint (0.35.0)
doc/train/multi-task-training.md

54-54: null
Spaces inside code span elements

(MD038, no-space-in-code)


56-56: null
Spaces inside code span elements

(MD038, no-space-in-code)

🔇 Additional comments (2)
doc/train/multi-task-training.md (2)

64-64: Expand the fitting nets sharing documentation

The current documentation for fitting nets sharing is too brief. Please consider adding:

  1. Explanation of why only level 0 is supported
  2. Details about how bias_atom_e and caseid parameters behave differently
  3. Example configuration showing fitting net sharing

59-63: Verify the supported shared_level values

The documentation states that valid shared_level values are 0-1, but previous documentation mentioned support for levels 0-3 with additional sharing behaviors for DPA2. Please clarify:

  1. If the reduction to levels 0-1 is intentional
  2. If DPA2-specific levels (2-3) should be documented separately
✅ Verification successful

Based on the search results, I can now generate the final response:

Documentation accurately reflects the implemented shared_level values

The documentation is correct in stating that valid shared_level values are 0-1. The codebase implementation shows:

  • Most descriptors (se_a, se_t, se_r) only support level 0
  • Only DPA1, DPA2, and se_t_tebd support both levels 0 and 1
  • Any other level (2-3) raises NotImplementedError
  • Level 0 shares all parameters
  • Level 1 shares only type embedding
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for shared_level implementation to verify supported values
rg -A 5 "shared_level.*=.*[0-9]" 

Length of output: 8118

🧰 Tools
🪛 LanguageTool

[misspelling] ~63-~63: Use “an” instead of ‘a’ if the following word starts with a vowel sound, e.g. ‘an article’, ‘an hour’.
Context: ...escriptors support all levels (e.g., se_a only supports level 0) - For fitting ...

(EN_A_VS_AN)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (8)
deepmd/pd/model/task/ener.py (1)

45-45: Add parameter documentation and validation

The new dim_case_embd parameter needs documentation to explain its purpose and constraints. Also consider adding validation to ensure non-negative dimensions.

    def __init__(
        self,
        ntypes: int,
        dim_descrpt: int,
        neuron: list[int] = [128, 128, 128],
        bias_atom_e: Optional[paddle.Tensor] = None,
        resnet_dt: bool = True,
        numb_fparam: int = 0,
        numb_aparam: int = 0,
        dim_case_embd: int = 0,
        activation_function: str = "tanh",
        precision: str = DEFAULT_PRECISION,
+       """
+       Parameters
+       ----------
+       dim_case_embd : int, optional
+           Dimension of case/task embedding for multi-task learning.
+           Must be non-negative. Default is 0 (no case embedding).
+       """
+       if dim_case_embd < 0:
+           raise ValueError("dim_case_embd must be non-negative")
deepmd/pd/model/task/invar_fitting.py (1)

60-62: Enhance parameter documentation

The docstring for dim_case_embd should be more descriptive:

  • Add type hint in the description (: int)
  • Document any validation rules (e.g., must be non-negative)
  • Clarify when this feature will be supported
     dim_case_embd : int
-        (Not supported yet)
-        Dimension of case specific embedding.
+        (Not supported yet - planned for future release)
+        Dimension of case specific embedding. Must be a non-negative integer.
+        When supported, this will enable case-specific embeddings for multi-task training.
deepmd/pd/model/atomic_model/dp_atomic_model.py (2)

142-147: The implementation looks good but could be more robust.

The method is well-structured and properly documented. However, consider these improvements for better robustness:

Consider applying these enhancements:

-    def set_case_embd(self, case_idx: int):
+    def set_case_embd(self, case_idx: int) -> None:
         """
         Set the case embedding of this atomic model by the given case_idx,
         typically concatenated with the output of the descriptor and fed into the fitting net.
+
+        Parameters
+        ----------
+        case_idx : int
+            The index of the case embedding to set. Must be non-negative.
+
+        Raises
+        ------
+        ValueError
+            If case_idx is negative
+        AttributeError
+            If fitting_net doesn't support case embedding
         """
+        if case_idx < 0:
+            raise ValueError(f"case_idx must be non-negative, got {case_idx}")
+        if not hasattr(self.fitting_net, 'set_case_embd'):
+            raise AttributeError("fitting_net doesn't support case embedding")
         self.fitting_net.set_case_embd(case_idx)

142-147: Consider documenting the case embedding architecture

The implementation follows good architectural patterns, but the broader context of case embedding usage could be better documented.

Consider:

  1. Adding a section in the class docstring explaining the case embedding feature and its role in multi-task training
  2. Creating an architecture decision record (ADR) documenting:
    • The motivation for case embeddings
    • The chosen implementation approach
    • Alternative approaches considered
    • Impact on existing components
  3. Updating relevant documentation to explain how case embeddings affect model training and inference
deepmd/pd/model/task/fitting.py (2)

168-170: Consider enhancing the error message.

While the validation is correct, the error message could be more informative by including when this feature will be supported or providing alternative approaches.

-            raise ValueError("dim_case_embd is not supported yet in PaddlePaddle.")
+            raise ValueError(
+                "Case embedding (dim_case_embd > 0) is not supported yet in PaddlePaddle. "
+                "This feature is planned for a future release. For now, please use dim_case_embd=0."
+            )

361-367: Consider using a shared error message constant.

To maintain consistency and ease future updates, consider defining the error message as a class constant.

+    _CASE_EMBD_ERROR = "Case embedding is not supported yet in PaddlePaddle."
+
     def __init__(self, ...):
         if dim_case_embd > 0:
-            raise ValueError("dim_case_embd is not supported yet in PaddlePaddle.")
+            raise ValueError(self._CASE_EMBD_ERROR)

     def set_case_embd(self, case_idx: int):
         """
         Set the case embedding of this fitting net by the given case_idx,
         typically concatenated with the output of the descriptor and fed into the fitting net.
         """
-        raise NotImplementedError("set_case_embd is not supported yet in PaddlePaddle.")
+        raise NotImplementedError(self._CASE_EMBD_ERROR)
deepmd/pd/model/model/make_model.py (2)

519-521: Add docstring and return type annotation.

The method implementation looks correct, but could benefit from improved documentation:

-        def set_case_embd(self, case_idx: int):
+        def set_case_embd(self, case_idx: int) -> None:
+            """Set the case embedding index for the atomic model.
+            
+            Parameters
+            ----------
+            case_idx : int
+                The index of the case embedding to be set.
+            """
             self.atomic_model.set_case_embd(case_idx)

519-521: Add test coverage for case embedding functionality.

The new method should be covered by unit tests to ensure proper delegation and error handling.

Would you like me to help create unit tests for:

  1. Basic case embedding setting
  2. Error handling for invalid case indices
  3. Integration with the atomic model
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between e809469 and b81a716.

📒 Files selected for processing (5)
  • deepmd/pd/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pd/model/model/make_model.py (1 hunks)
  • deepmd/pd/model/task/ener.py (3 hunks)
  • deepmd/pd/model/task/fitting.py (6 hunks)
  • deepmd/pd/model/task/invar_fitting.py (3 hunks)
🔇 Additional comments (7)
deepmd/pd/model/task/ener.py (1)

63-63: LGTM: Parameter correctly propagated to parent class

The dim_case_embd parameter is properly passed to the parent class constructor.

deepmd/pd/model/task/invar_fitting.py (1)

Line range hint 32-35: Consider architectural implications of case embedding support

Since this is part of a larger feature for case embedding support:

  1. Document the overall architecture design for case embedding
  2. Consider adding integration tests for the feature when implemented
  3. Update the class docstring to mention case embedding capability in the class overview

Also applies to: 98-98

deepmd/pd/model/atomic_model/dp_atomic_model.py (1)

142-147: Verify case embedding implementation across components

Let's verify the integration of case embedding functionality across the codebase.

✅ Verification successful

Case embedding implementation is properly integrated across components

The verification shows that case embedding functionality is well-integrated across the codebase:

  • The set_case_embd method is consistently implemented in all atomic models (DP, Linear, PairTab)
  • The fitting networks properly handle case embeddings with appropriate initialization and propagation
  • The feature is supported in PyTorch but explicitly marked as unsupported in TensorFlow/PaddlePaddle
  • Test files and examples demonstrate usage with multi-task models
  • Documentation in doc/train/multi-task-training.md properly describes the feature
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify case embedding implementation and usage

# Check if fitting_net classes implement set_case_embd
echo "Checking for set_case_embd implementation in fitting networks:"
ast-grep --pattern 'class $_ {
  $$$
  set_case_embd($$$) {
    $$$
  }
  $$$
}'

# Check for case embedding usage patterns
echo -e "\nChecking for case embedding usage patterns:"
rg "set_case_embd|dim_case_embd" -A 3

# Check for potential test files
echo -e "\nChecking for related test files:"
fd -g "*test*case*embd*.py"

Length of output: 42721

deepmd/pd/model/task/fitting.py (3)

106-108: LGTM: Documentation is clear and accurate.

The documentation clearly indicates that case embedding is not yet supported and provides a good description of the parameter's purpose.


146-146: LGTM: Parameter addition is well-placed.

The dim_case_embd parameter is appropriately placed in the argument list with a safe default value of 0.


284-284: LGTM: Serialization changes are consistent.

The version bump and addition of case embedding fields are appropriate for maintaining backward compatibility.

Also applies to: 292-292, 301-301

deepmd/pd/model/model/make_model.py (1)

519-521: Verify initialization and usage order of case embeddings.

Please ensure that:

  1. Case embeddings are properly initialized in the atomic model before this method is called
  2. The method is called at the appropriate time in the model's lifecycle
✅ Verification successful

Let me try a different search strategy to gather more information about the case embedding functionality.


Based on the search results, I can now provide a final response as the code analysis reveals the complete picture of case embedding initialization and usage:

Case embedding initialization and usage order is properly implemented.

The codebase shows that:

  1. Case embeddings are initialized during model creation with dim_case_embd parameter in the fitting network
  2. The set_case_embd method is called at the right time during training through deepmd/pt/train/training.py, specifically:
    • Only called during multitask training from scratch (not during resuming)
    • Called after model initialization but before training starts
    • The case embedding index is properly managed through get_case_embd_config
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for case embedding initialization and usage order

# Search for case embedding initialization
echo "Searching for case embedding initialization..."
rg -p "dim_case_embd|case_embd" --type python

# Search for set_case_embd usage
echo "Searching for set_case_embd usage..."
ast-grep --pattern 'set_case_embd($_)'

Length of output: 305


Script:

#!/bin/bash
# Search for case embedding related code with proper file extension
echo "Searching for case embedding initialization and usage..."
rg -p "case_embd" .

# Search for atomic model initialization and its properties
echo "Searching for atomic model initialization..."
rg -p "atomic_model" .

# Search specifically in the make_model.py file for context
echo "Checking make_model.py content..."
rg -p "class CM|atomic_model" deepmd/pd/model/model/make_model.py

Length of output: 34961

deepmd/pd/model/task/ener.py Show resolved Hide resolved
deepmd/pd/model/task/invar_fitting.py Show resolved Hide resolved
@iProzd iProzd added this pull request to the merge queue Nov 28, 2024
Merged via the queue into deepmodeling:devel with commit a6b61b9 Nov 28, 2024
60 checks passed
@iProzd iProzd deleted the share_fit_prod branch November 28, 2024 11:21
@iProzd iProzd restored the share_fit_prod branch December 13, 2024 06:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants