Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): freeze to StableXLO & DeepEval #4256

Merged
merged 9 commits into from
Oct 30, 2024
Merged

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Oct 25, 2024

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced support for .hlo file extensions in model loading and saving functionalities.
    • Added a DeepEval class for enhanced deep learning model evaluation in molecular simulations.
    • Implemented a new HLO class for managing model predictions within a deep learning framework.
  • Bug Fixes

    • Improved handling of suffixes and backend names in test cases for better consistency.
  • Documentation

    • Added SPDX license identifier to relevant files.
  • Chores

    • Refactored internal methods to streamline model prediction processes.

Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz marked this pull request as ready for review October 25, 2024 23:22
Copy link
Contributor

coderabbitai bot commented Oct 25, 2024

Warning

Rate limit exceeded

@njzjz has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 7 minutes and 11 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Files that changed from the base of the PR and between 748913b and ac65bc7.

📝 Walkthrough
📝 Walkthrough

Walkthrough

The pull request introduces several modifications across multiple files, primarily enhancing the JAXBackend class to support deep evaluation and additional file suffixes. It also implements a new DeepEval class for evaluating deep learning models, updates serialization functions to handle .hlo files, and refactors existing methods for better clarity and functionality. The changes aim to improve the handling of model predictions and data serialization while ensuring compatibility with new file formats.

Changes

File Path Change Summary
deepmd/backend/jax.py Updated JAXBackend to include Backend.Feature.DEEP_EVAL, modified suffixes, and redefined deep_eval, serialize_hook, and deserialize_hook properties.
deepmd/dpmodel/descriptor/se_e2_a.py Simplified handling of sec and updated computation of gr_tmp using xp.sum.
deepmd/dpmodel/utils/serialization.py Expanded save_dp_model and load_dp_model functions to accept .hlo files in addition to .dp.
deepmd/jax/env.py Added import for jax.export as jax_export and updated __all__ list.
deepmd/jax/infer/init.py Added SPDX license identifier comment for LGPL-3.0-or-later.
deepmd/jax/infer/deep_eval.py Introduced DeepEval class for model evaluation with various methods for properties and evaluation.
deepmd/jax/model/hlo.py Added HLO class for managing model predictions, including several methods for parameters and outputs.
deepmd/jax/utils/serialization.py Updated deserialize_to_file and serialize_from_file to handle .hlo files.
source/tests/consistent/io/test_io.py Enhanced IOTest class methods to standardize suffix access and include "jax" in backend names.
deepmd/dpmodel/model/make_model.py Introduced model_call_from_call_lower function and refactored call method in CM class.

Suggested labels

Docs

Suggested reviewers

  • wanghan-iapcm
  • iProzd

Possibly related PRs

  • feat pt : Support property fitting #3867: The changes in this PR enhance the functionality of the DeepEval class, which is directly related to the modifications made in the main PR regarding the deep_eval property in the JAXBackend class.
  • feat: DeepEval.get_model_def_script and common dp show #4131: This PR introduces a new method get_model_def_script, which is relevant to the serialization and deserialization hooks updated in the main PR, as both involve model definitions and their handling.
  • feat(jax): force & virial #4251: The modifications in this PR regarding the handling of forces and virials in the model are related to the changes made in the main PR that enhance the JAXBackend class's capabilities, particularly in terms of output handling.
  • feat(jax/array-api): se_e2_r #4257: The introduction of the DescrptSeR class and its enhancements for array API compatibility align with the main PR's focus on improving the handling of array operations in the JAXBackend class.

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Outside diff range and nitpick comments (10)
deepmd/dpmodel/utils/serialization.py (1)

144-144: Consider using consistent data structures for extension checks.

While the implementation is correct, there's an inconsistency in the data structures used for extension checking:

  • save_dp_model uses tuple: in (".dp", ".hlo")
  • load_dp_model uses set: in {".dp", ".hlo"}

For consistency and maintainability, consider using the same data structure in both functions.

Suggested change:

-    if filename_extension in {".dp", ".hlo"}:
+    if filename_extension in (".dp", ".hlo"):

Or alternatively, if you prefer sets for performance:

# In save_dp_model:
-    if filename_extension in (".dp", ".hlo"):
+    if filename_extension in {".dp", ".hlo"}:
source/tests/consistent/io/test_io.py (1)

135-135: Consider adjusting numerical tolerances for cross-backend comparison.

The test now includes the JAX backend in cross-backend result verification. However, the current tolerances (rtol=1e-12, atol=1e-12) might be too strict for comparing results across different backends, especially when comparing with JAX which might use different numerical implementations.

Consider using more relaxed tolerances (e.g., rtol=1e-7, atol=1e-10) to account for minor numerical differences between backends while still ensuring correctness.

deepmd/jax/utils/serialization.py (1)

70-85: Consider deep copying the data dictionary to prevent unintended mutations

Using data.copy() creates a shallow copy, which might lead to unintended side effects if the nested objects are modified elsewhere. Consider using copy.deepcopy(data) for a deep copy to avoid potential issues with shared references.

Apply this change if deep copying is necessary:

- data = data.copy()
+ import copy
+ data = copy.deepcopy(data)
deepmd/jax/model/hlo.py (7)

2-5: Standardize import statement formatting

There are inconsistent blank lines within the grouped import statements. For better readability and adherence to PEP 8 guidelines, remove unnecessary blank lines between imported items and before the closing parenthesis.

Apply the following diff to standardize the import formatting:

 from typing import (
     Any,
-    
     Optional,
-    
 )
 
 from deepmd.dpmodel.model.transform_output import (
     communicate_extended_output,
-    
 )
 
 from deepmd.dpmodel.output_def import (
     FittingOutputDef,
     ModelOutputDef,
     OutputVariableDef,
-    
 )
 
 from deepmd.dpmodel.utils.nlist import (
     build_neighbor_list,
     extend_coord_with_ghosts,
-    
 )
 
 from deepmd.dpmodel.utils.region import (
     normalize_coord,
-    
 )
 
 from deepmd.jax.env import (
     jax_export,
     jnp,
-    
 )
 
 from deepmd.jax.model.base_model import (
     BaseModel,
-    
 )
 
 from deepmd.utils.data_system import (
     DeepmdDataSystem,
-    
 )

Also applies to: 7-9, 10-14, 15-18, 19-21, 22-25, 26-28, 29-31


52-66: Add docstring to the __init__ method

The __init__ method lacks a docstring. Adding a docstring that describes the purpose of the constructor, its parameters, and any important details will enhance code readability and maintainability.


152-153: Avoid unnecessary variable reassignment

Reassigning variables coord, box, fparam, and aparam to abbreviated names (cc, bb, fp, ap) and deleting the originals may reduce code readability. Consider using the original variable names throughout the method to maintain clarity.

Apply this diff to retain the original variable names:

 nframes, nloc = atype.shape[:2]
-cc, bb, fp, ap = coord, box, fparam, aparam
-del coord, box, fparam, aparam
+coord, box, fparam, aparam = coord, box, fparam, aparam

190-194: Add docstring to the model_output_def method

The model_output_def method lacks a docstring. Providing a docstring will help others understand the purpose and usage of this method.


195-214: Add docstring to the call_lower method

The call_lower method is a key component of the class but lacks a docstring. Adding a docstring with details about the method's functionality, parameters, and return values will improve code comprehension.


215-344: Ensure all public methods have docstrings

Several public methods, such as get_type_map, get_rcut, get_dim_fparam, mixed_types, and others, lack docstrings. Providing docstrings for these methods enhances code readability and maintainability.


251-275: Implement serialization and model-building methods or clarify their purpose

The methods serialize, deserialize, update_sel, and get_model currently raise NotImplementedError. If these methods are intended to be implemented later, consider adding a comment to indicate this. If they are not applicable for this class, provide an explanation or remove them to avoid confusion.

Would you like assistance in implementing these methods or creating placeholders with appropriate comments?

Also applies to: 300-344

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between a66afd3 and 627947a.

📒 Files selected for processing (9)
  • deepmd/backend/jax.py (2 hunks)
  • deepmd/dpmodel/descriptor/se_e2_a.py (1 hunks)
  • deepmd/dpmodel/utils/serialization.py (2 hunks)
  • deepmd/jax/env.py (1 hunks)
  • deepmd/jax/infer/init.py (1 hunks)
  • deepmd/jax/infer/deep_eval.py (1 hunks)
  • deepmd/jax/model/hlo.py (1 hunks)
  • deepmd/jax/utils/serialization.py (3 hunks)
  • source/tests/consistent/io/test_io.py (2 hunks)
✅ Files skipped from review due to trivial changes (2)
  • deepmd/jax/env.py
  • deepmd/jax/infer/init.py
🧰 Additional context used
🪛 Ruff
deepmd/jax/infer/deep_eval.py

303-306: Use ternary operator natoms = len(atom_types[0]) if mixed_type else len(atom_types) instead of if-else-block

Replace if-else-block with natoms = len(atom_types[0]) if mixed_type else len(atom_types)

(SIM108)


332-335: Use ternary operator box_input = cells.reshape([-1, 3, 3]) if cells is not None else None instead of if-else-block

Replace if-else-block with box_input = cells.reshape([-1, 3, 3]) if cells is not None else None

(SIM108)

🔇 Additional comments (11)
deepmd/backend/jax.py (2)

74-78: Verify DeepEval implementation completeness.

The implementation now returns the DeepEval class, but we should verify that this class properly implements the required interface.

Let's check the implementation of DeepEval:

#!/bin/bash
# Description: Verify DeepEval implementation
# Test 1: Check if DeepEval implements all required methods
ast-grep --pattern 'class DeepEval:
  $$$
'

# Test 2: Check if there are any TODO/FIXME comments in the implementation
rg -g '!*.pyc' -g '!*.so' -A 3 'TODO|FIXME' deepmd/jax/infer/deep_eval.py

41-41: Consider the order of file suffixes.

The order of suffixes might affect file type resolution. Consider documenting whether .hlo or .jax takes precedence, or if the order matters for your use case.

Let's check if there's any documentation or tests that specify the precedence:

✅ Verification successful

Based on the search results, I can now provide a definitive response:

The suffix order is intentionally designed for file type resolution.

The order of suffixes [".hlo", ".jax"] is deliberate and correct. From the codebase:

  • .hlo files are HDF5-based files (using h5py) that store model data in a specific format for XLA/JAX optimization
  • .jax files are directory-based checkpoints with a different storage mechanism
  • The serialization logic in deepmd/jax/utils/serialization.py handles each format distinctly with clear precedence:
    • .jax files use Checkpointer with composite handlers
    • .hlo files use HDF5 storage with specific attributes

The current order ensures proper file type resolution without conflicts.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for documentation about file suffix precedence
# Test: Look for any mentions of .hlo or .jax in docs and tests
rg -g '!*.pyc' -g '!*.so' -A 3 '[.](hlo|jax)' docs/ tests/

Length of output: 121


Script:

#!/bin/bash
# Let's check how these suffixes are used in the codebase
rg -g '!*.pyc' -g '!*.so' "suffixes.*(\\.hlo|\\.jax)" -A 3

# Also check for any file handling or extension processing logic
ast-grep --pattern 'def $_(self, $$$):
  $$$
  suffixes
  $$$'

# Look for actual usage of these file types
rg -g '!*.pyc' -g '!*.so' "\\.hlo|\\.jax" -B 2 -A 2

# Check if there are any tests related to file handling
fd "test.*\\.py$" tests/ --exec grep -l "suffixes\\|[.]hlo\\|[.]jax" {}

Length of output: 19333

deepmd/dpmodel/utils/serialization.py (1)

93-93: LGTM! Verify consistent handling of .hlo files.

The addition of .hlo extension support is implemented correctly, maintaining the same robust HDF5-based serialization as .dp files.

Let's verify the consistent handling of both extensions:

✅ Verification successful

Based on the search results, I can now generate the final response:

Verified: .hlo files are handled consistently with .dp files

The codebase shows consistent handling of both extensions:

  • Both serialization and deserialization in dpmodel/utils/serialization.py treat .hlo files identically to .dp files using HDF5
  • The JAX backend properly declares .hlo support in its supported suffixes
  • Additional .hlo-specific handling exists in JAX utils for model deserialization, maintaining format compatibility
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that .hlo files are handled consistently with .dp files
# throughout the codebase

# Test 1: Check for any special handling of .hlo files that might differ from .dp
rg -g '!*.md' -g '!*.txt' '\.hlo' -C 3

# Test 2: Verify that both extensions are always handled together
rg -g '!*.md' -g '!*.txt' '\.dp.*\.hlo|\.hlo.*\.dp' -C 3

Length of output: 4067

source/tests/consistent/io/test_io.py (1)

71-84: Verify suffix indices for each backend.

The changes look good and make the suffix index more explicit. However, all backends are using suffix_idx=0. Please verify if this is intentional, especially for the JAX backend which might support multiple suffixes (.jax, .hlo).

✅ Verification successful

From the search results, I can see the backend suffix definitions:

  1. TensorFlow backend: .pb
  2. PyTorch backend: .pth, .pt
  3. JAX backend: .hlo, .jax
  4. DPModel backend: .dp, .yaml, .yml

Since JAX backend has two suffixes (.hlo and .jax), and it's using index 0 in the test code, let's verify which suffix is at index 0:


Now I have enough information to provide the final response. The JAX backend defines its suffixes as [".hlo", ".jax"], where .hlo is at index 0. Therefore, using suffix_idx=0 in the test is correct as it will use the .hlo suffix, which is the intended behavior.

Using suffix_idx=0 for JAX backend is correct

The test is correctly using suffix_idx=0 for all backends, including JAX. For JAX backend, this will use the .hlo suffix which is the first suffix in its list [".hlo", ".jax"]. This is consistent with the backend's implementation and its serialization logic.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check available suffixes for each backend
# Expected: Verify if using index 0 is appropriate for all backends

# Search for suffix definitions in backend files
echo "Checking backend suffixes:"
rg -A 2 "suffixes\s*=\s*\[" --type py

Length of output: 102


Script:

#!/bin/bash
# Let's try different approaches to find backend suffix definitions

# Search for Backend class definitions and their context
echo "Searching for Backend classes and their properties:"
rg -B 2 -A 10 "class \w+Backend\(" --type py

# Search specifically for suffix-related code
echo -e "\nSearching for suffix-related code:"
rg "suffix" --type py

# Look for backend-specific files
echo -e "\nListing potential backend files:"
fd -e py backend

Length of output: 95223


Script:

#!/bin/bash
# Check JAX backend implementation
echo "Checking JAX backend implementation:"
rg -A 5 "class JAXBackend" deepmd/backend/jax.py

Length of output: 287


Script:

#!/bin/bash
# Let's get the full JAX backend class definition to see the suffixes order
echo "Full JAX backend class definition:"
rg -B 2 -A 15 "suffixes.*=.*\[" deepmd/backend/jax.py

# Also check if there's any specific handling of JAX suffixes in the codebase
echo -e "\nChecking JAX suffix usage:"
rg "\.jax|\.hlo" --type py

Length of output: 6559

deepmd/dpmodel/descriptor/se_e2_a.py (2)

558-558: LGTM: Efficient use of pre-computed values

The change to use self.sel_cumsum directly instead of recomputing the cumulative sum is a good optimization that maintains the same functionality while reducing redundant computations.


Line range hint 580-600: Verify mathematical equivalence of the new computation method

The replacement of einsum operations with direct sum operations appears to maintain the same mathematical operations but in a different form. While this change likely improves JAX compatibility and potentially performance, please verify:

  1. The mathematical equivalence of:

    # Old
    gr_tmp = xp.einsum("lni,lnj->lij", gg, tr)
    # New
    gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)
  2. The mathematical equivalence of:

    # Old
    grrg = xp.einsum("flid,fljd->flij", gr, gr1)
    # New
    grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)

Consider adding comments explaining the mathematical operations being performed to improve code readability, as the direct sum operations are less intuitive than einsum notation.

✅ Verification successful

Looking at the code changes and search results, I can now generate a final response:

Mathematical equivalence verified for einsum replacement

The changes replacing einsum with direct sum operations are mathematically equivalent:

  1. For gr_tmp computation:

    • Original: xp.einsum("lni,lnj->lij", gg, tr) performs matrix multiplication along the 'n' axis
    • New: xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) does the same by:
      • Adding broadcast dimensions with None
      • Multiplying aligned tensors
      • Summing along the correct axis
  2. For grrg computation:

    • Original: xp.einsum("flid,fljd->flij", gr, gr1) performs matrix multiplication along the 'd' axis
    • New: xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) achieves the same through explicit broadcasting and summation

The changes improve JAX compatibility while maintaining the same mathematical operations, just expressed differently.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Search for any existing tests that verify the descriptor computation
# Test: Look for test files that might contain verification of these operations
rg -l "DescrptSeAArrayAPI.*call" --type py

Length of output: 44


Script:

#!/bin/bash
# Let's try a broader search for test files related to descriptor computation
rg -l "DescrptSeA" --type py tests/

# Also search for test files containing 'einsum' to find relevant test cases
rg -l "einsum" --type py tests/

# Look for the actual implementation to understand the context better
rg -l "class DescrptSeA" --type py

Length of output: 671


Script:

#!/bin/bash
# Search for any numerical testing or comparison functions
rg "assert.*allclose|assert.*almost_equal" --type py tests/

Length of output: 108


Script:

#!/bin/bash
# Look for the actual implementation of the descriptor computation
ast-grep --pattern 'def _compute_descriptor($$$)' 

Length of output: 51


Script:

#!/bin/bash
# Let's look at the test files in the correct location
rg -l "DescrptSeA" source/tests/

# Check the implementation in the JAX version since the changes are JAX-related
rg -B 5 -A 5 "einsum|sum.*axis" deepmd/jax/descriptor/se_e2_a.py

# Look at the implementation in the file being reviewed
rg -B 5 -A 5 "einsum|sum.*axis" deepmd/dpmodel/descriptor/se_e2_a.py

Length of output: 4221


Script:

#!/bin/bash
# Look for any numerical testing in the source/tests directory
rg "assert.*allclose|assert.*almost_equal" source/tests/

Length of output: 111515


Script:

#!/bin/bash
# Let's look at the test files specifically for se_e2_a descriptor
rg -B 5 -A 5 "einsum|sum.*axis" source/tests/array_api_strict/descriptor/se_e2_a.py source/tests/consistent/descriptor/test_se_e2_a.py

# Also check for any numerical comparison tests for this descriptor
rg "assert.*allclose|assert.*almost_equal" source/tests/array_api_strict/descriptor/se_e2_a.py source/tests/consistent/descriptor/test_se_e2_a.py

Length of output: 284


Script:

#!/bin/bash
# Look for any implementation comparisons between different array libraries
rg -l "test_se_e2_a" source/tests/consistent/

Length of output: 45

deepmd/jax/utils/serialization.py (4)

6-16: LGTM: New imports correctly integrate necessary modules

The added imports for numpy, orbax.checkpoint, and jax_export are appropriate and essential for handling .hlo files and exporting models. They integrate well with the existing codebase.


49-86: LGTM: Proper implementation of .hlo file handling in deserialize_to_file

The new conditional branch effectively handles models with the .hlo extension. The serialization process correctly uses jax_export to export the model and updates the data dictionary with serialized information and model constants. This addition enhances the utility by supporting more model formats.


54-69: Verify the handling of optional parameters in model export

The conditional inclusion of fparam and aparam based on model.get_dim_fparam() and model.get_dim_aparam() may skip parameters when the dimensions are zero or None. Ensure this behavior aligns with the intended model export logic, especially when dimensions are zero but parameters are still required.

Would you like to run a verification script to confirm that models with zero-dimensional parameters are correctly exported?


141-145: LGTM: Correct deserialization logic for .hlo files in serialize_from_file

The added branch appropriately handles the deserialization of .hlo files by loading the model data and removing unnecessary entries from the data dictionary. This ensures compatibility with the serialization logic.

deepmd/jax/infer/deep_eval.py (1)

385-388: Clarify the output shape calculation for atomic outputs.

There's a commented-out line with a note "Something wrong here?" indicating uncertainty about the correct output shape in the _get_output_shape method for OutputVariableCategory.OUT. Please verify that the current implementation is correct and remove the commented-out code if it's no longer needed.

deepmd/backend/jax.py Show resolved Hide resolved
deepmd/jax/model/hlo.py Outdated Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Outdated Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Outdated Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Outdated Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Show resolved Hide resolved
@@ -39,6 +46,44 @@
model_def_script=ocp.args.JsonSave(model_def_script),
),
)
elif model_file.endswith(".hlo"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable model_def_script is not used.
Copy link

codecov bot commented Oct 26, 2024

Codecov Report

Attention: Patch coverage is 78.31325% with 54 lines in your changes missing coverage. Please review.

Project coverage is 84.29%. Comparing base (159361d) to head (ac65bc7).
Report is 1 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/jax/infer/deep_eval.py 70.42% 42 Missing ⚠️
deepmd/jax/model/hlo.py 82.08% 12 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4256      +/-   ##
==========================================
- Coverage   84.37%   84.29%   -0.08%     
==========================================
  Files         551      553       +2     
  Lines       51585    51820     +235     
  Branches     3052     3052              
==========================================
+ Hits        43524    43683     +159     
- Misses       7100     7175      +75     
- Partials      961      962       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
deepmd/backend/jax.py (1)

74-78: LGTM: Consider enhancing docstring

The implementation correctly returns the DeepEval class. While the existing docstring is adequate, it could be enhanced to include details about the specific DeepEval implementation being used.

     @property
     def deep_eval(self) -> type["DeepEvalBackend"]:
         """The Deep Eval backend of the backend.
 
         Returns
         -------
         type[DeepEvalBackend]
-            The Deep Eval backend of the backend.
+            The JAX-based Deep Eval backend implementation for model evaluation.
         """
         from deepmd.jax.infer.deep_eval import (
             DeepEval,
         )
 
         return DeepEval
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 627947a and c9fee92.

📒 Files selected for processing (1)
  • deepmd/backend/jax.py (2 hunks)
🔇 Additional comments (2)
deepmd/backend/jax.py (2)

37-37: LGTM: Feature flag correctly updated

The addition of Backend.Feature.DEEP_EVAL properly aligns with the implementation of the deep_eval property, resolving the previously identified inconsistency.


41-41: Verify handling of both file formats

The addition of ".hlo" suffix is noted. Please ensure that both serialization hooks can properly handle both formats.

✅ Verification successful

Based on the search results, I can now generate the final response since I have sufficient information about how both formats are handled in the serialization code.

Both .hlo and .jax formats are properly handled in serialization

The serialization code in deepmd/jax/utils/serialization.py has distinct handling paths for both formats:

  • .jax files are handled using BaseModel.deserialize() with checkpointing support
  • .hlo files are processed through load_dp_model() with specific variable management
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that serialization code handles both .hlo and .jax formats
# Search for format-specific handling in serialization code
rg -A 5 "\.hlo|\.jax" --glob "deepmd/jax/utils/serialization.py"

Length of output: 2909

njzjz and others added 2 commits October 25, 2024 22:04
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Outside diff range and nitpick comments (3)
deepmd/jax/utils/auto_batch_size.py (1)

44-59: Simplify the return statement in is_oom_error.

The return statement can be simplified by returning the condition directly.

Apply this diff to simplify the code:

    def is_oom_error(self, e: Exception) -> bool:
        """Check if the exception is an OOM error.

        Parameters
        ----------
        e : Exception
            Exception
        """
        # several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
        # such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
        # (the meaningless error message should be considered as a bug in cusolver)
-        if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
-            "RESOURCE_EXHAUSTED:" in e.args[0]
-        ):
-            return True
-        return False
+        return isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
+            "RESOURCE_EXHAUSTED:" in e.args[0]
+        )
🧰 Tools
🪛 Ruff

55-59: Return the condition directly

Inline condition

(SIM103)

deepmd/dpmodel/model/make_model.py (1)

227-236: Consider using consistent parameter passing style

While the implementation is correct, consider using consistent parameter passing style for better readability:

-                coord=cc,
-                atype=atype,
-                box=bb,
                fparam=fp,
                aparam=ap,
                do_atomic_virial=do_atomic_virial,
+                coord=cc,
+                atype=atype,
+                box=bb,

Group the required parameters together followed by optional parameters for better code organization.

deepmd/jax/infer/deep_eval.py (1)

366-368: Rephrase the in-code comment for professionalism

The comment # this is kinda hacky can be rephrased to maintain a professional tone. Consider changing it to a more descriptive comment like # Assigning placeholder values when output is unavailable.

Apply this diff to update the comment:

-    )  # this is kinda hacky
+    )  # Assigning placeholder values when output is unavailable
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between c9fee92 and d35198b.

📒 Files selected for processing (4)
  • deepmd/dpmodel/model/make_model.py (3 hunks)
  • deepmd/jax/infer/deep_eval.py (1 hunks)
  • deepmd/jax/model/hlo.py (1 hunks)
  • deepmd/jax/utils/auto_batch_size.py (1 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/jax/infer/deep_eval.py

303-306: Use ternary operator natoms = len(atom_types[0]) if mixed_type else len(atom_types) instead of if-else-block

Replace if-else-block with natoms = len(atom_types[0]) if mixed_type else len(atom_types)

(SIM108)


332-335: Use ternary operator box_input = cells.reshape([-1, 3, 3]) if cells is not None else None instead of if-else-block

Replace if-else-block with box_input = cells.reshape([-1, 3, 3]) if cells is not None else None

(SIM108)

deepmd/jax/utils/auto_batch_size.py

55-59: Return the condition directly

Inline condition

(SIM103)

🔇 Additional comments (6)
deepmd/jax/utils/auto_batch_size.py (3)

1-11: LGTM! Clean file structure and imports.

The file has proper licensing, clean imports, and clear class inheritance structure.


24-32: LGTM! Well-structured constructor.

The constructor is properly documented with clear parameter descriptions and correctly calls the parent class initialization.


34-42: LGTM! Effective GPU availability check.

The implementation is simple and effective, with proper type hints and documentation.

deepmd/dpmodel/model/make_model.py (3)

3-3: LGTM: Type hint import addition

The addition of Callable from typing module is appropriate for type-hinting the function parameter.


43-129: Well-structured function extraction with good practices

The new model_call_from_call_lower function demonstrates:

  • Clear separation of concerns
  • Comprehensive documentation
  • Strong type hints
  • Logical flow of operations

43-129: Verify integration with other components

Let's verify the usage of this new function across the codebase:

✅ Verification successful

Refactoring is properly integrated and used

The function model_call_from_call_lower is correctly integrated into the codebase:

  • It's imported and used in deepmd/jax/model/hlo.py
  • The implementation is consistent with its usage
  • The git history shows this is a refactoring that maintains the same functionality
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other potential uses of model_call_from_call_lower
rg "model_call_from_call_lower" --type python

# Search for other implementations of similar functionality that might need refactoring
ast-grep --pattern 'def call($params) {
  $$$
  normalize_coord($$$)
  $$$
  extend_coord_with_ghosts($$$)
  $$$
}'

Length of output: 194


Script:

#!/bin/bash
# Search for usages of model_call_from_call_lower without file type flag
rg "model_call_from_call_lower"

# Search for similar patterns in Python files
rg "normalize_coord.*extend_coord_with_ghosts" -A 5

# Look for imports of this function
rg "from.*make_model.*import.*model_call_from_call_lower"

Length of output: 426


Script:

#!/bin/bash
# Check the context of usage in hlo.py
rg -B 5 -A 10 "model_call_from_call_lower" deepmd/jax/model/hlo.py

# Check if this is a new file or modification by looking at git history
git log -p deepmd/dpmodel/model/make_model.py | grep -A 10 "model_call_from_call_lower"

# Look for any similar patterns in the codebase
rg "normalize_coord|extend_coord_with_ghosts|build_neighbor_list" --no-filename

Length of output: 1660

deepmd/jax/model/hlo.py Show resolved Hide resolved
deepmd/jax/model/hlo.py Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Outdated Show resolved Hide resolved
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between d35198b and 2e3f3ad.

📒 Files selected for processing (1)
  • deepmd/jax/infer/deep_eval.py (1 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/jax/infer/deep_eval.py

303-306: Use ternary operator natoms = len(atom_types[0]) if mixed_type else len(atom_types) instead of if-else-block

Replace if-else-block with natoms = len(atom_types[0]) if mixed_type else len(atom_types)

(SIM108)


332-335: Use ternary operator box_input = cells.reshape([-1, 3, 3]) if cells is not None else None instead of if-else-block

Replace if-else-block with box_input = cells.reshape([-1, 3, 3]) if cells is not None else None

(SIM108)

deepmd/jax/infer/deep_eval.py Outdated Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Outdated Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Outdated Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Show resolved Hide resolved
deepmd/jax/infer/deep_eval.py Show resolved Hide resolved
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz enabled auto-merge October 29, 2024 18:51
@njzjz njzjz added this pull request to the merge queue Oct 29, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to a conflict with the base branch Oct 29, 2024
@njzjz njzjz enabled auto-merge October 29, 2024 22:04
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (3)
deepmd/dpmodel/model/make_model.py (1)

43-129: Add input validation for better error handling

While the implementation is solid, consider adding input validation to handle edge cases:

  • Validate that input arrays have correct shapes and types
  • Check if sel list is non-empty
  • Ensure rcut is positive

Example validation:

 def model_call_from_call_lower(
     *,  # enforce keyword-only arguments
     call_lower: Callable[
         [
             np.ndarray,
             np.ndarray,
             np.ndarray,
             Optional[np.ndarray],
             Optional[np.ndarray],
             bool,
         ],
         dict[str, np.ndarray],
     ],
     rcut: float,
     sel: list[int],
     mixed_types: bool,
     model_output_def: ModelOutputDef,
     coord: np.ndarray,
     atype: np.ndarray,
     box: Optional[np.ndarray] = None,
     fparam: Optional[np.ndarray] = None,
     aparam: Optional[np.ndarray] = None,
     do_atomic_virial: bool = False,
 ):
+    if not isinstance(coord, np.ndarray) or not isinstance(atype, np.ndarray):
+        raise TypeError("coord and atype must be numpy arrays")
+    if not sel:
+        raise ValueError("sel list cannot be empty")
+    if rcut <= 0:
+        raise ValueError("rcut must be positive")
     nframes, nloc = atype.shape[:2]
deepmd/jax/infer/deep_eval.py (2)

162-165: Add missing return type annotations

The methods get_has_efield and get_ntypes_spin lack return type annotations. Including return type annotations improves code readability and enables better type checking.

Apply this diff to add the return type annotations:

-def get_has_efield(self):
+def get_has_efield(self) -> bool:
     """Check if the model has efield."""
     return False

-def get_ntypes_spin(self):
+def get_ntypes_spin(self) -> int:
     """Get the number of spin atom types of this model."""
     return 0

Also applies to: 166-169


347-348: Clarify or remove commented-out code

The commented-out code and accompanying comment may cause confusion:

# it seems not doing conversion
# dp_name = self._OUTDEF_DP2BACKEND[odef.name]

Please consider clarifying the intent behind this code or removing it if it's no longer needed.

Apply this diff to remove the commented-out code:

-            # it seems not doing conversion
-            # dp_name = self._OUTDEF_DP2BACKEND[odef.name]
             dp_name = odef.name
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 2e3f3ad and 748913b.

📒 Files selected for processing (6)
  • deepmd/backend/jax.py (2 hunks)
  • deepmd/dpmodel/descriptor/se_e2_a.py (1 hunks)
  • deepmd/dpmodel/model/make_model.py (3 hunks)
  • deepmd/jax/env.py (2 hunks)
  • deepmd/jax/infer/deep_eval.py (1 hunks)
  • source/tests/consistent/io/test_io.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • deepmd/backend/jax.py
  • deepmd/dpmodel/descriptor/se_e2_a.py
  • deepmd/jax/env.py
  • source/tests/consistent/io/test_io.py
🧰 Additional context used
🪛 Ruff
deepmd/jax/infer/deep_eval.py

303-306: Use ternary operator natoms = len(atom_types[0]) if mixed_type else len(atom_types) instead of if-else-block

Replace if-else-block with natoms = len(atom_types[0]) if mixed_type else len(atom_types)

(SIM108)


332-335: Use ternary operator box_input = cells.reshape([-1, 3, 3]) if cells is not None else None instead of if-else-block

Replace if-else-block with box_input = cells.reshape([-1, 3, 3]) if cells is not None else None

(SIM108)

🔇 Additional comments (3)
deepmd/dpmodel/model/make_model.py (3)

1-4: LGTM: Import changes enhance type safety

The addition of Callable from typing module improves type safety and documentation of function signatures.


227-238: LGTM: Well-structured refactoring

The refactoring of the call method improves code organization by:

  • Separating complex logic into a dedicated function
  • Maintaining clear parameter passing
  • Preserving existing functionality

227-238: Verify impact on existing callers

Let's ensure all callers are compatible with the refactored implementation.

✅ Verification successful

Let me analyze the search results. The code shows that the call method is used in various test files and model implementations. The changes in make_model.py involve refactoring the model_call_from_call_lower function. Looking at the test files, they primarily use the call method for testing purposes and comparing outputs between different implementations.

Let me verify if there are any direct usages of model_call_from_call_lower in the codebase:


Let me check the implementation in jax/model/hlo.py to understand how it uses the refactored function:


Based on the analysis of the codebase, I can now provide the final response:

Changes to model_call_from_call_lower are properly synchronized

The refactored implementation in make_model.py is correctly synchronized with its only other usage in jax/model/hlo.py. Both implementations use the same parameter structure and calling convention. The changes are localized to these two files and all test cases that exercise this functionality continue to use the stable call interface.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for direct calls to CM.call method
rg -l "\.call\s*\(" | while read -r file; do
  if [[ "$file" != "deepmd/dpmodel/model/make_model.py" ]]; then
    echo "=== $file ==="
    rg -A 5 "\.call\s*\(" "$file"
  fi
done

Length of output: 20616


Script:

#!/bin/bash
# Search for direct usages of model_call_from_call_lower
rg "model_call_from_call_lower"

Length of output: 312


Script:

#!/bin/bash
# Check the implementation in jax/model/hlo.py
rg -A 10 "model_call_from_call_lower" deepmd/jax/model/hlo.py

Length of output: 633

deepmd/jax/infer/deep_eval.py Outdated Show resolved Hide resolved
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz added this pull request to the merge queue Oct 30, 2024
Merged via the queue into deepmodeling:devel with commit d165fee Oct 30, 2024
61 checks passed
@njzjz njzjz deleted the hlo branch October 30, 2024 02:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants