Skip to content

Commit

Permalink
fix(dpmodel/jax): fix fparam and aparam support in DeepEval (#4285)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced error messages for improved clarity when input dimensions are
incorrect.
- Added support for optional fitting and atomic parameters in model
evaluations.

- **Bug Fixes**
- Removed restrictions on providing fitting and atomic parameters,
allowing for more flexible evaluations.

- **Tests**
- Introduced a new test class to validate the handling of fitting and
atomic parameters in model evaluations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 31, 2024
1 parent 9c767ad commit ff04d8b
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 16 deletions.
8 changes: 4 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def _call_common(
assert fparam is not None, "fparam should not be None"
if fparam.shape[-1] != self.numb_fparam:
raise ValueError(
"get an input fparam of dim {fparam.shape[-1]}, ",
"which is not consistent with {self.numb_fparam}.",
f"get an input fparam of dim {fparam.shape[-1]}, "
f"which is not consistent with {self.numb_fparam}."
)
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
fparam = xp.tile(
Expand All @@ -409,8 +409,8 @@ def _call_common(
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
"get an input aparam of dim {aparam.shape[-1]}, ",
"which is not consistent with {self.numb_aparam}.",
f"get an input aparam of dim {aparam.shape[-1]}, "
f"which is not consistent with {self.numb_aparam}."
)
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std
Expand Down
21 changes: 17 additions & 4 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def eval(
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
if fparam is not None or aparam is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
Expand All @@ -216,7 +214,7 @@ def eval(
)
request_defs = self._get_request_defs(atomic)
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, request_defs
coords, cells, atom_types, fparam, aparam, request_defs
)
return dict(
zip(
Expand Down Expand Up @@ -306,6 +304,8 @@ def _eval_model(
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
request_defs: list[OutputVariableDef],
):
model = self.dp
Expand All @@ -323,12 +323,25 @@ def _eval_model(
box_input = cells.reshape([-1, 3, 3])
else:
box_input = None
if fparam is not None:
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
else:
fparam_input = None
if aparam is not None:
aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
else:
aparam_input = None

do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
)
batch_output = model(
coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial
coord_input,
type_input,
box=box_input,
fparam=fparam_input,
aparam=aparam_input,
do_atomic_virial=do_atomic_virial,
)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]
Expand Down
16 changes: 13 additions & 3 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ def eval(
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
if fparam is not None or aparam is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
Expand All @@ -226,7 +224,7 @@ def eval(
)
request_defs = self._get_request_defs(atomic)
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, request_defs
coords, cells, atom_types, fparam, aparam, request_defs
)
return dict(
zip(
Expand Down Expand Up @@ -316,6 +314,8 @@ def _eval_model(
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
request_defs: list[OutputVariableDef],
):
model = self.dp
Expand All @@ -333,6 +333,14 @@ def _eval_model(
box_input = cells.reshape([-1, 3, 3])
else:
box_input = None
if fparam is not None:
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
else:
fparam_input = None
if aparam is not None:
aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
else:
aparam_input = None

do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
Expand All @@ -341,6 +349,8 @@ def _eval_model(
to_jax_array(coord_input),
to_jax_array(type_input),
box=to_jax_array(box_input),
fparam=to_jax_array(fparam_input),
aparam=to_jax_array(aparam_input),
do_atomic_virial=do_atomic_virial,
)
if isinstance(batch_output, tuple):
Expand Down
8 changes: 3 additions & 5 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,16 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
model_def_script = data["model_def_script"]
call_lower = model.call_lower

nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape(
"nf, nloc, nghost, nfp, nap"
)
nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")
exported = jax_export.export(jax.jit(call_lower))(
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
jax.ShapeDtypeStruct((nf, nfp), jnp.float64)
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
if model.get_dim_fparam()
else None, # fparam
jax.ShapeDtypeStruct((nf, nap), jnp.float64)
jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
if model.get_dim_aparam()
else None, # aparam
False, # do_atomic_virial
Expand Down
56 changes: 56 additions & 0 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def test_deep_eval(self):
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
dtype=GLOBAL_NP_FLOAT_PRECISION,
).reshape(1, 9)
natoms = self.atype.shape[1]
nframes = self.atype.shape[0]
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
rets = []
for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"):
Expand All @@ -145,10 +147,20 @@ def test_deep_eval(self):
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(prefix + backend.suffixes[0], reference_data)
deep_eval = DeepEval(prefix + backend.suffixes[0])
if deep_eval.get_dim_fparam() > 0:
fparam = np.ones((nframes, deep_eval.get_dim_fparam()))
else:
fparam = None
if deep_eval.get_dim_aparam() > 0:
aparam = np.ones((nframes, natoms, deep_eval.get_dim_aparam()))
else:
aparam = None
ret = deep_eval.eval(
self.coords,
self.box,
self.atype,
fparam=fparam,
aparam=aparam,
)
rets.append(ret)
for ret in rets[1:]:
Expand Down Expand Up @@ -199,3 +211,47 @@ def setUp(self):

def tearDown(self):
IOTest.tearDown(self)


class TestDeepPotFparamAparam(unittest.TestCase, IOTest):
def setUp(self):
model_def_script = {
"type_map": ["O", "H"],
"descriptor": {
"type": "se_e2_a",
"sel": [20, 20],
"rcut_smth": 0.50,
"rcut": 6.00,
"neuron": [
3,
6,
],
"resnet_dt": False,
"axis_neuron": 2,
"precision": "float64",
"type_one_side": True,
"seed": 1,
},
"fitting_net": {
"type": "ener",
"neuron": [
5,
5,
],
"resnet_dt": True,
"precision": "float64",
"atom_ener": [],
"seed": 1,
"numb_fparam": 2,
"numb_aparam": 2,
},
}
model = get_model(copy.deepcopy(model_def_script))
self.data = {
"model": model.serialize(),
"backend": "test",
"model_def_script": model_def_script,
}

def tearDown(self):
IOTest.tearDown(self)

0 comments on commit ff04d8b

Please sign in to comment.