Skip to content

Commit

Permalink
the walrus
Browse files Browse the repository at this point in the history
  • Loading branch information
SamFerracin committed Mar 27, 2024
1 parent 1bbb68f commit a376a52
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 14 deletions.
34 changes: 33 additions & 1 deletion mrmustard/math/lattice/strategies/compactFock/inputValidation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,39 @@
from mrmustard.math.lattice.strategies.compactFock.singleLeftoverMode_grad import (
fock_representation_1leftoverMode_grad,
)
from thewalrus._hafnian import input_validation

def input_validation(A, rtol=1e-05, atol=1e-08):
"""Checks that the matrix A satisfies the requirements for Hafnian calculation.
These include:
* That the ``A`` is a NumPy array
* That ``A`` is square
* That ``A`` does not contain any NaNs
* That ``A`` is symmetric
Args:
A (array): a NumPy array.
rtol (float): the relative tolerance parameter used in ``np.allclose``
atol (float): the absolute tolerance parameter used in ``np.allclose``
Returns:
bool: returns ``True`` if the matrix satisfies all requirements
"""

if not isinstance(A, np.ndarray):
raise TypeError("Input matrix must be a NumPy array.")

n = A.shape

if n[0] != n[1]:
raise ValueError("Input matrix must be square.")

if np.isnan(A).any():
raise ValueError("Input matrix must not contain NaNs.")

if not np.allclose(A, A.T, rtol=rtol, atol=atol):
raise ValueError("Input matrix must be symmetric.")

return True


def hermite_multidimensional_diagonal(A, B, G0, cutoffs, rtol=1e-05, atol=1e-08):
Expand Down
7 changes: 0 additions & 7 deletions mrmustard/physics/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from typing import Any, Optional, Sequence, Tuple, Union

from thewalrus.quantum import is_pure_cov

from mrmustard import math, settings
from mrmustard.math.tensor_wrappers.xptensor import XPMatrix, XPVector
from mrmustard.utils.typing import Matrix, Scalar, Vector
Expand Down Expand Up @@ -660,11 +658,6 @@ def number_cov(cov: Matrix, means: Vector) -> Matrix:
)


def is_mixed_cov(cov: Matrix) -> bool: # TODO: deprecate
r"""Returns ``True`` if the covariance matrix is mixed, ``False`` otherwise."""
return not is_pure_cov(math.asnumpy(cov))


def trace(cov: Matrix, means: Vector, Bmodes: Sequence[int]) -> Tuple[Matrix, Vector]:
r"""Returns the covariances and means after discarding the specified modes.
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ grpcio = "1.60.0"
numpy = "^1.23.5"
scipy = "^1.8.0"
numba = "^0.59"
thewalrus = "^0.21.0"
thewalrus = "^0.19.0"
rich = "^10.15.1"
matplotlib = "^3.5.0"
ray = { version = "^2.5.0", extras = ["tune"], optional = true }
Expand Down
2 changes: 1 addition & 1 deletion tests/test_about.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for the :mod:`thewalrus` configuration class :class:`Configuration`.
Unit tests for the ``about`` method.
"""

import contextlib
Expand Down

0 comments on commit a376a52

Please sign in to comment.