-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(dpmodel/jax): fix fparam and aparam support in DeepEval #4285
fix(dpmodel/jax): fix fparam and aparam support in DeepEval #4285
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 WalkthroughWalkthroughThe pull request introduces modifications across several files related to the Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (2)
deepmd/dpmodel/infer/deep_eval.py (1)
326-333
: Consider adding dimension validation.While the reshaping logic is correct, consider adding validation to ensure the input dimensions match the model's expectations:
fparam.shape[-1] == self.get_dim_fparam()
aparam.shape[-1] == self.get_dim_aparam()
Example implementation:
if fparam is not None: + if fparam.shape[-1] != self.get_dim_fparam(): + raise ValueError(f"fparam dimension {fparam.shape[-1]} does not match model's dim_fparam {self.get_dim_fparam()}") fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) else: fparam_input = None if aparam is not None: + if aparam.shape[-1] != self.get_dim_aparam(): + raise ValueError(f"aparam dimension {aparam.shape[-1]} does not match model's dim_aparam {self.get_dim_aparam()}") aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam()) else: aparam_input = Nonedeepmd/jax/infer/deep_eval.py (1)
336-343
: Add dimension validation for fparam and aparamWhile the reshaping logic is correct, consider adding validation for input dimensions to provide clearer error messages. This would help users identify dimension mismatches earlier.
Add validation before reshaping:
if fparam is not None: + if fparam.size % self.get_dim_fparam() != 0: + raise ValueError(f"fparam size {fparam.size} is not divisible by dim_fparam {self.get_dim_fparam()}") fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) else: fparam_input = None if aparam is not None: + expected_size = nframes * natoms * self.get_dim_aparam() + if aparam.size != expected_size: + raise ValueError(f"aparam size {aparam.size} does not match expected size {expected_size}") aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam()) else: aparam_input = None
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
deepmd/dpmodel/fitting/general_fitting.py
(2 hunks)deepmd/dpmodel/infer/deep_eval.py
(3 hunks)deepmd/jax/infer/deep_eval.py
(4 hunks)deepmd/jax/utils/serialization.py
(1 hunks)source/tests/consistent/io/test_io.py
(3 hunks)
🔇 Additional comments (12)
deepmd/jax/utils/serialization.py (2)
54-54
: LGTM: Improved flexibility by removing hardcoded shape parameters.
The change to use only necessary shape parameters (nf, nloc, nghost
) is more maintainable, as fparam and aparam dimensions are now dynamically obtained from the model.
60-65
: LGTM: Proper handling of optional fparam and aparam parameters.
The implementation correctly handles the shape structures for both fparam and aparam, with appropriate null checks and dimensions.
Let's verify the consistency of fparam and aparam usage across the codebase:
✅ Verification successful
Based on the search results, I can see comprehensive evidence of consistent handling of fparam and aparam dimensions across the codebase:
-
The get_dim_fparam() and get_dim_aparam() methods are consistently implemented across different model classes and frameworks (JAX, PyTorch, TensorFlow)
-
The shape handling in the review matches the expected tensor shapes used throughout the codebase:
- fparam: (nf, dim_fparam)
- aparam: (nf, nloc, dim_aparam)
-
The null checks (if model.get_dim_fparam() and if model.get_dim_aparam()) are consistently used before accessing these parameters
-
The dimensions are properly propagated through model compression, testing, and API layers
LGTM: The fparam and aparam shape handling is consistent with the codebase patterns.
The implementation correctly handles optional parameters with proper null checks and maintains consistent tensor shapes across all usage points.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify consistent handling of fparam and aparam dimensions
# Expected: All dimension getters should be used consistently with shape structs
# Check for any inconsistencies in fparam dimension handling
rg -A 2 "get_dim_fparam|dim_fparam"
# Check for any inconsistencies in aparam dimension handling
rg -A 2 "get_dim_aparam|dim_aparam"
Length of output: 127297
source/tests/consistent/io/test_io.py (3)
139-140
: LGTM! Well-structured test enhancements.
The changes properly handle both supported and unsupported cases for fparam
and aparam
:
- Correctly extracts array dimensions from input data
- Creates parameters only when dimensions > 0
- Uses appropriate array shapes for both parameter types
Also applies to: 150-157, 162-163
216-254
: LGTM! Well-structured test class.
The test class properly sets up a test environment for validating fparam and aparam support:
- Clear test scope through dedicated class
- Proper initialization of test parameters
- Good inheritance structure
256-257
: LGTM! Proper test cleanup.
The tearDown method correctly inherits and calls the parent class cleanup logic.
deepmd/dpmodel/infer/deep_eval.py (2)
217-217
: LGTM: Parameter forwarding is correct.
The fparam
and aparam
parameters are correctly forwarded to the _eval_model
method.
339-344
: Verify model interface compatibility.
The implementation looks correct. Let's verify that all model implementations support the fparam
and aparam
parameters in their call interface.
✅ Verification successful
Based on the search results, I can now generate a final response. The code shows that all model implementations in the codebase properly support fparam
and aparam
parameters through their base classes and interfaces:
Model interface compatibility is verified.
The codebase shows consistent support for fparam
and aparam
parameters across all model implementations:
- All atomic models inherit from
BaseAtomicModel
which defines the interface withfparam
andaparam
support - The parameters are properly propagated through the model hierarchy:
- Base models define abstract methods
get_dim_fparam()
andget_dim_aparam()
- Concrete implementations like
DPAtomicModel
,LinearEnergyAtomicModel
, etc. implement these methods - Models handle both parameters in their
__call__
and forward methods with proper type checking and shape validation
- Base models define abstract methods
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if all model implementations support fparam and aparam parameters
# Search for model class definitions and their call methods
ast-grep --pattern 'class $_ {
$$$
def __call__($_, fparam=$_, aparam=$_) {
$$$
}
$$$
}'
Length of output: 102
Script:
#!/bin/bash
# Let's try a different approach to find model implementations and their call signatures
# First, let's find all Python files that might contain model implementations
echo "=== Files containing 'class' and '__call__' ==="
rg -l "class.*:.*def __call__" --type py
echo -e "\n=== Model class implementations with their call signatures ==="
rg "class.*:.*\n.*\n.*def __call__" -A 5 --type py
Length of output: 545
Script:
#!/bin/bash
# Let's try a simpler approach first to find model classes
echo "=== Python files with potential model implementations ==="
rg "class.*Model" --type py -l
echo -e "\n=== Model class implementations ==="
for file in $(rg "class.*Model" --type py -l); do
echo "=== File: $file ==="
# Print the class definition and a few lines after it
rg "class.*Model" "$file" -A 10
done
# Also search specifically for __call__ methods with fparam/aparam
echo -e "\n=== Call methods with fparam/aparam ==="
rg "def __call__.*(?:fparam|aparam)" --type py -A 5
Length of output: 136305
Script:
#!/bin/bash
# Let's search for __call__ methods in model implementations with a more focused approach
echo "=== Searching for __call__ methods in model files ==="
rg "def __call__" --type py deepmd/dpmodel/model/ deepmd/dpmodel/atomic_model/ -A 5
echo -e "\n=== Searching specifically for fparam/aparam usage in model files ==="
rg "fparam|aparam" --type py deepmd/dpmodel/model/ deepmd/dpmodel/atomic_model/ -C 2
Length of output: 29695
deepmd/jax/infer/deep_eval.py (3)
227-227
: LGTM: Parameters correctly passed to _eval_model
The addition of fparam
and aparam
parameters to the _eval_model
call is consistent with the method signature and documentation.
317-318
: LGTM: Method signature properly updated
The addition of fparam
and aparam
parameters with correct type hints (Optional[np.ndarray]) aligns with the parent method signature.
352-353
: LGTM: Proper JAX array conversion and parameter passing
The parameters are correctly converted to JAX arrays and passed to the model with appropriate names.
deepmd/dpmodel/fitting/general_fitting.py (2)
391-392
: LGTM! Clear and informative error message.
The error message is well-formatted using f-strings and provides clear information about the dimension mismatch by showing both the actual and expected values.
412-413
: LGTM! Consistent error message style.
The error message maintains consistency with the fparam error message style, using f-strings and providing clear information about the dimension mismatch.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4285 +/- ##
==========================================
+ Coverage 84.29% 84.31% +0.01%
==========================================
Files 553 553
Lines 51820 51828 +8
Branches 3052 3052
==========================================
+ Hits 43683 43699 +16
+ Misses 7177 7169 -8
Partials 960 960 ☔ View full report in Codecov by Sentry. |
For the frozen model, store two exported functions: one enables do_atomic_virial and the other doesn't. This PR is in conflict with #4285 (in `serialization.py`), and the conflict must be resolved after one is merged. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new parameter for enhanced atomic virial data handling in model evaluations. - Added support for atomic virial calculations in multiple model evaluation methods. - Updated export functionality to dynamically include atomic virial data based on user input. - **Bug Fixes** - Improved output structures across various backends to accommodate new atomic virial data. - **Tests** - Enhanced test cases to verify the new atomic virial functionalities and ensure compatibility with existing evaluations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Summary by CodeRabbit
New Features
Bug Fixes
Tests