-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(jax): fix several serialization and jit issues for DPA-2 #4315
base: devel
Are you sure you want to change the base?
Conversation
- `deepmd/jax/descriptor/__init__.py` imports SeT and DPA-2 to let them found by the plugin; - `deepmd/dpmodel/descriptor/dpa1.py` fixes the jit issue regarding to the shape generated by `jnp.prod`. The shape should be static by using `math.prod`. - `deepmd/jax/model/ener_model.py` and `deepmd/jax/model/dp_zbl_model.py` stop the graident of coordinates when rebuilding the neighbor list. The graient of sort causes an error due to jax-ml/jax#24730. Signed-off-by: Jinzhe Zeng <[email protected]>
The universal test may need to add JAX models and test saving to the SavedModel. (again, we cannot run TF1 and TF2 at the same session. The regular jit doesn't trigger the thrid error, only saving the model does.) |
📝 WalkthroughWalkthroughThe changes in this pull request involve several modifications across multiple files, primarily focusing on the implementation of new methods and adjustments to existing methods within various classes. Key updates include the introduction of the Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/jax/model/ener_model.py (1)
53-66
: Add documentation explaining the stop_gradient usage.The implementation correctly prevents gradient computation during neighbor list rebuilding by using
jax.lax.stop_gradient
onextended_coord
. This addresses the gradient sorting issue referenced in JAX issue #24730.Consider adding a docstring explaining why
stop_gradient
is necessary to help future maintainers understand this critical implementation detail.Example docstring:
def format_nlist( self, extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, extra_nlist_sort: bool = False, ): """Format neighbor list while preventing gradient computation during rebuilding. Uses stop_gradient on coordinates to prevent gradient sorting issues (see JAX issue #24730) during neighbor list rebuilding. Args: extended_coord: Extended coordinates array extended_atype: Extended atom types array nlist: Neighbor list array extra_nlist_sort: Optional flag for additional sorting """deepmd/jax/model/dp_zbl_model.py (1)
53-66
: Consider adding docstring documentation.Adding documentation would help future maintainers understand:
- The purpose of stopping gradients
- The relationship with JAX issue #24730
- The impact on gradient computation
Here's a suggested docstring:
def format_nlist( self, extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, extra_nlist_sort: bool = False, ): + """Format neighbor list while preventing gradient computation for coordinates. + + This prevents gradient computation during neighbor list rebuilding to avoid + gradient sorting issues (see JAX #24730). The gradients are stopped only for + the coordinate inputs used in neighbor list operations. + """ return DPZBLModelDP.format_nlist(
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
deepmd/dpmodel/descriptor/dpa1.py
(3 hunks)deepmd/jax/descriptor/__init__.py
(2 hunks)deepmd/jax/model/dp_zbl_model.py
(2 hunks)deepmd/jax/model/ener_model.py
(2 hunks)
🔇 Additional comments (8)
deepmd/jax/descriptor/__init__.py (3)
28-28
: LGTM: Updated exports list.
The __all__
list is correctly updated to include both new descriptors, making them part of the public API.
Also applies to: 30-30
20-22
: LGTM: New SeT-TEBD descriptor import.
The import of DescrptSeTTebd
aligns with the PR objectives to enhance the plugin system.
Let's verify the module exists:
#!/bin/bash
# Verify the se_t_tebd module exists and contains the expected class
fd "se_t_tebd.py" --type f --exec ast-grep --pattern 'class DescrptSeTTebd'
5-7
: LGTM: New DPA-2 descriptor import.
The import of DescrptDPA2
aligns with the PR objectives to enhance the plugin system.
Let's verify the module exists:
✅ Verification successful
LGTM: The DPA2 descriptor import is correctly implemented.
The import of DescrptDPA2
in deepmd/jax/descriptor/__init__.py
is properly implemented and matches the class definition in dpa2.py
. The class is also correctly exposed in the module's __all__
list.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Verify the DPA2 module exists and contains the expected class
fd "dpa2.py" --type f --exec ast-grep --pattern 'class DescrptDPA2'
Length of output: 75
Script:
#!/bin/bash
# Let's first check if the file exists
fd "dpa2.py" --type f
# If the file exists, let's examine its content
fd "dpa2.py" --type f --exec cat {}
Length of output: 115300
Script:
#!/bin/bash
# Let's check if the DPA2 class exists in the JAX descriptor module
rg "class DescrptDPA2" deepmd/jax/descriptor/dpa2.py
# Let's also verify the imports in the JAX descriptor module
cat deepmd/jax/descriptor/__init__.py
Length of output: 758
deepmd/jax/model/ener_model.py (1)
15-15
: LGTM: Import addition is necessary.
The addition of the jax
import is required for using jax.lax.stop_gradient
in the new format_nlist
method.
deepmd/jax/model/dp_zbl_model.py (2)
15-15
: LGTM: Import addition is appropriate.
The addition of the jax
import is necessary for using jax.lax.stop_gradient
in the new format_nlist
method.
53-66
: Implementation looks correct, verify the impact.
The use of jax.lax.stop_gradient
on extended_coord
effectively prevents gradient computation during neighbor list rebuilding, which should resolve the gradient sorting issue (JAX #24730).
Let's verify that this is the only place where neighbor list formatting occurs in this model:
✅ Verification successful
Implementation is correct and consistent across the codebase
The verification shows that stop_gradient
is consistently applied to extended_coord
in both JAX model implementations (dp_zbl_model.py
and ener_model.py
). This is the correct approach to prevent gradient computation during neighbor list rebuilding, while the base implementations in dpmodel
and pt.model
handle their respective frameworks appropriately.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for other potential neighbor list formatting operations
ast-grep --pattern 'format_nlist'
Length of output: 1089
Script:
#!/bin/bash
# Let's check the implementation of format_nlist in the parent classes
ast-grep --pattern $'class DPZBLModelDP {
$$$
format_nlist($$$) {
$$$
}
$$$
}'
ast-grep --pattern $'class EnergyModelDP {
$$$
format_nlist($$$) {
$$$
}
$$$
}'
Length of output: 184
Script:
#!/bin/bash
# Let's try a different approach to find the parent class implementations
rg "class (DPZBLModelDP|EnergyModelDP)" -A 10
rg "def format_nlist" -B 5 -A 10
Length of output: 4210
deepmd/dpmodel/descriptor/dpa1.py (2)
856-856
: Good optimization: Using math.prod
for shape calculations.
Using math.prod
instead of xp.prod(xp.asarray(...))
is a better choice for shape calculations as it:
- Works with static shapes during JIT compilation
- Avoids unnecessary array conversion
- Is more efficient for small tuples
870-870
: LGTM: Consistent use of math.prod
for shape calculations.
The change is consistent with the optimization in cal_g
, ensuring uniform behavior across both methods.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4315 +/- ##
==========================================
- Coverage 84.40% 84.23% -0.18%
==========================================
Files 570 570
Lines 53071 53078 +7
Branches 3054 3054
==========================================
- Hits 44794 44709 -85
- Misses 7318 7410 +92
Partials 959 959 ☔ View full report in Codecov by Sentry. |
deepmd/jax/descriptor/__init__.py
imports SeT and DPA-2 to let them found by the plugin;deepmd/dpmodel/descriptor/dpa1.py
fixes the jit issue regarding to the shape generated byjnp.prod
. The shape should be static by usingmath.prod
.deepmd/jax/model/ener_model.py
anddeepmd/jax/model/dp_zbl_model.py
stop the graident of coordinates when rebuilding the neighbor list. The graient of sort causes an error due to InconclusiveDimensionOperation: Symbolic dimension comparison 'b' < '2147483647' is inconclusive. jax-ml/jax#24730.Summary by CodeRabbit
New Features
format_nlist
inDPZBLModel
andEnergyModel
classes for improved neighbor list formatting.DescrptDPA2
andDescrptSeTTebd
to the public API.Bug Fixes
DPZBLModel
andEnergyModel
to ensure proper serialization and deserialization ofatomic_model
.Documentation