Skip to content

Commit

Permalink
Rich wires (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
ziofil authored Jan 13, 2025
1 parent 4d61a52 commit a4dfcd8
Show file tree
Hide file tree
Showing 39 changed files with 616 additions and 673 deletions.
11 changes: 6 additions & 5 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]:
if "name" in params: # assume abstract type, serialize the representation
ansatz_cls = type(self.ansatz)
serializable["name"] = self.name
serializable["wires"] = self.wires.sorted_args
serializable["wires"] = tuple(tuple(a) for a in self.wires.args)
serializable["ansatz_cls"] = f"{ansatz_cls.__module__}.{ansatz_cls.__qualname__}"
return serializable, self.ansatz.to_dict()

Expand Down Expand Up @@ -112,7 +112,9 @@ def _deserialize(cls, data: dict) -> CircuitComponent:
if "ansatz_cls" in data:
ansatz_cls, wires, name = map(data.pop, ["ansatz_cls", "wires", "name"])
ansatz = locate(ansatz_cls).from_dict(data)
return cls._from_attributes(Representation(ansatz, Wires(*map(set, wires))), name=name)
return cls._from_attributes(
Representation(ansatz, Wires(*tuple(set(m) for m in wires))), name=name
)

return cls(**data)

Expand Down Expand Up @@ -500,7 +502,6 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent:
>>> d_fock = d.to_fock(shape=3)
>>> assert d_fock.name == d.name
>>> assert d_fock.wires == d.wires
>>> assert isinstance(d_fock.ansatz, ArrayAnsatz)
Args:
Expand Down Expand Up @@ -683,9 +684,9 @@ def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitCompone
if only_ket or only_bra or both_sides:
ret = self @ other
elif self_needs_bra or self_needs_ket:
ret = (self.adjoint @ self) @ other
ret = self.adjoint @ (self @ other)
elif other_needs_bra or other_needs_ket:
ret = self @ (other @ other.adjoint)
ret = (self @ other.adjoint) @ other
else:
msg = f"``>>`` not supported between {self} and {other} because it's not clear "
msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components."
Expand Down
9 changes: 4 additions & 5 deletions mrmustard/lab_dev/circuit_components_utils/b_to_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..transformations.base import Map
from ...physics.ansatz import PolyExpAnsatz
from ...physics.representations import RepEnum
from ...physics.wires import ReprEnum
from ..utils import make_parameter

__all__ = ["BtoPS"]
Expand Down Expand Up @@ -55,10 +55,9 @@ def __init__(
n_modes=len(modes),
),
).representation
for i in self.wires.input.indices:
self.representation._idx_reps[i] = (RepEnum.BARGMANN, None)
for i in self.wires.output.indices:
self.representation._idx_reps[i] = (RepEnum.PHASESPACE, float(self.parameters.s.value))
for w in self.representation.wires.output.wires:
w.repr = ReprEnum.CHARACTERISTIC
w.repr_params_func = lambda: self.parameters.s

def inverse(self):
ret = BtoPS(self.modes, self.parameters.s)
Expand Down
14 changes: 6 additions & 8 deletions mrmustard/lab_dev/circuit_components_utils/b_to_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..transformations.base import Operation
from ...physics.ansatz import PolyExpAnsatz
from ...physics.representations import RepEnum
from ...physics.wires import ReprEnum
from ..utils import make_parameter

__all__ = ["BtoQ"]
Expand Down Expand Up @@ -53,13 +53,11 @@ def __init__(
fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=self.parameters.phi
),
).representation
for i in self.wires.input.indices:
self.representation._idx_reps[i] = (RepEnum.BARGMANN, None)
for i in self.wires.output.indices:
self.representation._idx_reps[i] = (
RepEnum.QUADRATURE,
float(self.parameters.phi.value),
)
for w in self.representation.wires.input.wires:
w.repr = ReprEnum.BARGMANN
for w in self.representation.wires.output.wires:
w.repr = ReprEnum.QUADRATURE
w.repr_params_func = lambda: self.parameters.phi

def inverse(self):
ret = BtoQ(self.modes, self.parameters.phi)
Expand Down
3 changes: 2 additions & 1 deletion mrmustard/lab_dev/circuit_components_utils/trace_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..circuit_components import CircuitComponent
from ...physics.ansatz import PolyExpAnsatz
from ...physics.representations import Representation
from ...physics.wires import Wires

__all__ = ["TraceOut"]

Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
super().__init__(
Representation(
PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)),
[(), modes, (), modes],
Wires(set(), set(modes), set(), set(modes)),
),
name="Tr",
)
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def serialize(self, filestem: str = None):
@classmethod
def deserialize(cls, data: dict) -> Circuit:
r"""Deserialize a Circuit."""
comps, path = map(data.pop, ("components", "path"))
comps, path = data.pop("components"), data.pop("path")

for k, v in data.items():
kwarg, i = k.split(":")
Expand Down
27 changes: 16 additions & 11 deletions mrmustard/lab_dev/states/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

from __future__ import annotations
from typing import Sequence
from typing import Collection

from itertools import product
import warnings
Expand All @@ -30,7 +30,7 @@
from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho
from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2
from mrmustard.physics.representations import Representation
from mrmustard.physics.wires import Wires
from mrmustard.physics.wires import Wires, ReprEnum
from mrmustard.utils.typing import ComplexMatrix, ComplexVector, ComplexTensor, RealVector

from .base import State, _validate_operator, OperatorType
Expand Down Expand Up @@ -111,7 +111,7 @@ def _purities(self) -> RealVector:
@classmethod
def from_bargmann(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple[ComplexMatrix, ComplexVector, complex],
name: str | None = None,
) -> State:
Expand All @@ -120,7 +120,7 @@ def from_bargmann(
@classmethod
def from_fock(
cls,
modes: Sequence[int],
modes: Collection[int],
array: ComplexTensor,
name: str | None = None,
batched: bool = False,
Expand All @@ -130,22 +130,27 @@ def from_fock(
@classmethod
def from_ansatz(
cls,
modes: Sequence[int],
modes: Collection[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> State:
if not isinstance(modes, set) and sorted(modes) != list(modes):
raise ValueError(f"Modes must be sorted. got {modes}")
modes = set(modes)
if ansatz and ansatz.num_vars != 2 * len(modes):
raise ValueError(
f"Expected an ansatz with {2*len(modes)} variables, found {ansatz.num_vars}."
)
wires = Wires(modes_out_bra=modes, modes_out_ket=modes)
wires = Wires(modes_out_bra=set(modes), modes_out_ket=set(modes))
if isinstance(ansatz, ArrayAnsatz):
for w in wires:
w.repr = ReprEnum.FOCK
return DM(Representation(ansatz, wires), name)

@classmethod
def from_phase_space(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple,
name: str | None = None,
s: float = 0, # pylint: disable=unused-argument
Expand Down Expand Up @@ -175,7 +180,7 @@ def from_phase_space(
@classmethod
def from_quadrature(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple[ComplexMatrix, ComplexVector, complex],
phi: float = 0.0,
name: str | None = None,
Expand Down Expand Up @@ -203,7 +208,7 @@ def from_quadrature(
return DM.from_ansatz(modes, (Q >> QtoB).ansatz, name)

@classmethod
def random(cls, modes: Sequence[int], m: int | None = None, max_r: float = 1.0) -> DM:
def random(cls, modes: Collection[int], m: int | None = None, max_r: float = 1.0) -> DM:
r"""
Samples a random density matrix. The final state has zero displacement.
Expand Down Expand Up @@ -381,13 +386,13 @@ def _ipython_display_(self): # pragma: no cover
is_fock = isinstance(self.ansatz, ArrayAnsatz)
display(widgets.state(self, is_ket=False, is_fock=is_fock))

def __getitem__(self, modes: int | Sequence[int]) -> State:
def __getitem__(self, modes: int | Collection[int]) -> State:
r"""
Traces out all the modes except those given.
The result is returned with modes in increasing order.
"""
if isinstance(modes, int):
modes = [modes]
modes = {modes}
modes = set(modes)

if not modes.issubset(self.modes):
Expand Down
23 changes: 14 additions & 9 deletions mrmustard/lab_dev/states/ket.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from __future__ import annotations

from typing import Sequence
from typing import Collection
from itertools import product
import warnings
import numpy as np
Expand All @@ -30,7 +30,7 @@
from mrmustard.physics.bargmann_utils import wigner_to_bargmann_psi
from mrmustard.physics.gaussian import purity
from mrmustard.physics.representations import Representation
from mrmustard.physics.wires import Wires
from mrmustard.physics.wires import Wires, ReprEnum
from mrmustard.utils.typing import (
ComplexMatrix,
ComplexVector,
Expand Down Expand Up @@ -90,7 +90,7 @@ def _probabilities(self) -> RealVector:
@classmethod
def from_bargmann(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple[ComplexMatrix, ComplexVector, complex],
name: str | None = None,
) -> State:
Expand All @@ -99,7 +99,7 @@ def from_bargmann(
@classmethod
def from_fock(
cls,
modes: Sequence[int],
modes: Collection[int],
array: ComplexTensor,
name: str | None = None,
batched: bool = False,
Expand All @@ -109,22 +109,27 @@ def from_fock(
@classmethod
def from_ansatz(
cls,
modes: Sequence[int],
modes: Collection[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> State:
if not isinstance(modes, set) and sorted(modes) != list(modes):
raise ValueError(f"Modes must be sorted. Got {modes}")
modes = set(modes)
if ansatz and ansatz.num_vars != len(modes):
raise ValueError(
f"Expected an ansatz with {len(modes)} variables, found {ansatz.num_vars}."
)
wires = Wires(modes_out_ket=modes)
if isinstance(ansatz, ArrayAnsatz):
for w in wires.quantum_wires:
w.repr = ReprEnum.FOCK
return Ket(Representation(ansatz, wires), name)

@classmethod
def from_phase_space(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple,
name: str | None = None,
atol_purity: float | None = 1e-5,
Expand All @@ -147,7 +152,7 @@ def from_phase_space(
@classmethod
def from_quadrature(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple[ComplexMatrix, ComplexVector, complex],
phi: float = 0.0,
name: str | None = None,
Expand All @@ -158,7 +163,7 @@ def from_quadrature(
return Ket.from_ansatz(modes, (Q >> QtoB).ansatz, name)

@classmethod
def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Ket:
def random(cls, modes: Collection[int], max_r: float = 1.0) -> Ket:
r"""
Generates a random zero displacement state.
Expand Down Expand Up @@ -341,7 +346,7 @@ def _ipython_display_(self): # pragma: no cover
is_fock = isinstance(self.ansatz, ArrayAnsatz)
display(widgets.state(self, is_ket=True, is_fock=is_fock))

def __getitem__(self, modes: int | Sequence[int]) -> State:
def __getitem__(self, modes: int | Collection[int]) -> State:
r"""
Reduced density matrix obtained by tracing out all the modes except those in the given
``modes``. Note that the result is returned with modes in increasing order.
Expand Down
5 changes: 5 additions & 0 deletions mrmustard/lab_dev/states/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Sequence

from mrmustard.physics.ansatz import ArrayAnsatz
from mrmustard.physics.wires import ReprEnum
from mrmustard.physics.fock_utils import fock_state
from .ket import Ket
from ..utils import make_parameter, reshape_params
Expand Down Expand Up @@ -81,3 +82,7 @@ def __init__(
self.short_name = [str(int(n)) for n in self.parameters.n.value]
for i, cutoff in enumerate(self.parameters.cutoffs.value):
self.manual_shape[i] = int(cutoff) + 1

for w in self.representation.wires.output.wires:
w.repr = ReprEnum.FOCK
w.repr_params_func = lambda w=w: [int(self.parameters.n.value[w.index])]
8 changes: 8 additions & 0 deletions mrmustard/lab_dev/states/quadrature_eigenstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from mrmustard.physics.ansatz import PolyExpAnsatz
from mrmustard.physics import triples
from mrmustard.physics.wires import ReprEnum
from .ket import Ket
from ..utils import make_parameter, reshape_params

Expand Down Expand Up @@ -77,6 +78,13 @@ def __init__(
),
).representation

for w in self.representation.wires.output.wires:
w.repr = ReprEnum.QUADRATURE
w.repr_params_func = lambda w=w: [
self.parameters.x.value[w.index],
self.parameters.phi.value[w.index],
]

@property
def L2_norm(self):
r"""
Expand Down
1 change: 0 additions & 1 deletion mrmustard/lab_dev/states/sauron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from mrmustard.lab_dev.states.ket import Ket
from mrmustard.physics.ansatz import PolyExpAnsatz
from mrmustard.physics import triples

from ..utils import make_parameter


Expand Down
Loading

0 comments on commit a4dfcd8

Please sign in to comment.