Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rich wires #523

Merged
merged 53 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
12b1118
new wires
ziofil Nov 14, 2024
4780a68
Merge branch 'develop' into rich_wires
ziofil Nov 14, 2024
c57a72f
updates
ziofil Nov 14, 2024
a63c66f
better format
ziofil Nov 14, 2024
a5dd993
replace old wires
ziofil Nov 14, 2024
1ab2057
fix codefactor issues
ziofil Nov 14, 2024
6b0f074
fixes wires tests
ziofil Nov 14, 2024
c3ea39b
default wires for components
ziofil Nov 15, 2024
9846ecd
fix set issue
ziofil Nov 15, 2024
5632af6
fix attribute name
ziofil Nov 15, 2024
5d84f9f
update representation and tests
ziofil Nov 15, 2024
dbeec31
fixed permutation
ziofil Nov 15, 2024
90333ff
updated gitignore
ziofil Nov 15, 2024
ea6ed3f
sync
ziofil Nov 18, 2024
753f67d
sync
ziofil Nov 19, 2024
1d2d23b
fix order and remove print
ziofil Nov 19, 2024
bd14277
sync
ziofil Nov 19, 2024
52effa3
fixed serialization
ziofil Nov 19, 2024
6c51310
only sets
ziofil Nov 20, 2024
3e0dbe5
fix doctest
ziofil Nov 21, 2024
7309aed
codefactor
ziofil Nov 25, 2024
7134db0
imports
ziofil Nov 25, 2024
cda49df
imports
ziofil Nov 25, 2024
36efac5
pylint
ziofil Nov 25, 2024
2cab118
pylint
ziofil Nov 25, 2024
f9e04d9
fixed import
ziofil Nov 25, 2024
2be6329
remove unused code
ziofil Nov 25, 2024
d9f5e85
removed commented code
ziofil Nov 25, 2024
321cf7b
removed unused code
ziofil Nov 25, 2024
00e0eac
fix wires
ziofil Nov 26, 2024
4e0b476
add test
ziofil Nov 26, 2024
da7e8be
removed import
ziofil Nov 26, 2024
d7a23db
iterable -> sequence
ziofil Nov 26, 2024
e527e23
test Number wires
ziofil Nov 26, 2024
e3dd1cb
fix attribute name
ziofil Nov 27, 2024
a416407
fix attr name
ziofil Nov 27, 2024
cd142a4
fixed wires
ziofil Nov 27, 2024
ed38377
better pylintrc
ziofil Nov 27, 2024
2ef30b2
Merge branch 'develop' into rich_wires
ziofil Nov 27, 2024
10121c9
better pylintrc
ziofil Nov 27, 2024
b8ac8f7
more tests
ziofil Nov 27, 2024
47384ad
pylint kwuygfkuqygrejhf
ziofil Nov 27, 2024
6ed1588
Sequence -> Collection
ziofil Nov 27, 2024
3b6be98
removed pylint disable
ziofil Nov 28, 2024
5e58350
lab_dev tests pass
ziofil Dec 18, 2024
c6c90d7
Merge branch 'develop' into rich_wires
ziofil Dec 18, 2024
db08c68
ids
ziofil Dec 18, 2024
c79abfa
removed slots
ziofil Dec 18, 2024
92f9b17
dict -> set
ziofil Dec 18, 2024
27fa3f9
Merge branch 'develop' into rich_wires
ziofil Jan 7, 2025
ce57db3
utilize the updated parameter access
ziofil Jan 7, 2025
b9cd0f3
Refactor representation parameters to use functions instead of static…
ziofil Jan 10, 2025
7b4494b
fix codefactor issues
ziofil Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]:
if "name" in params: # assume abstract type, serialize the representation
ansatz_cls = type(self.ansatz)
serializable["name"] = self.name
serializable["wires"] = self.wires.sorted_args
serializable["wires"] = tuple(tuple(a) for a in self.wires.args)
serializable["ansatz_cls"] = f"{ansatz_cls.__module__}.{ansatz_cls.__qualname__}"
return serializable, self.ansatz.to_dict()

Expand Down Expand Up @@ -112,7 +112,9 @@ def _deserialize(cls, data: dict) -> CircuitComponent:
if "ansatz_cls" in data:
ansatz_cls, wires, name = map(data.pop, ["ansatz_cls", "wires", "name"])
ansatz = locate(ansatz_cls).from_dict(data)
return cls._from_attributes(Representation(ansatz, Wires(*map(set, wires))), name=name)
return cls._from_attributes(
Representation(ansatz, Wires(*tuple(set(m) for m in wires))), name=name
)

return cls(**data)

Expand Down Expand Up @@ -500,7 +502,6 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent:
>>> d_fock = d.to_fock(shape=3)

>>> assert d_fock.name == d.name
>>> assert d_fock.wires == d.wires
>>> assert isinstance(d_fock.ansatz, ArrayAnsatz)

Args:
Expand Down Expand Up @@ -683,9 +684,9 @@ def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitCompone
if only_ket or only_bra or both_sides:
ret = self @ other
elif self_needs_bra or self_needs_ket:
ret = (self.adjoint @ self) @ other
ret = self.adjoint @ (self @ other)
ziofil marked this conversation as resolved.
Show resolved Hide resolved
elif other_needs_bra or other_needs_ket:
ret = self @ (other @ other.adjoint)
ret = (self @ other.adjoint) @ other
else:
msg = f"``>>`` not supported between {self} and {other} because it's not clear "
msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components."
Expand Down
9 changes: 4 additions & 5 deletions mrmustard/lab_dev/circuit_components_utils/b_to_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..transformations.base import Map
from ...physics.ansatz import PolyExpAnsatz
from ...physics.representations import RepEnum
from ...physics.wires import ReprEnum
from ..utils import make_parameter

__all__ = ["BtoPS"]
Expand Down Expand Up @@ -55,10 +55,9 @@ def __init__(
n_modes=len(modes),
),
).representation
for i in self.wires.input.indices:
self.representation._idx_reps[i] = (RepEnum.BARGMANN, None)
for i in self.wires.output.indices:
self.representation._idx_reps[i] = (RepEnum.PHASESPACE, float(self.parameters.s.value))
for w in self.representation.wires.output.wires:
w.repr = ReprEnum.CHARACTERISTIC
w.repr_params_func = lambda: self.parameters.s

def inverse(self):
ret = BtoPS(self.modes, self.parameters.s)
Expand Down
14 changes: 6 additions & 8 deletions mrmustard/lab_dev/circuit_components_utils/b_to_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..transformations.base import Operation
from ...physics.ansatz import PolyExpAnsatz
from ...physics.representations import RepEnum
from ...physics.wires import ReprEnum
from ..utils import make_parameter

__all__ = ["BtoQ"]
Expand Down Expand Up @@ -53,13 +53,11 @@ def __init__(
fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=self.parameters.phi
),
).representation
for i in self.wires.input.indices:
self.representation._idx_reps[i] = (RepEnum.BARGMANN, None)
for i in self.wires.output.indices:
self.representation._idx_reps[i] = (
RepEnum.QUADRATURE,
float(self.parameters.phi.value),
)
for w in self.representation.wires.input.wires:
w.repr = ReprEnum.BARGMANN
for w in self.representation.wires.output.wires:
w.repr = ReprEnum.QUADRATURE
w.repr_params_func = lambda: self.parameters.phi

def inverse(self):
ret = BtoQ(self.modes, self.parameters.phi)
Expand Down
3 changes: 2 additions & 1 deletion mrmustard/lab_dev/circuit_components_utils/trace_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..circuit_components import CircuitComponent
from ...physics.ansatz import PolyExpAnsatz
from ...physics.representations import Representation
from ...physics.wires import Wires

__all__ = ["TraceOut"]

Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
super().__init__(
Representation(
PolyExpAnsatz.from_function(fn=triples.identity_Abc, n_modes=len(modes)),
[(), modes, (), modes],
Wires(set(), set(modes), set(), set(modes)),
),
name="Tr",
)
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def serialize(self, filestem: str = None):
@classmethod
def deserialize(cls, data: dict) -> Circuit:
r"""Deserialize a Circuit."""
comps, path = map(data.pop, ("components", "path"))
comps, path = data.pop("components"), data.pop("path")
ziofil marked this conversation as resolved.
Show resolved Hide resolved

for k, v in data.items():
kwarg, i = k.split(":")
Expand Down
27 changes: 16 additions & 11 deletions mrmustard/lab_dev/states/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

from __future__ import annotations
from typing import Sequence
from typing import Collection

from itertools import product
import warnings
Expand All @@ -30,7 +30,7 @@
from mrmustard.physics.bargmann_utils import wigner_to_bargmann_rho
from mrmustard.physics.gaussian_integrals import complex_gaussian_integral_2
from mrmustard.physics.representations import Representation
from mrmustard.physics.wires import Wires
from mrmustard.physics.wires import Wires, ReprEnum
from mrmustard.utils.typing import ComplexMatrix, ComplexVector, ComplexTensor, RealVector

from .base import State, _validate_operator, OperatorType
Expand Down Expand Up @@ -111,7 +111,7 @@ def _purities(self) -> RealVector:
@classmethod
def from_bargmann(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple[ComplexMatrix, ComplexVector, complex],
name: str | None = None,
) -> State:
Expand All @@ -120,7 +120,7 @@ def from_bargmann(
@classmethod
def from_fock(
cls,
modes: Sequence[int],
modes: Collection[int],
array: ComplexTensor,
name: str | None = None,
batched: bool = False,
Expand All @@ -130,22 +130,27 @@ def from_fock(
@classmethod
def from_ansatz(
cls,
modes: Sequence[int],
modes: Collection[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> State:
if not isinstance(modes, set) and sorted(modes) != list(modes):
raise ValueError(f"Modes must be sorted. got {modes}")
ziofil marked this conversation as resolved.
Show resolved Hide resolved
modes = set(modes)
if ansatz and ansatz.num_vars != 2 * len(modes):
raise ValueError(
f"Expected an ansatz with {2*len(modes)} variables, found {ansatz.num_vars}."
)
wires = Wires(modes_out_bra=modes, modes_out_ket=modes)
wires = Wires(modes_out_bra=set(modes), modes_out_ket=set(modes))
if isinstance(ansatz, ArrayAnsatz):
for w in wires:
w.repr = ReprEnum.FOCK
return DM(Representation(ansatz, wires), name)

@classmethod
def from_phase_space(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple,
name: str | None = None,
s: float = 0, # pylint: disable=unused-argument
Expand Down Expand Up @@ -175,7 +180,7 @@ def from_phase_space(
@classmethod
def from_quadrature(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple[ComplexMatrix, ComplexVector, complex],
phi: float = 0.0,
name: str | None = None,
Expand Down Expand Up @@ -203,7 +208,7 @@ def from_quadrature(
return DM.from_ansatz(modes, (Q >> QtoB).ansatz, name)

@classmethod
def random(cls, modes: Sequence[int], m: int | None = None, max_r: float = 1.0) -> DM:
def random(cls, modes: Collection[int], m: int | None = None, max_r: float = 1.0) -> DM:
r"""
Samples a random density matrix. The final state has zero displacement.

Expand Down Expand Up @@ -381,13 +386,13 @@ def _ipython_display_(self): # pragma: no cover
is_fock = isinstance(self.ansatz, ArrayAnsatz)
display(widgets.state(self, is_ket=False, is_fock=is_fock))

def __getitem__(self, modes: int | Sequence[int]) -> State:
def __getitem__(self, modes: int | Collection[int]) -> State:
r"""
Traces out all the modes except those given.
The result is returned with modes in increasing order.
"""
if isinstance(modes, int):
modes = [modes]
modes = {modes}
modes = set(modes)

if not modes.issubset(self.modes):
Expand Down
23 changes: 14 additions & 9 deletions mrmustard/lab_dev/states/ket.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from __future__ import annotations

from typing import Sequence
from typing import Collection
from itertools import product
import warnings
import numpy as np
Expand All @@ -30,7 +30,7 @@
from mrmustard.physics.bargmann_utils import wigner_to_bargmann_psi
from mrmustard.physics.gaussian import purity
from mrmustard.physics.representations import Representation
from mrmustard.physics.wires import Wires
from mrmustard.physics.wires import Wires, ReprEnum
from mrmustard.utils.typing import (
ComplexMatrix,
ComplexVector,
Expand Down Expand Up @@ -90,7 +90,7 @@ def _probabilities(self) -> RealVector:
@classmethod
def from_bargmann(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple[ComplexMatrix, ComplexVector, complex],
name: str | None = None,
) -> State:
Expand All @@ -99,7 +99,7 @@ def from_bargmann(
@classmethod
def from_fock(
cls,
modes: Sequence[int],
modes: Collection[int],
array: ComplexTensor,
name: str | None = None,
batched: bool = False,
Expand All @@ -109,22 +109,27 @@ def from_fock(
@classmethod
def from_ansatz(
cls,
modes: Sequence[int],
modes: Collection[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> State:
if not isinstance(modes, set) and sorted(modes) != list(modes):
raise ValueError(f"Modes must be sorted. Got {modes}")
modes = set(modes)
if ansatz and ansatz.num_vars != len(modes):
raise ValueError(
f"Expected an ansatz with {len(modes)} variables, found {ansatz.num_vars}."
)
wires = Wires(modes_out_ket=modes)
if isinstance(ansatz, ArrayAnsatz):
for w in wires.quantum_wires:
w.repr = ReprEnum.FOCK
return Ket(Representation(ansatz, wires), name)

@classmethod
def from_phase_space(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple,
name: str | None = None,
atol_purity: float | None = 1e-5,
Expand All @@ -147,7 +152,7 @@ def from_phase_space(
@classmethod
def from_quadrature(
cls,
modes: Sequence[int],
modes: Collection[int],
triple: tuple[ComplexMatrix, ComplexVector, complex],
phi: float = 0.0,
name: str | None = None,
Expand All @@ -158,7 +163,7 @@ def from_quadrature(
return Ket.from_ansatz(modes, (Q >> QtoB).ansatz, name)

@classmethod
def random(cls, modes: Sequence[int], max_r: float = 1.0) -> Ket:
def random(cls, modes: Collection[int], max_r: float = 1.0) -> Ket:
r"""
Generates a random zero displacement state.

Expand Down Expand Up @@ -341,7 +346,7 @@ def _ipython_display_(self): # pragma: no cover
is_fock = isinstance(self.ansatz, ArrayAnsatz)
display(widgets.state(self, is_ket=True, is_fock=is_fock))

def __getitem__(self, modes: int | Sequence[int]) -> State:
def __getitem__(self, modes: int | Collection[int]) -> State:
r"""
Reduced density matrix obtained by tracing out all the modes except those in the given
``modes``. Note that the result is returned with modes in increasing order.
Expand Down
5 changes: 5 additions & 0 deletions mrmustard/lab_dev/states/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Sequence

from mrmustard.physics.ansatz import ArrayAnsatz
from mrmustard.physics.wires import ReprEnum
from mrmustard.physics.fock_utils import fock_state
from .ket import Ket
from ..utils import make_parameter, reshape_params
Expand Down Expand Up @@ -81,3 +82,7 @@ def __init__(
self.short_name = [str(int(n)) for n in self.parameters.n.value]
for i, cutoff in enumerate(self.parameters.cutoffs.value):
self.manual_shape[i] = int(cutoff) + 1

for w in self.representation.wires.output.wires:
w.repr = ReprEnum.FOCK
w.repr_params_func = lambda w=w: [int(self.parameters.n.value[w.index])]
8 changes: 8 additions & 0 deletions mrmustard/lab_dev/states/quadrature_eigenstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from mrmustard.physics.ansatz import PolyExpAnsatz
from mrmustard.physics import triples
from mrmustard.physics.wires import ReprEnum
from .ket import Ket
from ..utils import make_parameter, reshape_params

Expand Down Expand Up @@ -77,6 +78,13 @@ def __init__(
),
).representation

for w in self.representation.wires.output.wires:
w.repr = ReprEnum.QUADRATURE
w.repr_params_func = lambda w=w: [
self.parameters.x.value[w.index],
self.parameters.phi.value[w.index],
]

@property
def L2_norm(self):
r"""
Expand Down
1 change: 0 additions & 1 deletion mrmustard/lab_dev/states/sauron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from mrmustard.lab_dev.states.ket import Ket
from mrmustard.physics.ansatz import PolyExpAnsatz
from mrmustard.physics import triples

from ..utils import make_parameter


Expand Down
Loading
Loading