Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New broadcasted gaussian integrals #502

Merged
merged 14 commits into from
Oct 15, 2024
3 changes: 2 additions & 1 deletion mrmustard/math/backend_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def assign(self, tensor: tf.Tensor, value: tf.Tensor) -> tf.Tensor:

def astensor(self, array: np.ndarray | tf.Tensor, dtype=None) -> tf.Tensor:
dtype = dtype or np.array(array).dtype.name
return tf.convert_to_tensor(array, dtype)
return tf.cast(tf.convert_to_tensor(array, dtype_hint=dtype), dtype)

def atleast_1d(self, array: tf.Tensor, dtype=None) -> tf.Tensor:
return tf.experimental.numpy.atleast_1d(self.cast(self.astensor(array), dtype))
Expand Down Expand Up @@ -205,6 +205,7 @@ def from_backend(self, value) -> bool:
return isinstance(value, (tf.Tensor, tf.Variable))

def gather(self, array: tf.Tensor, indices: tf.Tensor, axis: int) -> tf.Tensor:
indices = tf.convert_to_tensor(indices, dtype=tf.int32)
return tf.gather(array, indices, axis=axis)

def imag(self, array: tf.Tensor) -> tf.Tensor:
Expand Down
524 changes: 258 additions & 266 deletions mrmustard/physics/gaussian_integrals.py

Large diffs are not rendered by default.

26 changes: 9 additions & 17 deletions mrmustard/physics/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@

from mrmustard import math, settings
from mrmustard.physics.gaussian_integrals import (
contract_two_Abc_poly,
reorder_abc,
complex_gaussian_integral,
complex_gaussian_integral_1,
complex_gaussian_integral_2,
)
from mrmustard.physics.ansatze import Ansatz, PolyExpAnsatz, ArrayAnsatz
from mrmustard.utils.typing import (
Expand Down Expand Up @@ -436,12 +436,7 @@ def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> Bargmann:
Returns:
Bargmann: the ansatz with the given indices traced over
"""
A, b, c = [], [], []
for Abc in zip(self.A, self.b, self.c):
Aij, bij, cij = complex_gaussian_integral(Abc, idx_z, idx_zconj, measure=-1.0)
A.append(Aij)
b.append(bij)
c.append(cij)
A, b, c = complex_gaussian_integral_1(self.triple, idx_z, idx_zconj, measure=-1.0)
ziofil marked this conversation as resolved.
Show resolved Hide resolved
return Bargmann(A, b, c)

def __call__(self, z: ComplexTensor) -> ComplexTensor:
Expand Down Expand Up @@ -497,22 +492,19 @@ def __matmul__(self, other: Bargmann) -> Bargmann:
idx_s = self._contract_idxs
idx_o = other._contract_idxs

Abc = []
if settings.UNSAFE_ZIP_BATCH:
if self.ansatz.batch_size != other.ansatz.batch_size:
raise ValueError(
f"Batch size of the two ansatze must match since the settings.UNSAFE_ZIP_BATCH is {settings.UNSAFE_ZIP_BATCH}."
)
for (A1, b1, c1), (A2, b2, c2) in zip(
zip(self.A, self.b, self.c), zip(other.A, other.b, other.c)
):
Abc.append(contract_two_Abc_poly((A1, b1, c1), (A2, b2, c2), idx_s, idx_o))
A, b, c = complex_gaussian_integral_2(
self.triple, other.triple, idx_s, idx_o, mode="zip"
)
else:
for A1, b1, c1 in zip(self.A, self.b, self.c):
for A2, b2, c2 in zip(other.A, other.b, other.c):
Abc.append(contract_two_Abc_poly((A1, b1, c1), (A2, b2, c2), idx_s, idx_o))
A, b, c = complex_gaussian_integral_2(
self.triple, other.triple, idx_s, idx_o, mode="kron"
)

A, b, c = zip(*Abc)
return Bargmann(A, b, c)

def to_dict(self) -> dict[str, ArrayLike]:
Expand Down
12 changes: 5 additions & 7 deletions mrmustard/physics/triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from mrmustard import math, settings
from mrmustard.utils.typing import Matrix, Vector, Scalar
from mrmustard.physics.gaussian_integrals import contract_two_Abc
from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2


# ~~~~~~~~~
Expand Down Expand Up @@ -239,12 +239,10 @@ def sauron_state_Abc(n: int, epsilon: float):
As = np.zeros([n + 1, 1, 1], dtype="complex128")

# normalization
prob = 0
for A1, b1, c1 in zip(As, bs, cs):
for A2, b2, c2 in zip(As, bs, cs):
prob += contract_two_Abc(
(np.conj(A1), np.conj(b1), np.conj(c1)), (A2, b2, c2), [0], [0]
)[2]
probs = complex_gaussian_integral_2(
(np.conj(As), np.conj(bs), np.conj(cs)), (As, bs, cs), [0], [0]
)[2]
prob = np.sum(probs)
cs /= np.sqrt(prob)

return As, bs, cs
Expand Down
14 changes: 7 additions & 7 deletions tests/test_lab_dev/test_circuit_components_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

from mrmustard import math, settings
from mrmustard.physics.triples import identity_Abc, displacement_map_s_parametrized_Abc
from mrmustard.physics.representations import Bargmann
from mrmustard.physics.bargmann import wigner_to_bargmann_rho
from mrmustard.physics.gaussian_integrals import (
contract_two_Abc,
real_gaussian_integral,
complex_gaussian_integral,
complex_gaussian_integral_1,
complex_gaussian_integral_2,
join_Abc,
join_Abc_real,
)
from mrmustard.physics.representations import Bargmann
from mrmustard.lab_dev.circuit_components_utils import TraceOut, BtoPS, BtoQ
from mrmustard.lab_dev.circuit_components import CircuitComponent
from mrmustard.lab_dev.states import Coherent, DM
Expand Down Expand Up @@ -126,7 +126,7 @@ def testBtoPS_contraction_with_state(self):

# get new triple by contraction
Ds_bargmann_triple = displacement_map_s_parametrized_Abc(s=0, n_modes=1)
A2, b2, c2 = contract_two_Abc(
A2, b2, c2 = complex_gaussian_integral_2(
state_bargmann_triple, Ds_bargmann_triple, idx1=[0, 1], idx2=[1, 3]
)

Expand Down Expand Up @@ -154,7 +154,7 @@ def testBtoPS_contraction_with_state(self):

# get new triple by contraction
Ds_bargmann_triple = displacement_map_s_parametrized_Abc(s=0, n_modes=2)
A2, b2, c2 = contract_two_Abc(
A2, b2, c2 = complex_gaussian_integral_2(
state_bargmann_triple,
Ds_bargmann_triple,
idx1=[0, 1, 2, 3],
Expand Down Expand Up @@ -186,7 +186,7 @@ def testBtoQ_works_correctly_by_applying_it_twice_on_a_state(self):
BtoQ_CC1.representation.b[0],
BtoQ_CC1.representation.c[0],
)
Ainter, binter, cinter = complex_gaussian_integral(
Ainter, binter, cinter = complex_gaussian_integral_1(
join_Abc((A0, b0, c0), (step1A, step1b, step1c)),
idx_z=[0, 1],
idx_zconj=[4, 5],
Expand Down Expand Up @@ -220,7 +220,7 @@ def testBtoQ_works_correctly_by_applying_it_twice_on_a_state(self):
BtoQ_CC1.representation.b[0],
BtoQ_CC1.representation.c[0],
)
Ainter, binter, cinter = complex_gaussian_integral(
Ainter, binter, cinter = complex_gaussian_integral_1(
join_Abc((A0, b0, c0), (step1A, step1b, step1c)),
idx_z=[
0,
Expand Down
Loading
Loading