diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index d08659ea1..81c62e986 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -25,6 +25,7 @@ import numbers from functools import cached_property +import copy import numpy as np from numpy.typing import ArrayLike import ipywidgets as widgets @@ -46,6 +47,7 @@ from mrmustard.lab_dev.wires import Wires from mrmustard.physics.triples import identity_Abc + __all__ = ["CircuitComponent"] @@ -109,6 +111,7 @@ def __init__( ) if self._representation: self._representation = self._representation.reorder(tuple(perm)) + self._index_representation = {i: ("B", None) for i in self.wires.indices} def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: """ @@ -168,6 +171,13 @@ def adjoint(self) -> CircuitComponent: ret.short_name = self.short_name for param in self.parameter_set.all_parameters.values(): ret._add_parameter(param) + + # handling index representations: + for i, j in enumerate(kets): + ret._index_representation[i] = self._index_representation[j] + for i, j in enumerate(bras): + ret._index_representation[i + len(kets)] = self._index_representation[j] + return ret @property @@ -187,6 +197,16 @@ def dual(self) -> CircuitComponent: ret.short_name = self.short_name for param in self.parameter_set.all_parameters.values(): ret._add_parameter(param) + + # handling index representations: + for i, j in enumerate(ib): + ret._index_representation[i] = self._index_representation[j] + for i, j in enumerate(ob): + ret._index_representation[i + len(ib)] = self._index_representation[j] + for i, j in enumerate(ik): + ret._index_representation[i + len(ib + ob)] = self._index_representation[j] + for i, j in enumerate(ok): + ret._index_representation[i + len(ib + ob + ik)] = self._index_representation[j] return ret @cached_property @@ -423,6 +443,7 @@ def _from_attributes( ret._name = name ret._representation = representation ret._wires = wires + ret._index_representation = {i: ("B", None) for i in wires.indices} return ret return CircuitComponent(representation, wires, name) @@ -549,6 +570,8 @@ def on(self, modes: Sequence[int]) -> CircuitComponent: modes_in_ket=set(modes) if ik else set(), ) + ret._index_representation = copy.deepcopy(self._index_representation) + return ret def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: @@ -587,9 +610,10 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: del ret.manual_shape return ret - def to_bargmann(self) -> CircuitComponent: + def to_bargmann(self, indices: Sequence[int] | None = None) -> CircuitComponent: r""" - Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation. + Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation on the specified "indices." + If "indices" are not specified, all indices are transformed into bargmann. .. code-block:: >>> from mrmustard.lab_dev import Dgate @@ -604,9 +628,17 @@ def to_bargmann(self) -> CircuitComponent: >>> assert d_bargmann.wires == d.wires >>> assert isinstance(d_bargmann.representation, Bargmann) """ - if isinstance(self.representation, Bargmann): - return self - else: + + ret = copy.deepcopy(self) + if isinstance(self.representation, Bargmann): # TODO: better name for Bargmann class + # check cc rep + if not indices: + indices = self.wires.indices + + ret = ret._apply_btoq_for_change_of_rep(indices) + ret = ret._apply_btops_for_change_of_rep(indices) + + elif isinstance(self.representation, Fock): if self.representation.ansatz._original_abc_data: A, b, c = self.representation.ansatz._original_abc_data else: @@ -620,7 +652,7 @@ def to_bargmann(self) -> CircuitComponent: ret = self._from_attributes(bargmann, self.wires, self.name) if "manual_shape" in ret.__dict__: del ret.manual_shape - return ret + return ret def _add_parameter(self, parameter: Constant | Variable): r""" @@ -650,6 +682,76 @@ def _getitem_builtin(self, modes: set[int]): kwargs = self.parameter_set[items].to_dict() return self.__class__(modes=modes, **kwargs) + def _apply_btops_for_change_of_rep(self, indices: Sequence[int]) -> CircuitComponent: + r""" + Helper function for change of representation in to_bargmann() + + Args: + indices: the set of indices that we want to be represented in bargmann. + + Output: + the cc object with Bargmann representation on the specified indices. The representations on the other wires remain intact. + """ + + from .circuit_components_utils import BtoPS + + ret = copy.deepcopy(self) + + for i in indices: + if len(self._index_representation[i]) > 2: + continue + name, arg = self._index_representation[i] + + if name == "PS": + ret._index_representation[i] = ("B", None) + m = self.wires.index_to_mode_dict[i] + if i in self.wires.output.bra.indices: + if m not in self.wires.output.ket.modes: + raise ValueError( + f"The object does not have a consistent representation. Mode {m} with PS representation has appeared only on the output bra." + ) + friend_index = self.wires.index_dicts[2][m] + ret._index_representation[friend_index] = ("B", None) + ret = ret @ BtoPS([m], s=arg).adjoint.inverse() + + if i in self.wires.input.bra.indices: + if m not in self.wires.input.ket.modes: + raise ValueError( + f"The object does not have a consistent representation. Mode {m} with PS representation has appeared only on the input bra." + ) + friend_index = self.wires.index_dicts[3][m] + ret._index_representation[friend_index] = ("B", None) + ret = BtoPS([m], s=arg).dual.inverse() @ ret + + return ret + + def _apply_btoq_for_change_of_rep(self, indices: Sequence[int]) -> CircuitComponent: + r""" + Helper function for change of representation in to_bargmann() + """ + + from .circuit_components_utils import BtoQ + + ret = copy.deepcopy(self) + for i in indices: + if len(self._index_representation[i]) > 2: + continue + name, arg = self._index_representation[i] + if name == "Q": + ret._index_representation[i] = ("B", None) # perhaps not needed -- can be removed + if i in self.wires.output.bra.indices: + ret = ret @ BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).adjoint.inverse() + if i in self.wires.output.ket.indices: + ret = ret @ BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).inverse() + if i in self.wires.input.bra.indices: + ret = ( + BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).dual.adjoint.inverse() + @ ret + ) + if i in self.wires.input.ket.indices: + ret = BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).dual.inverse() @ ret + return ret + def _light_copy(self, wires: Wires | None = None) -> CircuitComponent: r""" Creates a "light" copy of this component by referencing its __dict__, except for the wires, @@ -690,8 +792,11 @@ def __add__(self, other: CircuitComponent) -> CircuitComponent: """ if self.wires != other.wires: raise ValueError("Cannot add components with different wires.") - rep = self.representation + other.representation + rep = ( + self.to_bargmann().representation + other.to_bargmann().representation + ) # addition occurs in bargmann always name = self.name if self.name == other.name else "" + # TODO: go back to bargmann on all modes return self._from_attributes(rep, self.wires, name) def __eq__(self, other) -> bool: @@ -700,7 +805,16 @@ def __eq__(self, other) -> bool: Compares representations and wires, but not the other attributes (e.g. name and parameter set). """ - return self.representation == other.representation and self.wires == other.wires + from .circuit_components_utils import BtoQ, BtoPS + + if (type(self.representation) == type(other.representation) == Fock) or isinstance( + self, (BtoQ, BtoPS) + ): + return self.representation == other.representation and self.wires == other.wires + else: + self_rep = self.to_bargmann().representation + other_rep = other.to_bargmann().representation + return self_rep == other_rep and self.wires == other.wires def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: r""" @@ -719,27 +833,100 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: >>> att = Attenuator([0], 0.5) >>> assert (coh @ att).wires.input.bra # the input bra is still uncontracted """ + from .circuit_components_utils import BtoQ, BtoPS + if isinstance(other, (numbers.Number, np.ndarray)): return self * other wires_result, perm = self.wires @ other.wires idx_z, idx_zconj = self._matmul_indices(other) - if type(self.representation) is type(other.representation): + + if type(self.representation) is type(other.representation) is Fock: self_rep = self.representation other_rep = other.representation else: - self_rep = self.to_bargmann().representation - other_rep = other.to_bargmann().representation + if ( + (not isinstance(self, BtoQ)) + and (not isinstance(other, BtoQ)) + and (not isinstance(self, BtoPS)) + and (not isinstance(other, BtoPS)) + ): + self_copy = copy.deepcopy(self) + other_copy = copy.deepcopy(other) + index_self, index_other = self_copy.wires.contracted_indices(other_copy.wires) + self_rep = self_copy.to_bargmann( + index_self + ).representation # this is where the copy is required (to not send the intial objects back to Bargmann) + other_rep = other_copy.to_bargmann(index_other).representation + else: + self_rep = self.representation + other_rep = other.representation rep = self_rep[idx_z] @ other_rep[idx_zconj] rep = rep.reorder(perm) if perm else rep - return CircuitComponent._from_attributes(rep, wires_result, None) + result = CircuitComponent._from_attributes(rep, wires_result, None) + + # REMEMBER the representations: + # set the index_representation of uncontracted indices: + # (this will be overwritten if we have a change of representation e.g. other == BtoQ) + for m in other.wires.output.bra.modes: + i = result.wires.index_dicts[0][m] + j = other.wires.index_dicts[0][m] + result._index_representation[i] = other._index_representation[j][:2] + for m in other.wires.output.ket.modes: + i = result.wires.index_dicts[2][m] + j = other.wires.index_dicts[2][m] + result._index_representation[i] = other._index_representation[j][:2] + + for m in self.wires.input.bra.modes: + i = result.wires.index_dicts[1][m] + j = self.wires.index_dicts[1][m] + result._index_representation[i] = self._index_representation[j][:2] + for m in self.wires.input.ket.modes: + i = result.wires.index_dicts[3][m] + j = self.wires.index_dicts[3][m] + result._index_representation[i] = self._index_representation[j][:2] + + # now we check for indices that might have been contracted: + idx_1, idx_2 = self.wires.contracted_indices(other.wires) + + for m in other.wires.input.bra.modes: + j = other.wires.index_dicts[1][m] + if j not in idx_2: + i = result.wires.index_dicts[1][m] + result._index_representation[i] = other._index_representation[j][:2] + for m in other.wires.input.ket.modes: + + j = other.wires.index_dicts[3][m] + if j not in idx_2: + i = result.wires.index_dicts[3][m] + result._index_representation[i] = other._index_representation[j][:2] + + for m in self.wires.output.bra.modes: + j = self.wires.index_dicts[0][m] + if j not in idx_1: + i = result.wires.index_dicts[0][m] + result._index_representation[i] = self._index_representation[j][:2] + + for m in self.wires.output.ket.modes: + j = self.wires.index_dicts[2][m] + if j not in idx_1: + i = result.wires.index_dicts[2][m] + result._index_representation[i] = self._index_representation[j][:2] + + if isinstance(self, (BtoQ, BtoPS)) and isinstance(other, (BtoQ, BtoPS)): + for i in result.wires.indices: + result._index_representation[i] = result._index_representation[i] + (None,) + + return result def __mul__(self, other: Scalar) -> CircuitComponent: r""" Implements the multiplication by a scalar from the right. """ - return self._from_attributes(self.representation * other, self.wires, self.name) + ret = self._from_attributes(self.representation * other, self.wires, self.name) + ret._index_representation = self._index_representation + return ret def __repr__(self) -> str: repr = self.representation @@ -828,6 +1015,7 @@ def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitCompone msg = f"``>>`` not supported between {self} and {other} because it's not clear " msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components." raise ValueError(msg) + return self._rshift_return(ret) def __sub__(self, other: CircuitComponent) -> CircuitComponent: diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py index 59f4188ca..11fb91189 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_ps.py @@ -16,6 +16,7 @@ The class representing an operation that changes Bargmann into phase space. """ # pylint: disable=protected-access +# pylint: disable=protected-access from __future__ import annotations from typing import Sequence @@ -52,8 +53,13 @@ def __init__( ), name="BtoPS", ) + self._add_parameter(Constant(s, "s")) - self.s = s + d1 = {mode: ("PS", float(self.s.value)) for mode in range(len(modes))} + d2 = {mode + len(modes): ("B", None) for mode in range(len(modes))} + d3 = {mode + 2 * len(modes): ("PS", float(self.s.value)) for mode in range(len(modes))} + d4 = {mode + 3 * len(modes): ("B", None) for mode in range(len(modes))} + self._index_representation = {**d1, **d2, **d3, **d4} @property def adjoint(self) -> BtoPS: @@ -61,10 +67,17 @@ def adjoint(self) -> BtoPS: kets = self.wires.ket.indices rep = self.representation.reorder(kets + bras).conj() - ret = BtoPS(self.modes, self.s) + ret = BtoPS(self.modes, float(self.s.value)) ret._representation = rep ret._wires = self.wires.adjoint ret._name = self.name + "_adj" + + # handling index representations: + for i, j in enumerate(kets): + ret._index_representation[i] = self._index_representation[j] + for i, j in enumerate(bras): + ret._index_representation[i + len(kets)] = self._index_representation[j] + return ret @property @@ -75,16 +88,43 @@ def dual(self) -> BtoPS: ob = self.wires.bra.output.indices rep = self.representation.reorder(ib + ob + ik + ok).conj() - ret = BtoPS(self.modes, self.s) + ret = BtoPS(self.modes, float(self.s.value)) ret._representation = rep ret._wires = self.wires.dual ret._name = self.name + "_dual" + + # handling index representations: + for i, j in enumerate(ib): + ret._index_representation[i] = self._index_representation[j] + for i, j in enumerate(ob): + ret._index_representation[i + len(ib)] = self._index_representation[j] + for i, j in enumerate(ik): + ret._index_representation[i + len(ib + ob)] = self._index_representation[j] + for i, j in enumerate(ok): + ret._index_representation[i + len(ib + ob + ik)] = self._index_representation[j] + return ret def inverse(self) -> BtoPS: inv = super().inverse() - ret = BtoPS(self.modes, self.s) + ret = BtoPS(self.modes, float(self.s.value)) ret._representation = inv.representation ret._wires = inv.wires ret._name = inv.name + + ok = self.wires.ket.output.indices + ik = self.wires.ket.input.indices + ib = self.wires.bra.input.indices + ob = self.wires.bra.output.indices + + # handling index representations: + for i, j in enumerate(ib): + ret._index_representation[i] = self._index_representation[j] + for i, j in enumerate(ob): + ret._index_representation[i + len(ib)] = self._index_representation[j] + for i, j in enumerate(ik): + ret._index_representation[i + len(ib + ob)] = self._index_representation[j] + for i, j in enumerate(ok): + ret._index_representation[i + len(ib + ob + ik)] = self._index_representation[j] + return ret diff --git a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py index ec531f0b5..af89025b5 100644 --- a/mrmustard/lab_dev/circuit_components_utils/b_to_q.py +++ b/mrmustard/lab_dev/circuit_components_utils/b_to_q.py @@ -16,6 +16,7 @@ The class representing an operation that changes Bargmann into quadrature. """ # pylint: disable=protected-access +# pylint: disable=protected-access from __future__ import annotations from typing import Sequence @@ -53,8 +54,11 @@ def __init__( representation=repr, name="BtoQ", ) + self._add_parameter(Constant(phi, "phi")) - self.phi = phi + d1 = {mode: ("Q", float(self.phi.value)) for mode in range(len(modes))} + d2 = {mode + len(modes): ("B", None) for mode in range(len(modes))} + self._index_representation = {**d1, **d2} @property def adjoint(self) -> BtoQ: @@ -62,10 +66,19 @@ def adjoint(self) -> BtoQ: kets = self.wires.ket.indices rep = self.representation.reorder(kets + bras).conj() - ret = BtoQ(self.modes, self.phi) + ret = BtoQ(self.modes, float(self.phi.value)) ret._representation = rep ret._wires = self.wires.adjoint ret._name = self.name + "_adj" + + # handling index representations: + for i, j in enumerate(kets): + ret._index_representation[i] = self._index_representation[j] + for i, j in enumerate(bras): + ret._index_representation[i + len(kets)] = self._index_representation[j] + + ret._index_representation = self._index_representation + return ret @property @@ -76,16 +89,43 @@ def dual(self) -> BtoQ: ob = self.wires.bra.output.indices rep = self.representation.reorder(ib + ob + ik + ok).conj() - ret = BtoQ(self.modes, self.phi) + ret = BtoQ(self.modes, float(self.phi.value)) ret._representation = rep ret._wires = self.wires.dual ret._name = self.name + "_dual" + + # handling index representations: + for i, j in enumerate(ib): + ret._index_representation[i] = self._index_representation[j] + for i, j in enumerate(ob): + ret._index_representation[i + len(ib)] = self._index_representation[j] + for i, j in enumerate(ik): + ret._index_representation[i + len(ib + ob)] = self._index_representation[j] + for i, j in enumerate(ok): + ret._index_representation[i + len(ib + ob + ik)] = self._index_representation[j] + return ret def inverse(self) -> BtoQ: inv = super().inverse() - ret = BtoQ(self.modes, self.phi) + ret = BtoQ(self.modes, float(self.phi.value)) ret._representation = inv.representation ret._wires = inv.wires ret._name = inv.name + + ok = self.wires.ket.output.indices + ik = self.wires.ket.input.indices + ib = self.wires.bra.input.indices + ob = self.wires.bra.output.indices + + # handling index representations: + for i, j in enumerate(ib): + ret._index_representation[i] = self._index_representation[j] + for i, j in enumerate(ob): + ret._index_representation[i + len(ib)] = self._index_representation[j] + for i, j in enumerate(ik): + ret._index_representation[i + len(ib + ob)] = self._index_representation[j] + for i, j in enumerate(ok): + ret._index_representation[i + len(ib + ob + ik)] = self._index_representation[j] + return ret diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 95fc4dda5..95b196a98 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -964,7 +964,9 @@ def __rshift__(self, other: CircuitComponent) -> CircuitComponent: w = result.wires if not w.input and w.bra.modes == w.ket.modes: - return DM(w.modes, result.representation) + ret = DM(w.modes, result.representation) + ret._index_representation = result._index_representation + return ret return result @@ -1259,7 +1261,11 @@ def __rshift__(self, other: CircuitComponent | Scalar) -> CircuitComponent | Bat if not result.wires.input: if not result.wires.bra: - return Ket(result.wires.modes, result.representation) + ret = Ket(result.wires.modes, result.representation) + ret._index_representation = result._index_representation + return ret elif result.wires.bra.modes == result.wires.ket.modes: - result = DM(result.wires.modes, result.representation) + ret = DM(result.wires.modes, result.representation) + ret._index_representation = result._index_representation + return ret return result diff --git a/mrmustard/lab_dev/transformations/dgate.py b/mrmustard/lab_dev/transformations/dgate.py index a7640f3ef..1a6fd42dc 100644 --- a/mrmustard/lab_dev/transformations/dgate.py +++ b/mrmustard/lab_dev/transformations/dgate.py @@ -94,7 +94,6 @@ def __init__( xs, ys = list(reshape_params(len(modes), x=x, y=y)) self._add_parameter(make_parameter(x_trainable, xs, "x", x_bounds)) self._add_parameter(make_parameter(y_trainable, ys, "y", y_bounds)) - self._representation = Bargmann.from_function( fn=triples.displacement_gate_Abc, x=self.x, y=self.y ) diff --git a/mrmustard/lab_dev/wires.py b/mrmustard/lab_dev/wires.py index 7f896df1d..3c36237c6 100644 --- a/mrmustard/lab_dev/wires.py +++ b/mrmustard/lab_dev/wires.py @@ -325,6 +325,14 @@ def indices(self) -> tuple[int, ...]: self.index_dicts[t][m] for t, modes in enumerate(self.sorted_args) for m in modes ) + @cached_property + def index_to_mode_dict(self) -> dict[int, int]: + r""" + A dictionary that maps indecies to modes. For example, for D = Dgate([0],x=1), we + have {0: 0, 1:0} since all wires (indices) correspond to mode 0. + """ + return {i: m for d in self.index_dicts for m, i in d.items()} + @cached_property def input(self) -> Wires: r""" diff --git a/tests/test_lab_dev/test_circuit_components.py b/tests/test_lab_dev/test_circuit_components.py index b0b6eee0a..d5b80136d 100644 --- a/tests/test_lab_dev/test_circuit_components.py +++ b/tests/test_lab_dev/test_circuit_components.py @@ -43,6 +43,7 @@ Channel, Wires, ) +from mrmustard.lab_dev.circuit_components_utils import BtoQ, BtoPS from ..random import Abc_triple @@ -138,6 +139,20 @@ def test_adjoint(self): assert d1_adj_adj.parameter_set == d1.parameter_set assert d1_adj_adj.representation == d1.representation + # index representation test: + d2 = d1 >> BtoQ([1], 0.7) + d2_dual = d2.adjoint + d2_dual._index_representation + assert d2_dual._index_representation == { + 0: ("Q", 0.7), + 1: ("B", None), + 2: ("B", None), + 3: ("B", None), + } + + rho = DM.random([0]) @ BtoQ([0]) + assert rho.adjoint._index_representation == {0: ("Q", 0), 1: ("B", None)} + def test_dual(self): d1 = Dgate([1, 8], x=0.1, y=0.2) d1_dual = d1.dual @@ -157,6 +172,16 @@ def test_dual(self): assert d1_dual_dual.wires == d1.wires assert d1_dual_dual.representation == d1.representation + # index representation test + d2 = d1 >> BtoQ([1], 0.7) + d2_dual = d2.dual + assert d2_dual._index_representation == { + 0: ("B", None), + 1: ("B", None), + 2: ("Q", 0.7), + 3: ("B", None), + } + def test_light_copy(self): d1 = CircuitComponent( Bargmann(*displacement_gate_Abc(0.1, 0.1)), wires=[(), (), (1,), (1,)] @@ -182,6 +207,12 @@ def test_on(self): assert bool(d67.parameter_set) is True assert d67._representation is d89._representation + # index representation test + psi = Vacuum([0, 1]) + psi._index_representation[0] = ("Q", 0.1) + phi = psi.on([2, 3]) + assert phi._index_representation == psi._index_representation + def test_on_error(self): with pytest.raises(ValueError): Vacuum([1, 2]).on([3]) @@ -226,6 +257,12 @@ def test_add(self): d12 = d1 + d2 assert d12.representation == d1.representation + d2.representation + # checking if addition takes care of representations + psi = Ket.random([0]) + phi = Ket.random([0]) + + assert psi + (phi >> BtoQ([0])) == psi + phi + def test_sub(self): s1 = DisplacedSqueezed([1], x=1.0, y=0.5, r=0.1) s2 = DisplacedSqueezed([1], x=0.5, y=0.2, r=0.2) @@ -318,6 +355,40 @@ def test_matmul_is_associative(self): assert result1 == result3 assert result1 == result4 + def tets_matmul_respects_representations(self): + rho = Vacuum([0, 1]).dm() + psi = Vacuum([2]) + psi._index_representation[0] = ("Q", 0) + assert (rho @ psi.dual)._index_representation == { + 0: ("B", None), + 1: ("B", None), + 2: ("B", None), + 3: ("B", None), + 4: ("Q", 0.0), + } + + # the following example has no physical meaning and is just + # meant to check the logic of matmul in handling representations + rho = Vacuum([0]).dm() + psi = Vacuum([2]) + psi._index_representation[0] = ("Q", 0) + ch = Channel.random([0]) + ch._index_representation = {0: ("Q", 0), 1: ("Q", 1.5), 2: ("PS", 0.5), 3: ("B", None)} + assert (rho @ psi.dual @ ch)._index_representation == { + 0: ("Q", 0), + 1: ("PS", 0.5), + 2: ("Q", 0), + } + + rho = DM.random([0, 1]) + sigma = DM.random([0]) + assert rho >> BtoPS([0], 0) >> sigma.dual == rho >> sigma.dual + + rho = DM.random([0, 1]) + sigma = DM.random([0]) + + assert sigma >> (BtoPS([0], 0).dual >> rho.dual) == sigma >> rho.dual + def test_matmul_scalar(self): d0 = Dgate([0], x=0.1, y=0.1) result = d0 @ 0.8 @@ -329,6 +400,9 @@ def test_matmul_scalar(self): assert math.allclose(result2.representation.b, d0.representation.b) assert math.allclose(result2.representation.c, 0.8 * d0.representation.c) + psi = Ket.random([0]).to_quadrature() * 2 + assert psi._index_representation == {0: ("Q", 0)} + def test_rshift_all_bargmann(self): vac012 = Vacuum([0, 1, 2]) d0 = Dgate([0], x=0.1, y=0.1) @@ -401,6 +475,13 @@ def test_rshift_bargmann_and_fock(self, shape): settings.AUTOSHAPE_MAX = 50 + def test_rshift_mixed_representation(self): + psi = Ket.random([0]) + phi = Ket.random([0, 1]) + r1 = psi @ phi.dual + r2 = (psi @ BtoQ([0])) @ phi.dual + assert r1 == r2 + def test_rshift_error(self): vac012 = Vacuum([0, 1, 2]) d0 = Dgate([0], x=0.1, y=0.1) @@ -561,3 +642,25 @@ def __init__(self, rep, custom_modes): TypeError, match="MyComponent does not seem to have any wires construction method" ): cc._serialize() + + ## tests regarding multirepresentation: + def test_index_representation(self): + r""" + Tests the initialization and updating of index_representation dictionary + """ + + # testing initialization + assert DM.random([0, 1])._index_representation == {i: ("B", None) for i in range(4)} + + # testing update under BtoQ + phi = BtoQ([0]).dual @ Channel.random([0]) + assert phi._index_representation == { + 0: ("B", None), + 1: ("B", None), + 2: ("B", None), + 3: ("Q", 0.0), + } + + # testing update under BtoPS + rho = DM.random([0]) @ BtoPS([0], s=0.2) + assert rho._index_representation == {0: ("PS", 0.2), 1: ("PS", 0.2)} diff --git a/tests/test_lab_dev/test_circuit_components_utils.py b/tests/test_lab_dev/test_circuit_components_utils.py new file mode 100644 index 000000000..58f4659e2 --- /dev/null +++ b/tests/test_lab_dev/test_circuit_components_utils.py @@ -0,0 +1,311 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for circuit components utils.""" + +# pylint: disable=fixme, missing-function-docstring, protected-access, pointless-statement + +import numpy as np +import pytest + +from mrmustard import math, settings +from mrmustard.physics.triples import identity_Abc, displacement_map_s_parametrized_Abc +from mrmustard.physics.representations import Bargmann +from mrmustard.physics.bargmann import wigner_to_bargmann_rho +from mrmustard.physics.gaussian_integrals import ( + real_gaussian_integral, + complex_gaussian_integral_1, + complex_gaussian_integral_2, + join_Abc, + join_Abc_real, +) +from mrmustard.lab_dev.circuit_components_utils import TraceOut, BtoPS, BtoQ +from mrmustard.lab_dev.circuit_components import CircuitComponent +from mrmustard.lab_dev.states import Coherent, DM +from mrmustard.lab_dev.wires import Wires +from mrmustard.lab_dev.states import Ket +from mrmustard.lab_dev.states import Ket + + +# original settings +autocutoff_max0 = settings.AUTOCUTOFF_MAX_CUTOFF + + +class TestTraceOut: + r""" + Tests ``TraceOut`` objects. + """ + + @pytest.mark.parametrize("modes", [[0], [1, 2], [3, 4, 5]]) + def test_init(self, modes): + tr = TraceOut(modes) + + assert tr.name == "Tr" + assert tr.wires == Wires(modes_in_bra=set(modes), modes_in_ket=set(modes)) + assert tr.representation == Bargmann(*identity_Abc(len(modes))) + + def test_trace_out_bargmann_states(self): + state = Coherent([0, 1, 2], x=1) + + assert state >> TraceOut([0]) == Coherent([1, 2], x=1).dm() + assert state >> TraceOut([1, 2]) == Coherent([0], x=1).dm() + + trace = state >> TraceOut([0, 1, 2]) + assert np.isclose(trace, 1.0) + + def test_trace_out_complex(self): + cc = CircuitComponent.from_bargmann( + ( + np.array([[0.1 + 0.2j, 0.3 + 0.4j], [0.3 + 0.4j, 0.5 - 0.6j]]), + np.array([0.7 + 0.8j, -0.9 + 0.10j]), + 0.11 - 0.12j, + ), + modes_out_ket=[0], + modes_out_bra=[0], + ) + assert (cc >> TraceOut([0])).dtype == math.complex128 + + def test_trace_out_fock_states(self): + state = Coherent([0, 1, 2], x=1).to_fock(10) + assert state >> TraceOut([0]) == Coherent([1, 2], x=1).to_fock(7).dm() + assert state >> TraceOut([1, 2]) == Coherent([0], x=1).to_fock(7).dm() + + no_state = state >> TraceOut([0, 1, 2]) + assert np.isclose(no_state, 1.0) + + +class TestBtoPS: + r""" + Tests for the ``BtoPS`` class. + """ + + modes = [[0], [1, 2], [9, 7]] + s = [0, -1, 1] + + @pytest.mark.parametrize("modes,s", zip(modes, s)) + def test_init(self, modes, s): + dsmap = BtoPS(modes, s) # pylint: disable=protected-access + + assert dsmap.name == "BtoPS" + assert dsmap.modes == [modes] if not isinstance(modes, list) else sorted(modes) + + def test_representation(self): + rep1 = BtoPS(modes=[0], s=0).representation # pylint: disable=protected-access + A_correct, b_correct, c_correct = displacement_map_s_parametrized_Abc(s=0, n_modes=1) + assert math.allclose(rep1.A[0], A_correct) + assert math.allclose(rep1.b[0], b_correct) + assert math.allclose(rep1.c[0], c_correct) + + rep2 = BtoPS(modes=[5, 10], s=1).representation # pylint: disable=protected-access + A_correct, b_correct, c_correct = displacement_map_s_parametrized_Abc(s=1, n_modes=2) + assert math.allclose(rep2.A[0], A_correct) + assert math.allclose(rep2.b[0], b_correct) + assert math.allclose(rep2.c[0], c_correct) + + def testBtoPS_contraction_with_state(self): + # The init state cov and means comes from the random state 'state = Gaussian(1) >> Dgate([0.2], [0.3])' + state_cov = np.array([[0.32210229, -0.99732956], [-0.99732956, 6.1926484]]) + state_means = np.array([0.4, 0.6]) + A, b, c = wigner_to_bargmann_rho(state_cov, state_means) + state = DM.from_bargmann(modes=[0], triple=(A, b, c)) + state_bargmann_triple = state.bargmann_triple() + + # get new triple by right shift + state_after = state >> BtoPS(modes=[0], s=0) # pylint: disable=protected-access + A1, b1, c1 = state_after.bargmann_triple(batched=True) + + # get new triple by contraction + Ds_bargmann_triple = displacement_map_s_parametrized_Abc(s=0, n_modes=1) + A2, b2, c2 = complex_gaussian_integral_2( + state_bargmann_triple, Ds_bargmann_triple, idx1=[0, 1], idx2=[1, 3] + ) + + assert math.allclose(A1, A2) + assert math.allclose(b1, b2) + assert math.allclose(c1, c2) + + # The init state cov and means comes from the random state 'state = Gaussian(2) >> Dgate([0.2], [0.3])' + state_cov = np.array( + [ + [0.77969414, 0.10437996, 0.72706741, 0.29121535], + [0.10437996, 0.22846619, 0.1211067, 0.45983868], + [0.72706741, 0.1211067, 1.02215481, 0.16216756], + [0.29121535, 0.45983868, 0.16216756, 2.10006], + ] + ) + state_means = np.array([0.28284271, 0.0, 0.42426407, 0.0]) + A, b, c = wigner_to_bargmann_rho(state_cov, state_means) + state = DM.from_bargmann(modes=[0, 1], triple=(A, b, c)) + state_bargmann_triple = state.bargmann_triple() + + # get new triple by right shift + state_after = state >> BtoPS(modes=[0, 1], s=0) # pylint: disable=protected-access + A1, b1, c1 = state_after.bargmann_triple(batched=True) + + # get new triple by contraction + Ds_bargmann_triple = displacement_map_s_parametrized_Abc(s=0, n_modes=2) + A2, b2, c2 = complex_gaussian_integral_2( + state_bargmann_triple, + Ds_bargmann_triple, + idx1=[0, 1, 2, 3], + idx2=[2, 3, 6, 7], + ) + + assert math.allclose(A1, A2) + assert math.allclose(b1, b2) + assert math.allclose(c1, c2) + + psi = Ket.random([0]) + assert math.allclose((psi >> BtoPS([0], 1)).representation([0, 0]), [1.0]) + + def test_Bto_S_index_representation(self): + r""" + Tests the assingments of the index_representration of a BtoPS and its variants + """ + btops_1 = BtoPS([0], 0.1) + assert btops_1._index_representation == { + 0: ("PS", 0.1), + 1: ("B", None), + 2: ("PS", 0.1), + 3: ("B", None), + } + + btops_dual = BtoPS([0], 0.1).dual + assert btops_dual._index_representation == { + 0: ("B", None), + 1: ("PS", 0.1), + 2: ("B", None), + 3: ("PS", 0.1), + } + + btops_adjoint = btops_1.adjoint + assert btops_adjoint._index_representation == btops_1._index_representation + + btops_inv = btops_1.inverse() + assert btops_inv._index_representation == btops_dual._index_representation + + +class TestBtoQ: + r""" + Tests for the ``BtoQ`` class. + """ + + def testBtoQ_works_correctly_by_applying_it_twice_on_a_state(self): + A0 = np.array([[0.5, 0.3], [0.3, 0.5]]) + 0.0j + b0 = np.zeros(2, dtype=np.complex128) + c0 = 1.0 + 0j + + modes = [0, 1] + BtoQ_CC1 = BtoQ(modes, 0.0) + step1A, step1b, step1c = BtoQ_CC1.bargmann_triple(batched=False) + Ainter, binter, cinter = complex_gaussian_integral_1( + join_Abc((A0, b0, c0), (step1A, step1b, step1c)), + idx_z=[0, 1], + idx_zconj=[4, 5], + measure=-1, + ) + QtoBMap_CC2 = BtoQ(modes, 0.0).dual + step2A, step2b, step2c = QtoBMap_CC2.bargmann_triple(batched=False) + + new_A, new_b, new_c = join_Abc_real( + (Ainter[0], binter[0], cinter[0]), (step2A, step2b, step2c), [0, 1], [2, 3] + ) + + Af, bf, cf = real_gaussian_integral((new_A, new_b, new_c), idx=[0, 1]) + + assert math.allclose(A0, Af) + assert math.allclose(b0, bf) + assert math.allclose(c0, cf) + + A0 = np.array([[0.4895454]]) + b0 = np.zeros(1) + c0 = 1.0 + 0j + + modes = [0] + BtoQ_CC1 = BtoQ(modes, 0.0) + step1A, step1b, step1c = BtoQ_CC1.bargmann_triple(batched=False) + Ainter, binter, cinter = complex_gaussian_integral_1( + join_Abc((A0, b0, c0), (step1A, step1b, step1c)), + idx_z=[ + 0, + ], + idx_zconj=[2], + measure=-1, + ) + QtoBMap_CC2 = BtoQ(modes, 0.0).dual + step2A, step2b, step2c = QtoBMap_CC2.bargmann_triple(batched=False) + + new_A, new_b, new_c = join_Abc_real( + (Ainter[0], binter[0], cinter[0]), (step2A, step2b, step2c), [0], [1] + ) + + Af, bf, cf = real_gaussian_integral((new_A, new_b, new_c), idx=[0]) + + assert math.allclose(A0, Af) + assert math.allclose(b0, bf) + assert math.allclose(c0, cf) + + psi = Ket.random([0]) + phi = Ket.random([0]) + c1 = psi >> phi.dual + c2 = (psi >> BtoQ([0])) >> (phi >> BtoQ([0])).dual + assert math.allclose(c1, c2) + + def test_BtoQ_with_displacement(self): + "tests the BtoQ transformation with coherent states" + + def wavefunction_coh(alpha, quad, axis_angle): + "alpha = x+iy of coherent state, quad is quadrature variable, axis_angle of quad axis" + A = -1 / settings.HBAR + b = np.exp(-1j * axis_angle) * np.sqrt(2 / settings.HBAR) * alpha + c = ( + np.exp(-0.5 * np.abs(alpha) ** 2) + / np.power(np.pi * settings.HBAR, 0.25) + * np.exp(-0.5 * alpha**2 * np.exp(-2j * axis_angle)) + ) + return c * np.exp(0.5 * A * quad**2 + b * quad) + + x = np.random.random() + y = np.random.random() + axis_angle = np.random.random() + quad = np.random.random() + + state = Coherent([0], x, y) + wavefunction = (state >> BtoQ([0], axis_angle)).representation.ansatz + + assert np.allclose(wavefunction(quad), wavefunction_coh(x + 1j * y, quad, axis_angle)) + + def test_BtoQ_index_representatioin(self): + "Tests whether BtoQ, and its adjopint/dual and their combinations have the right representation" + + btoq_1 = BtoQ([0]) + assert btoq_1._index_representation == {0: ("Q", 0), 1: ("B", None)} + + btoq_dual = btoq_1.dual + assert btoq_dual._index_representation == {0: ("B", None), 1: ("Q", 0)} + + btoq_adjoint = btoq_1.adjoint + assert btoq_adjoint._index_representation == {0: ("Q", 0), 1: ("B", None)} + + btoq_dual_adj = btoq_1.dual.adjoint + assert btoq_dual_adj._index_representation == {0: ("B", None), 1: ("Q", 0)} + + btoq_adj_dual = btoq_1.adjoint.dual + assert btoq_adj_dual._index_representation == {0: ("B", None), 1: ("Q", 0)} + + btoq_inv = btoq_1.inverse() + assert btoq_inv._index_representation == btoq_dual._index_representation + + btoq_adj_inv = btoq_1.adjoint.inverse() + assert btoq_adj_inv._index_representation == {0: ("B", None), 1: ("Q", 0)} diff --git a/tests/test_lab_dev/test_circuit_components_utils/test_b_to_ps.py b/tests/test_lab_dev/test_circuit_components_utils/test_b_to_ps.py index 09308bc04..a50350b29 100644 --- a/tests/test_lab_dev/test_circuit_components_utils/test_b_to_ps.py +++ b/tests/test_lab_dev/test_circuit_components_utils/test_b_to_ps.py @@ -49,7 +49,7 @@ def test_adjoint(self): kets = btops.wires.ket.indices assert adjoint_btops.representation == btops.representation.reorder(kets + bras).conj() assert adjoint_btops.wires == btops.wires.adjoint - assert adjoint_btops.s == btops.s + assert adjoint_btops.s.value == btops.s.value assert isinstance(adjoint_btops, BtoPS) def test_dual(self): @@ -62,7 +62,7 @@ def test_dual(self): ob = btops.wires.bra.output.indices assert dual_btops.representation == btops.representation.reorder(ib + ob + ik + ok).conj() assert dual_btops.wires == btops.wires.dual - assert dual_btops.s == btops.s + assert dual_btops.s.value == btops.s.value assert isinstance(dual_btops, BtoPS) def test_inverse(self): diff --git a/tests/test_lab_dev/test_circuit_components_utils/test_b_to_q.py b/tests/test_lab_dev/test_circuit_components_utils/test_b_to_q.py index 3eaebf271..3cc1f7349 100644 --- a/tests/test_lab_dev/test_circuit_components_utils/test_b_to_q.py +++ b/tests/test_lab_dev/test_circuit_components_utils/test_b_to_q.py @@ -40,7 +40,7 @@ def test_adjoint(self): kets = btoq.wires.ket.indices assert adjoint_btoq.representation == btoq.representation.reorder(kets).conj() assert adjoint_btoq.wires == btoq.wires.adjoint - assert adjoint_btoq.phi == btoq.phi + assert adjoint_btoq.phi.value == btoq.phi.value assert isinstance(adjoint_btoq, BtoQ) def test_dual(self): @@ -51,7 +51,7 @@ def test_dual(self): ik = dual_btoq.wires.ket.input.indices assert dual_btoq.representation == btoq.representation.reorder(ik + ok).conj() assert dual_btoq.wires == btoq.wires.dual - assert dual_btoq.phi == btoq.phi + assert dual_btoq.phi.value == btoq.phi.value assert isinstance(dual_btoq, BtoQ) def test_inverse(self): diff --git a/tests/test_lab_dev/test_states/test_states_base.py b/tests/test_lab_dev/test_states/test_states_base.py index 396a5f401..3d0466722 100644 --- a/tests/test_lab_dev/test_states/test_states_base.py +++ b/tests/test_lab_dev/test_states/test_states_base.py @@ -39,6 +39,7 @@ from mrmustard.lab_dev.transformations import Attenuator, Dgate, Sgate from mrmustard.lab_dev.wires import Wires from mrmustard.widgets import state as state_widget +from mrmustard.lab_dev.circuit_components_utils import BtoQ, BtoPS # original settings autocutoff_max0 = int(settings.AUTOCUTOFF_MAX_CUTOFF) @@ -174,6 +175,9 @@ def test_to_from_quadrature(self): assert math.allclose(btest2, b0) assert math.allclose(ctest2, c0) + psi = Ket.random([0]) + assert psi.to_quadrature()._index_representation == {0: ("Q", 0)} + def test_L2_norm(self): state = Coherent([0], x=1) assert state.L2_norm == 1 @@ -580,6 +584,9 @@ def test_to_from_quadrature(self): assert np.allclose(btest2, b0) assert np.allclose(ctest2, c0) + rho = DM.random([0]) + assert rho.to_quadrature()._index_representation == {0: ("Q", 0), 1: ("Q", 0)} + def test_L2_norms(self): state = Coherent([0], x=1).dm() + Coherent([0], x=-1).dm() # incoherent assert len(state._L2_norms) == 2 @@ -789,6 +796,9 @@ def test_rshift(self): assert isinstance(dm >> Coherent([0], 1).dual, DM) assert isinstance(dm >> Coherent([0], 1).dm().dual, DM) + rho = DM.random([0, 1]) >> BtoPS([0], 0) >> BtoQ([1]) + assert rho._index_representation == {0: ("PS", 0), 1: ("Q", 0), 2: ("PS", 0), 3: ("Q", 0)} + @pytest.mark.parametrize("modes", [[5], [1, 2]]) def test_random(self, modes): m = len(modes) diff --git a/tests/test_lab_dev/test_wires.py b/tests/test_lab_dev/test_wires.py index c1d4af695..ad70cfa2f 100644 --- a/tests/test_lab_dev/test_wires.py +++ b/tests/test_lab_dev/test_wires.py @@ -115,6 +115,10 @@ def test_ids_dicts(self): assert w.input.ids_dicts == d assert w.input.bra.ids_dicts == d + def test_index_to_mode_dict(self): + w = Wires({0, 1}, {1}, {2}, {0, 4}) + assert w.index_to_mode_dict == {0: 0, 1: 1, 2: 1, 3: 2, 4: 0, 5: 4} + def test_adjoint(self): w = Wires({0, 1, 2}, {3, 4, 5}, {6, 7}, {8}) w_adj = w.adjoint