From 38815b371162a2153b0c2b24a38867825665c7a3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 4 Nov 2024 15:15:03 -0500 Subject: [PATCH 1/4] feat(jax): export call_lower to SavedModel via jax2tf (#4254) ## Summary by CodeRabbit ## Release Notes - **New Features** - Added support for the TensorFlow SavedModel format, allowing users to handle additional model file types. - Introduced a new TensorFlow model wrapper class for enhanced integration with JAX functionalities. - **Bug Fixes** - Improved error handling for unsupported file formats during model deserialization. - **Documentation** - Updated backend documentation to reflect new file extensions and clarify backend capabilities. - **Tests** - Enhanced test structure for better clarity and maintainability regarding backend handling. - Added a new job for testing TensorFlow 2 in eager mode within the testing workflow. - Introduced a conditional skip for tests based on TensorFlow 2 compatibility. --------- Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 18 +- deepmd/backend/jax.py | 2 +- deepmd/jax/infer/deep_eval.py | 27 ++- deepmd/jax/jax2tf/__init__.py | 11 + deepmd/jax/jax2tf/serialization.py | 172 ++++++++++++++ deepmd/jax/jax2tf/tfmodel.py | 325 ++++++++++++++++++++++++++ deepmd/jax/utils/serialization.py | 12 +- doc/backend.md | 3 +- pyproject.toml | 1 + source/tests/consistent/io/test_io.py | 18 +- source/tests/utils.py | 1 + 11 files changed, 568 insertions(+), 22 deletions(-) create mode 100644 deepmd/jax/jax2tf/__init__.py create mode 100644 deepmd/jax/jax2tf/serialization.py create mode 100644 deepmd/jax/jax2tf/tfmodel.py diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index e46bddd98a..422dcb5f17 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -25,19 +25,23 @@ jobs: python-version: ${{ matrix.python }} - run: python -m pip install -U uv - run: | - source/install/uv_with_retry.sh pip install --system mpich + source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu + export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py + source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py + source/install/uv_with_retry.sh pip install --system horovod --no-build-isolation env: # Please note that uv has some issues with finding # existing TensorFlow package. Currently, it uses # TensorFlow in the build dependency, but if it # changes, setting `TENSORFLOW_ROOT`. - TENSORFLOW_VERSION: 2.16.1 DP_ENABLE_PYTORCH: 1 DP_BUILD_TESTING: 1 - UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/njzjz/simple https://pypi.anaconda.org/mpi4py/simple" + UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/mpi4py/simple" + HOROVOD_WITH_TENSORFLOW: 1 + HOROVOD_WITHOUT_PYTORCH: 1 + HOROVOD_WITH_MPI: 1 - run: dp --version - name: Get durations from cache uses: actions/cache@v4 @@ -53,6 +57,12 @@ jobs: - run: pytest --cov=deepmd source/tests --durations=0 --splits 6 --group ${{ matrix.group }} --store-durations --durations-path=.test_durations --splitting-algorithm least_duration env: NUM_WORKERS: 0 + - name: Test TF2 eager mode + run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0 + env: + NUM_WORKERS: 0 + DP_TEST_TF2_ONLY: 1 + if: matrix.group == 1 - run: mv .test_durations .test_durations_${{ matrix.group }} - name: Upload partial durations uses: actions/upload-artifact@v4 diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index cfb0936bda..7a714c2090 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".hlo", ".jax"] + suffixes: ClassVar[list[str]] = [".hlo", ".jax", ".savedmodel"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index b60076c68c..fc526a502e 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -90,15 +90,24 @@ def __init__( self.output_def = output_def self.model_path = model_file - model_data = load_dp_model(model_file) - self.dp = HLO( - stablehlo=model_data["@variables"]["stablehlo"].tobytes(), - stablehlo_atomic_virial=model_data["@variables"][ - "stablehlo_atomic_virial" - ].tobytes(), - model_def_script=model_data["model_def_script"], - **model_data["constants"], - ) + if model_file.endswith(".hlo"): + model_data = load_dp_model(model_file) + self.dp = HLO( + stablehlo=model_data["@variables"]["stablehlo"].tobytes(), + stablehlo_atomic_virial=model_data["@variables"][ + "stablehlo_atomic_virial" + ].tobytes(), + model_def_script=model_data["model_def_script"], + **model_data["constants"], + ) + elif model_file.endswith(".savedmodel"): + from deepmd.jax.jax2tf.tfmodel import ( + TFModelWrapper, + ) + + self.dp = TFModelWrapper(model_file) + else: + raise ValueError("Unsupported file extension") self.rcut = self.dp.get_rcut() self.type_map = self.dp.get_type_map() if isinstance(auto_batch_size, bool): diff --git a/deepmd/jax/jax2tf/__init__.py b/deepmd/jax/jax2tf/__init__.py new file mode 100644 index 0000000000..88a928f04d --- /dev/null +++ b/deepmd/jax/jax2tf/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf + +if not tf.executing_eagerly(): + # TF disallow temporary eager execution + raise RuntimeError( + "Unfortunatly, jax2tf (requires eager execution) cannot be used with the " + "TensorFlow backend (disables eager execution). " + "If you are converting a model between different backends, " + "considering converting to the `.dp` format first." + ) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py new file mode 100644 index 0000000000..dff43a11fc --- /dev/null +++ b/deepmd/jax/jax2tf/serialization.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json + +import tensorflow as tf +from jax.experimental import ( + jax2tf, +) + +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +def deserialize_to_file(model_file: str, data: dict) -> None: + """Deserialize the dictionary to a model file. + + Parameters + ---------- + model_file : str + The model file to be saved. + data : dict + The dictionary to be deserialized. + """ + if model_file.endswith(".savedmodel"): + model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] + call_lower = model.call_lower + + tf_model = tf.Module() + + def exported_whether_do_atomic_virial(do_atomic_virial): + def call_lower_with_fixed_do_atomic_virial( + coord, atype, nlist, mapping, fparam, aparam + ): + return call_lower( + coord, + atype, + nlist, + mapping, + fparam, + aparam, + do_atomic_virial=do_atomic_virial, + ) + + return jax2tf.convert( + call_lower_with_fixed_do_atomic_virial, + polymorphic_shapes=[ + "(nf, nloc + nghost, 3)", + "(nf, nloc + nghost)", + f"(nf, nloc, {model.get_nnei()})", + "(nf, nloc + nghost)", + f"(nf, {model.get_dim_fparam()})", + f"(nf, nloc, {model.get_dim_aparam()})", + ], + with_gradient=True, + ) + + # Save a function that can take scalar inputs. + # We need to explicit set the function name, so C++ can find it. + @tf.function( + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_lower_without_atomic_virial( + coord, atype, nlist, mapping, fparam, aparam + ): + return exported_whether_do_atomic_virial(do_atomic_virial=False)( + coord, atype, nlist, mapping, fparam, aparam + ) + + tf_model.call_lower = call_lower_without_atomic_virial + + @tf.function( + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): + return exported_whether_do_atomic_virial(do_atomic_virial=True)( + coord, atype, nlist, mapping, fparam, aparam + ) + + tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial + + # set functions to export other attributes + @tf.function + def get_type_map(): + return tf.constant(model.get_type_map(), dtype=tf.string) + + tf_model.get_type_map = get_type_map + + @tf.function + def get_rcut(): + return tf.constant(model.get_rcut(), dtype=tf.double) + + tf_model.get_rcut = get_rcut + + @tf.function + def get_dim_fparam(): + return tf.constant(model.get_dim_fparam(), dtype=tf.int64) + + tf_model.get_dim_fparam = get_dim_fparam + + @tf.function + def get_dim_aparam(): + return tf.constant(model.get_dim_aparam(), dtype=tf.int64) + + tf_model.get_dim_aparam = get_dim_aparam + + @tf.function + def get_sel_type(): + return tf.constant(model.get_sel_type(), dtype=tf.int64) + + tf_model.get_sel_type = get_sel_type + + @tf.function + def is_aparam_nall(): + return tf.constant(model.is_aparam_nall(), dtype=tf.bool) + + tf_model.is_aparam_nall = is_aparam_nall + + @tf.function + def model_output_type(): + return tf.constant(model.model_output_type(), dtype=tf.string) + + tf_model.model_output_type = model_output_type + + @tf.function + def mixed_types(): + return tf.constant(model.mixed_types(), dtype=tf.bool) + + tf_model.mixed_types = mixed_types + + if model.get_min_nbor_dist() is not None: + + @tf.function + def get_min_nbor_dist(): + return tf.constant(model.get_min_nbor_dist(), dtype=tf.double) + + tf_model.get_min_nbor_dist = get_min_nbor_dist + + @tf.function + def get_sel(): + return tf.constant(model.get_sel(), dtype=tf.int64) + + tf_model.get_sel = get_sel + + @tf.function + def get_model_def_script(): + return tf.constant( + json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string + ) + + tf_model.get_model_def_script = get_model_def_script + tf.saved_model.save( + tf_model, + model_file, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), + ) diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py new file mode 100644 index 0000000000..8f04014a97 --- /dev/null +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +import jax.experimental.jax2tf as jax2tf +import tensorflow as tf + +from deepmd.dpmodel.model.make_model import ( + model_call_from_call_lower, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) +from deepmd.jax.env import ( + jnp, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +OUTPUT_DEFS = { + "energy": OutputVariableDef( + "energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + "mask": OutputVariableDef( + "mask", + shape=[1], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), +} + + +def decode_list_of_bytes(list_of_bytes: list[bytes]) -> list[str]: + """Decode a list of bytes to a list of strings.""" + return [x.decode() for x in list_of_bytes] + + +class TFModelWrapper(tf.Module): + def __init__( + self, + model, + ) -> None: + self.model = tf.saved_model.load(model) + self._call_lower = jax2tf.call_tf(self.model.call_lower) + self._call_lower_atomic_virial = jax2tf.call_tf( + self.model.call_lower_atomic_virial + ) + self.type_map = decode_list_of_bytes(self.model.get_type_map().numpy().tolist()) + self.rcut = self.model.get_rcut().numpy().item() + self.dim_fparam = self.model.get_dim_fparam().numpy().item() + self.dim_aparam = self.model.get_dim_aparam().numpy().item() + self.sel_type = self.model.get_sel_type().numpy().tolist() + self._is_aparam_nall = self.model.is_aparam_nall().numpy().item() + self._model_output_type = decode_list_of_bytes( + self.model.model_output_type().numpy().tolist() + ) + self._mixed_types = self.model.mixed_types().numpy().item() + if hasattr(self.model, "get_min_nbor_dist"): + self.min_nbor_dist = self.model.get_min_nbor_dist().numpy().item() + else: + self.min_nbor_dist = None + self.sel = self.model.get_sel().numpy().tolist() + self.model_def_script = self.model.get_model_def_script().numpy().decode() + + def __call__( + self, + coord: jnp.ndarray, + atype: jnp.ndarray, + box: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ) -> Any: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,jnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return self.call(coord, atype, box, fparam, aparam, do_atomic_virial) + + def call( + self, + coord: jnp.ndarray, + atype: jnp.ndarray, + box: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,jnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return model_call_from_call_lower( + call_lower=self.call_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def model_output_def(self): + return ModelOutputDef( + FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) + ) + + def call_lower( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial + else: + call_lower = self._call_lower + # Attempt to convert a value (None) with an unsupported type () to a Tensor. + if fparam is None: + fparam = jnp.empty( + (extended_coord.shape[0], self.get_dim_fparam()), dtype=jnp.float64 + ) + if aparam is None: + aparam = jnp.empty( + (extended_coord.shape[0], nlist.shape[1], self.get_dim_aparam()), + dtype=jnp.float64, + ) + return call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + + def get_type_map(self) -> list[str]: + """Get the type map.""" + return self.type_map + + def get_rcut(self): + """Get the cut-off radius.""" + return self.rcut + + def get_dim_fparam(self): + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.dim_fparam + + def get_dim_aparam(self): + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.dim_aparam + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.sel_type + + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self._is_aparam_nall + + def model_output_type(self) -> list[str]: + """Get the output type for the model.""" + return self._model_output_type + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + raise NotImplementedError("Not implemented") + + @classmethod + def deserialize(cls, data: dict) -> "TFModelWrapper": + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModel + The deserialized model + """ + raise NotImplementedError("Not implemented") + + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.model_def_script + + def get_min_nbor_dist(self) -> Optional[float]: + """Get the minimum distance between two atoms.""" + return self.min_nbor_dist + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.get_nsel() + + def get_sel(self) -> list[int]: + return self.sel + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return sum(self.sel) + + def mixed_types(self) -> bool: + return self._mixed_types + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statictics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + raise NotImplementedError("Not implemented") + + @classmethod + def get_model(cls, model_params: dict) -> "TFModelWrapper": + """Get the model by the parameters. + + By default, all the parameters are directly passed to the constructor. + If not, override this method. + + Parameters + ---------- + model_params : dict + The model parameters + + Returns + ------- + BaseBaseModel + The model + """ + raise NotImplementedError("Not implemented") diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index ec2de3060e..6ab99a81f0 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -55,13 +55,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial(do_atomic_virial): def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, nlist_start, fparam, aparam + coord, atype, nlist, mapping, fparam, aparam ): return call_lower( coord, atype, nlist, - nlist_start, + mapping, fparam, aparam, do_atomic_virial=do_atomic_virial, @@ -107,8 +107,14 @@ def call_lower_with_fixed_do_atomic_virial( "sel": model.get_sel(), } save_dp_model(filename=model_file, model_dict=data) + elif model_file.endswith(".savedmodel"): + from deepmd.jax.jax2tf.serialization import ( + deserialize_to_file as deserialize_to_savedmodel, + ) + + return deserialize_to_savedmodel(model_file, data) else: - raise ValueError("JAX backend only supports converting .jax directory") + raise ValueError("Unsupported file extension") def serialize_from_file(model_file: str) -> dict: diff --git a/doc/backend.md b/doc/backend.md index cf99eea9cb..3fb70bee90 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -25,11 +25,12 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different ### JAX {{ jax_icon }} -- Model filename extension: `.xlo` +- Model filename extension: `.xlo`, `.savedmodel` - Checkpoint filename extension: `.jax` [JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required. Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions. +`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow. Currently, this backend is developed actively, and has no support for training and the C++ interface. ### DP {{ dpmodel_icon }} diff --git a/pyproject.toml b/pyproject.toml index 1faacb973c..802e920014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -444,6 +444,7 @@ select = [ [tool.uv.sources] mpich = { index = "mpi4py" } +openmpi = { index = "mpi4py" } [[tool.uv.index]] name = "mpi4py" diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 91cd391322..ca213da13c 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -23,6 +23,7 @@ from ...utils import ( CI, + DP_TEST_TF2_ONLY, TEST_DEVICE, ) @@ -72,6 +73,7 @@ def tearDown(self): shutil.rmtree(ii) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") + @unittest.skipIf(DP_TEST_TF2_ONLY, "Conflict with TF2 eager mode.") def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name, suffix_idx in ( @@ -140,13 +142,21 @@ def test_deep_eval(self): nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] - for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): + for backend_name, suffix_idx in ( + # unfortunately, jax2tf cannot work with tf v1 behaviors + ("jax", 2) if DP_TEST_TF2_ONLY else ("tensorflow", 0), + ("pytorch", 0), + ("dpmodel", 0), + ("jax", 0), + ): backend = Backend.get_backend(backend_name)() if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) - self.save_data_to_model(prefix + backend.suffixes[0], reference_data) - deep_eval = DeepEval(prefix + backend.suffixes[0]) + self.save_data_to_model( + prefix + backend.suffixes[suffix_idx], reference_data + ) + deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx]) if deep_eval.get_dim_fparam() > 0: fparam = np.ones((nframes, deep_eval.get_dim_fparam())) else: @@ -169,7 +179,7 @@ def test_deep_eval(self): self.atype, fparam=fparam, aparam=aparam, - do_atomic_virial=True, + atomic=True, ) rets.append(ret) for ret in rets[1:]: diff --git a/source/tests/utils.py b/source/tests/utils.py index bfb3d445af..a9bf0f11ea 100644 --- a/source/tests/utils.py +++ b/source/tests/utils.py @@ -8,3 +8,4 @@ # see https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables CI = os.environ.get("CI") == "true" +DP_TEST_TF2_ONLY = os.environ.get("DP_TEST_TF2_ONLY") == "1" From 4b73fbe54d50546980cdd9a71f9e39c564cf75a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 01:23:00 +0000 Subject: [PATCH 2/4] [pre-commit.ci] pre-commit autoupdate (#4310) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.1 → v0.7.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.1...v0.7.2) - [github.com/pre-commit/mirrors-clang-format: v19.1.2 → v19.1.3](https://github.com/pre-commit/mirrors-clang-format/compare/v19.1.2...v19.1.3) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6cb534fd22..721a0cd6eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: exclude: ^source/3rdparty - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.7.1 + rev: v0.7.2 hooks: - id: ruff args: ["--fix"] @@ -60,7 +60,7 @@ repos: - id: blacken-docs # C++ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v19.1.2 + rev: v19.1.3 hooks: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$) From 9ed039765465229768535dda6d28f60888b2f42d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 4 Nov 2024 21:46:39 -0500 Subject: [PATCH 3/4] fix(cmake): Replace deprecated `FetchContent_Populate` with `FetchContent_MakeAvailable` (#4309) Update `source/lmp/plugin/CMakeLists.txt` to use `FetchContent_MakeAvailable` instead of `FetchContent_Populate`. * Replace `FetchContent_Populate(lammps_download)` with `FetchContent_MakeAvailable(lammps_download)` on line 13. * Remove `FetchContent_GetProperties` and `if(NOT lammps_download_POPULATED)` block. This fixes a CMake warning: ``` CMake Warning (dev) at /home/runner/work/_temp/-111029589/cmake-3.30.5-linux-x86_64/share/cmake-3.30/Modules/FetchContent.cmake:1953 (message): Calling FetchContent_Populate(lammps_download) is deprecated, call FetchContent_MakeAvailable(lammps_download) instead. Policy CMP0169 can be set to OLD to allow FetchContent_Populate(lammps_download) to be called directly for now, but the ability to call it with declared details will be removed completely in a future version. Call Stack (most recent call first): lmp/plugin/CMakeLists.txt:13 (FetchContent_Populate) This warning is for project developers. Use -Wno-dev to suppress it. ``` --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/njzjz/deepmd-kit?shareId=32a460fb-6c67-4397-b000-6f36e9841970). ## Summary by CodeRabbit - **Chores** - Simplified CMake configuration for the LAMMPS plugin, ensuring consistent availability of LAMMPS source. - Streamlined handling of LAMMPS versioning and installation logic. - Updated minimum required CMake version from 3.11 to 3.14. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- source/lmp/plugin/CMakeLists.txt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/source/lmp/plugin/CMakeLists.txt b/source/lmp/plugin/CMakeLists.txt index f912059261..13da3d7114 100644 --- a/source/lmp/plugin/CMakeLists.txt +++ b/source/lmp/plugin/CMakeLists.txt @@ -2,17 +2,14 @@ if(DEFINED LAMMPS_SOURCE_ROOT OR DEFINED LAMMPS_VERSION) message(STATUS "enable LAMMPS plugin mode") add_library(lammps_interface INTERFACE) if(DEFINED LAMMPS_VERSION) - cmake_minimum_required(VERSION 3.11) + cmake_minimum_required(VERSION 3.14) include(FetchContent) FetchContent_Declare( lammps_download GIT_REPOSITORY https://github.com/lammps/lammps GIT_TAG ${LAMMPS_VERSION}) - FetchContent_GetProperties(lammps_download) - if(NOT lammps_download_POPULATED) - FetchContent_Populate(lammps_download) - set(LAMMPS_SOURCE_ROOT ${lammps_download_SOURCE_DIR}) - endif() + FetchContent_MakeAvailable(lammps_download) + set(LAMMPS_SOURCE_ROOT ${lammps_download_SOURCE_DIR}) endif() set(LAMMPS_HEADER_DIR ${LAMMPS_SOURCE_ROOT}/src) message(STATUS "LAMMPS_HEADER_DIR is ${LAMMPS_HEADER_DIR}") From dabedd230cd4541750707a9accbf729c50325d86 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 4 Nov 2024 23:23:30 -0500 Subject: [PATCH 4/4] fix(jax): calculate virial in `call_lower` (#4304) ## Summary by CodeRabbit - **New Features** - Enhanced output of the model by providing a reduced form of the virial tensor, improving usability for further calculations and analyses. - Introduced a new test class, `TestEnerLower`, to evaluate lower-level energy models, excluding TensorFlow functionality. --------- Signed-off-by: Jinzhe Zeng --- deepmd/jax/model/base_model.py | 2 + source/tests/consistent/model/test_ener.py | 220 ++++++++++++++++++++- 2 files changed, 221 insertions(+), 1 deletion(-) diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 1e880700a2..44152a4c26 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -152,4 +152,6 @@ def eval_ce( avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2] ) model_predict[kk_derv_c] = extended_virial + # [nf, *def, 9] + model_predict[kk_derv_c + "_redu"] = jnp.sum(extended_virial, axis=1) return model_predict diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index ec73c57fa8..5d0253c5e8 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -6,8 +6,18 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) @@ -27,7 +37,8 @@ if INSTALLED_PT: from deepmd.pt.model.model import get_model as get_model_pt from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT - + from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy + from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch else: EnergyModelPT = None if INSTALLED_TF: @@ -39,6 +50,9 @@ ) if INSTALLED_JAX: + from deepmd.jax.common import ( + to_jax_array, + ) from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX from deepmd.jax.model.model import get_model as get_model_jax else: @@ -243,3 +257,207 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["energy_derv_c"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") + + +@parameterized( + ( + [], + [[0, 1]], + ), + ( + [], + [1], + ), +) +class TestEnerLower(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + pair_exclude_types, atom_exclude_types = self.param + return { + "type_map": ["O", "H"], + "pair_exclude_types": pair_exclude_types, + "atom_exclude_types": atom_exclude_types, + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = EnergyModelTF + dp_class = EnergyModelDP + pt_class = EnergyModelPT + jax_class = EnergyModelJAX + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + # TF does not have lower interface + return True + + @property + def skip_jax(self): + return not INSTALLED_JAX + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is EnergyModelDP: + return get_model_dp(data) + elif cls is EnergyModelPT: + return get_model_pt(data) + elif cls is EnergyModelJAX: + return get_model_jax(data) + return cls(**data, **self.additional_data) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + + rcut = 6.0 + nframes, nloc = atype.shape[:2] + coord_normalized = normalize_coord( + coords.reshape(nframes, nloc, 3), + box.reshape(nframes, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + 6.0, + [20, 20], + distinguish_types=True, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + self.nlist = nlist + self.extended_coord = extended_coord + self.extended_atype = extended_atype + self.mapping = mapping + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + raise NotImplementedError("no TF in this test") + + def eval_dp(self, dp_obj: Any) -> Any: + return dp_obj.call_lower( + self.extended_coord, + self.extended_atype, + self.nlist, + self.mapping, + do_atomic_virial=True, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return { + kk: torch_to_numpy(vv) + for kk, vv in pt_obj.forward_lower( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.nlist), + numpy_to_torch(self.mapping), + do_atomic_virial=True, + ).items() + } + + def eval_jax(self, jax_obj: Any) -> Any: + return { + kk: to_numpy_array(vv) + for kk, vv in jax_obj.call_lower( + to_jax_array(self.extended_coord), + to_jax_array(self.extended_atype), + to_jax_array(self.nlist), + to_jax_array(self.mapping), + do_atomic_virial=True, + ).items() + } + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend is self.RefBackend.DP: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["extended_force"].ravel(), + ret["virial"].ravel(), + ret["extended_virial"].ravel(), + ) + elif backend is self.RefBackend.JAX: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ret["energy_derv_c"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}")