From a376a5221fd7af74ad9c951d918175593185e182 Mon Sep 17 00:00:00 2001 From: SamFerracin Date: Wed, 27 Mar 2024 15:00:18 -0400 Subject: [PATCH] the walrus --- .../strategies/compactFock/inputValidation.py | 34 ++++++++++++++++++- mrmustard/physics/gaussian.py | 7 ---- poetry.lock | 8 ++--- pyproject.toml | 2 +- tests/test_about.py | 2 +- 5 files changed, 39 insertions(+), 14 deletions(-) diff --git a/mrmustard/math/lattice/strategies/compactFock/inputValidation.py b/mrmustard/math/lattice/strategies/compactFock/inputValidation.py index 237c7de59..b42158dbe 100644 --- a/mrmustard/math/lattice/strategies/compactFock/inputValidation.py +++ b/mrmustard/math/lattice/strategies/compactFock/inputValidation.py @@ -18,7 +18,39 @@ from mrmustard.math.lattice.strategies.compactFock.singleLeftoverMode_grad import ( fock_representation_1leftoverMode_grad, ) -from thewalrus._hafnian import input_validation + +def input_validation(A, rtol=1e-05, atol=1e-08): + """Checks that the matrix A satisfies the requirements for Hafnian calculation. + These include: + * That the ``A`` is a NumPy array + * That ``A`` is square + * That ``A`` does not contain any NaNs + * That ``A`` is symmetric + + Args: + A (array): a NumPy array. + rtol (float): the relative tolerance parameter used in ``np.allclose`` + atol (float): the absolute tolerance parameter used in ``np.allclose`` + + Returns: + bool: returns ``True`` if the matrix satisfies all requirements + """ + + if not isinstance(A, np.ndarray): + raise TypeError("Input matrix must be a NumPy array.") + + n = A.shape + + if n[0] != n[1]: + raise ValueError("Input matrix must be square.") + + if np.isnan(A).any(): + raise ValueError("Input matrix must not contain NaNs.") + + if not np.allclose(A, A.T, rtol=rtol, atol=atol): + raise ValueError("Input matrix must be symmetric.") + + return True def hermite_multidimensional_diagonal(A, B, G0, cutoffs, rtol=1e-05, atol=1e-08): diff --git a/mrmustard/physics/gaussian.py b/mrmustard/physics/gaussian.py index f5737c34a..7611660ed 100644 --- a/mrmustard/physics/gaussian.py +++ b/mrmustard/physics/gaussian.py @@ -18,8 +18,6 @@ from typing import Any, Optional, Sequence, Tuple, Union -from thewalrus.quantum import is_pure_cov - from mrmustard import math, settings from mrmustard.math.tensor_wrappers.xptensor import XPMatrix, XPVector from mrmustard.utils.typing import Matrix, Scalar, Vector @@ -660,11 +658,6 @@ def number_cov(cov: Matrix, means: Vector) -> Matrix: ) -def is_mixed_cov(cov: Matrix) -> bool: # TODO: deprecate - r"""Returns ``True`` if the covariance matrix is mixed, ``False`` otherwise.""" - return not is_pure_cov(math.asnumpy(cov)) - - def trace(cov: Matrix, means: Vector, Bmodes: Sequence[int]) -> Tuple[Matrix, Vector]: r"""Returns the covariances and means after discarding the specified modes. diff --git a/poetry.lock b/poetry.lock index 830e825e7..aac3f8631 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3509,13 +3509,13 @@ tests = ["pytest", "pytest-cov"] [[package]] name = "thewalrus" -version = "0.21.0" +version = "0.19.0" description = "Open source library for hafnian calculation" optional = false python-versions = "*" files = [ - {file = "thewalrus-0.21.0-py3-none-any.whl", hash = "sha256:5f393d17fc8362e7156337faed769e99f15149040ef298d2a1be27f234aa8cb9"}, - {file = "thewalrus-0.21.0.tar.gz", hash = "sha256:a8e1d6a7dea1e2c70aeb172f2dba1dfc7fabfa6e000c8ace9c5f81c7df422637"}, + {file = "thewalrus-0.19.0-py3-none-any.whl", hash = "sha256:07b6e2969bf5405a2df736c442b1500857438bbd2afc2053b8b600b8b0c67f97"}, + {file = "thewalrus-0.19.0.tar.gz", hash = "sha256:06ff07a14cd8cd4650d9c82b8bb8301ef9a58dcdd4bafb14841768ccf80c98b9"}, ] [package.dependencies] @@ -3785,4 +3785,4 @@ ray = ["ray", "scikit-optimize"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "7d1198f211fdadc258ac6921669b414f48f6b61b279f81246aaef1f38b571b5c" +content-hash = "d82541767cbc68915c992d627228835dda2c9b836b922219c71d7fb42d899555" diff --git a/pyproject.toml b/pyproject.toml index acc60cf5f..f9d7845c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ grpcio = "1.60.0" numpy = "^1.23.5" scipy = "^1.8.0" numba = "^0.59" -thewalrus = "^0.21.0" +thewalrus = "^0.19.0" rich = "^10.15.1" matplotlib = "^3.5.0" ray = { version = "^2.5.0", extras = ["tune"], optional = true } diff --git a/tests/test_about.py b/tests/test_about.py index 0152d681f..487f42536 100644 --- a/tests/test_about.py +++ b/tests/test_about.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Unit tests for the :mod:`thewalrus` configuration class :class:`Configuration`. +Unit tests for the ``about`` method. """ import contextlib