Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasted gaussian integral #494

Closed
wants to merge 13 commits into from
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.venv

.DS_Store
.vscode
*.egg-info*
Expand All @@ -19,3 +21,5 @@ doc/code/api/*
coverage.xml
.coverage
/.serialize_cache/

.venv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"data":{"layout-restorer:data":{"main":{"dock":{"type":"tab-area","currentIndex":1,"widgets":["notebook:Untitled1.ipynb","terminal:1"]},"current":"notebook:Untitled8.ipynb"},"down":{"size":0,"widgets":[]},"left":{"collapsed":true,"visible":false,"widgets":["filebrowser","running-sessions","@jupyterlab/toc:plugin","extensionmanager.main-view"],"widgetStates":{"jp-running-sessions":{"sizes":[0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666,0.16666666666666666],"expansionStates":[false,false,false,false,false,false]},"extensionmanager.main-view":{"sizes":[0.3333333333333333,0.3333333333333333,0.3333333333333333],"expansionStates":[false,false,false]}}},"right":{"collapsed":true,"visible":false,"widgets":["jp-property-inspector","debugger-sidebar"],"widgetStates":{"jp-debugger-sidebar":{"sizes":[0.2,0.2,0.2,0.2,0.2],"expansionStates":[false,false,false,false,false]}}},"relativeSizes":[0,1,0],"top":{"simpleVisibility":true}},"docmanager:recents":{"opened":[{"path":"","contentType":"directory","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled8.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled7.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled5.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled4.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled1.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled3.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled2.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"flat_vanilla_take2.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"}],"closed":[{"path":"Untitled3.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"flat_vanilla_take2.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"},{"path":"Untitled2.ipynb","contentType":"notebook","factory":"Notebook","root":"~/Library/CloudStorage/Dropbox/Work/Xanadu/projects/MrMustard"}]},"notebook:Untitled1.ipynb":{"data":{"path":"Untitled1.ipynb","factory":"Notebook"}},"terminal:1":{"data":{"name":"1"}},"notebook:Untitled4.ipynb":{"data":{"path":"Untitled4.ipynb","factory":"Notebook"}},"notebook:Untitled5.ipynb":{"data":{"path":"Untitled5.ipynb","factory":"Notebook"}},"notebook:Untitled7.ipynb":{"data":{"path":"Untitled7.ipynb","factory":"Notebook"}},"notebook:Untitled8.ipynb":{"data":{"path":"Untitled8.ipynb","factory":"Notebook"}}},"metadata":{"id":"default"}}
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ extension-pkg-whitelist=numpy,tensorflow,scipy,thewalrus,strawberryfields
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=numpy,tensorflow,scipy,thewalrus,strawberryfields,strawberryfields.parameters,collections.abc
ignored-modules=numpy,tensorflow,scipy,thewalrus,strawberryfields,strawberryfields.parameters,collections.abc, mrmustard.math

# List of classes names for which member attributes should not be checked
# (useful for classes with attributes dynamically set). This supports can work
Expand Down
24 changes: 15 additions & 9 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class CircuitComponent:
def __init__(
self,
representation: Bargmann | Fock | None = None,
wires: Wires | Sequence[tuple[int]] | None = None,
wires: Wires | Sequence[tuple[int, ...]] | None = None,
name: str | None = None,
) -> None:
self._name = name
Expand Down Expand Up @@ -164,9 +164,10 @@ def adjoint(self) -> CircuitComponent:
bras = self.wires.bra.indices
kets = self.wires.ket.indices
rep = self.representation.reorder(kets + bras).conj() if self.representation else None

ret = CircuitComponent(rep, self.wires.adjoint, self.name)
ret.short_name = self.short_name
for param in self.parameter_set.all_parameters.values():
ret._add_parameter(param)
return ret

@property
Expand All @@ -184,7 +185,8 @@ def dual(self) -> CircuitComponent:

ret = CircuitComponent(rep, self.wires.dual, self.name)
ret.short_name = self.short_name

for param in self.parameter_set.all_parameters.values():
ret._add_parameter(param)
return ret

@cached_property
Expand Down Expand Up @@ -375,7 +377,8 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor:
# Find where all the bras and kets are so they can be conjugated appropriately
conjugates = [i not in self.wires.ket.indices for i in range(len(self.wires.indices))]
quad_basis = math.sum(
[quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axes=[0]
[quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays],
axes=[0],
)
return quad_basis

Expand All @@ -385,7 +388,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor:
@classmethod
def _from_attributes(
cls,
representation: Representation,
representation: Representation | None,
wires: Wires,
name: str | None = None,
) -> CircuitComponent:
Expand Down Expand Up @@ -487,6 +490,8 @@ def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> Compl
)
if self.representation.ansatz.polynomial_shape[0] == 0:
arrays = [math.hermite_renormalized(A, b, c, shape) for A, b, c in zip(As, bs, cs)]
if self.representation.polynomial_shape[0] == 0:
arrays = [math.hermite_renormalized(A, b, c, shape) for A, b, c in zip(As, bs, cs)]
else:
arrays = [
math.sum(
Expand All @@ -497,12 +502,12 @@ def fock(self, shape: int | Sequence[int] | None = None, batched=False) -> Compl
)
for A, b, c in zip(As, bs, cs)
]
except AttributeError:
except AttributeError as e:
shape = shape or self.auto_shape()
if len(shape) != num_vars:
raise ValueError(
f"Expected Fock shape of length {num_vars}, got length {len(shape)}"
)
) from e
arrays = self.representation.reduce(shape).array
array = math.sum(arrays, axes=[0])
arrays = math.expand_dims(array, 0) if batched else array
Expand Down Expand Up @@ -572,7 +577,8 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent:
"""
fock = Fock(self.fock(shape, batched=True), batched=True)
try:
fock.ansatz._original_abc_data = self.representation.triple
if self.representation.ansatz.polynomial_shape[0] == 0:
fock.ansatz._original_abc_data = self.representation.triple
except AttributeError:
fock.ansatz._original_abc_data = None
try:
Expand Down Expand Up @@ -721,7 +727,7 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent:

wires_result, perm = self.wires @ other.wires
idx_z, idx_zconj = self._matmul_indices(other)
if type(self.representation) == type(other.representation):
if type(self.representation) is type(other.representation):
self_rep = self.representation
other_rep = other.representation
else:
Expand Down
240 changes: 240 additions & 0 deletions mrmustard/lab_dev/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Samplers for measurement devices.
"""

from __future__ import annotations
from itertools import product

from abc import ABC, abstractmethod

from typing import Any, Sequence

import numpy as np

from mrmustard import math, settings

from .states import State, Number, Ket
from .circuit_components import CircuitComponent
from .circuit_components_utils import BtoQ

__all__ = ["Sampler", "PNRSampler", "HomodyneSampler"]


class Sampler(ABC):
r"""
A sampler for measurements of quantum circuits.

Args:
meas_outcomes: The measurement outcomes for this sampler.
povms: The (optional) POVMs of this sampler.
"""

def __init__(
self,
meas_outcomes: Sequence[Any],
povms: CircuitComponent | Sequence[CircuitComponent] | None = None,
):
self._povms = povms
self._meas_outcomes = meas_outcomes
self._outcome_arg = None

@property
def povms(self) -> CircuitComponent | Sequence[CircuitComponent] | None:
r"""
The POVMs of this sampler.
"""
return self._povms

@property
def meas_outcomes(self) -> Sequence[Any]:
r"""
The measurement outcomes of this sampler.
"""
return self._meas_outcomes

@abstractmethod
def probabilities(self, state: State, atol: float = 1e-4) -> Sequence[float]:
r"""
Returns the probability distribution of a state w.r.t. measurement outcomes.

Args:
state: The state to generate the probability distribution of. Note: the
input state must be normalized.
atol: The absolute tolerance used for validating that the computed
probability distribution sums to ``1``.
"""

def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -> np.ndarray:
r"""
Returns an array of samples given a state.

Args:
state: The state to sample.
n_samples: The number of samples to generate.
seed: An optional seed for random sampling.

Returns:
An array of samples such that the shape is ``(n_samples, n_modes)``.
"""
initial_mode = state.modes[0]
initial_samples, probs = self.sample_prob_dist(state[initial_mode], n_samples, seed)

if len(state.modes) == 1:
return initial_samples

unique_samples, counts = np.unique(initial_samples, return_counts=True)
ret = []
for unique_sample, counts in zip(unique_samples, counts):
meas_op = self._get_povm(unique_sample, initial_mode).dual
prob = probs[initial_samples.tolist().index(unique_sample)]
norm = math.sqrt(prob) if isinstance(state, Ket) else prob
reduced_state = (state >> meas_op) / norm
samples = self.sample(reduced_state, counts)
for sample in samples:
ret.append(np.append([unique_sample], sample))
return np.array(ret)

def sample_prob_dist(
self, state: State, n_samples: int = 1000, seed: int | None = None
) -> tuple[np.ndarray, np.ndarray]:
r"""
Samples a state by computing the probability distribution.

Args:
state: The state to sample.
n_samples: The number of samples to generate.
seed: An optional seed for random sampling.

Returns:
A tuple of the generated samples and the probability
of obtaining the sample.
"""
rng = np.random.default_rng(seed) if seed else settings.rng
probs = self.probabilities(state)
meas_outcomes = list(product(self.meas_outcomes, repeat=len(state.modes)))
samples = rng.choice(
a=meas_outcomes,
p=self.probabilities(state),
size=n_samples,
)
return samples, np.array([probs[meas_outcomes.index(tuple(sample))] for sample in samples])

def _get_povm(self, meas_outcome: Any, mode: int) -> CircuitComponent:
r"""
Returns the POVM associated with a given outcome on a given mode.

Args:
meas_outcome: The measurement outcome.
mode: The mode.

Returns:
The POVM circuit component.

Raises:
ValueError: If this sampler has no POVMs.
"""
if self._povms is None:
raise ValueError("This sampler has no POVMs defined.")
if isinstance(self.povms, CircuitComponent):
kwargs = self.povms.parameter_set.to_dict()
kwargs[self._outcome_arg] = meas_outcome
return self.povms.__class__(modes=[mode], **kwargs)
else:
return self.povms[self.meas_outcomes.index(meas_outcome)].on([mode])

def _validate_probs(self, probs: Sequence[float], atol: float) -> Sequence[float]:
r"""
Validates that the given probability distribution sums to ``1`` within some
tolerance and returns a renormalized probability distribution to account for
small numerical errors.

Args:
probs: The probability distribution to validate.
atol: The absolute tolerance to validate with.
"""
atol = atol or settings.ATOL
prob_sum = sum(probs)
if not math.allclose(prob_sum, 1, atol):
raise ValueError(f"Probabilities sum to {prob_sum} and not 1.0.")
return math.real(probs / prob_sum)


class PNRSampler(Sampler):
r"""
A sampler for photon-number resolving (PNR) detectors.

Args:
cutoff: The photon number cutoff.
"""

def __init__(self, cutoff: int) -> None:
super().__init__(list(range(cutoff)), Number([0], 0))
self._cutoff = cutoff
self._outcome_arg = "n"

def probabilities(self, state, atol=1e-4):
return self._validate_probs(state.fock_distribution(self._cutoff), atol)


class HomodyneSampler(Sampler):
r"""
A sampler for homodyne measurements.

Args:
phi: The quadrature angle where ``0`` corresponds to ``x`` and ``\pi/2`` to ``p``.
bounds: The range of values to discretize over.
num: The number of points to discretize over.
"""

def __init__(
self,
phi: float = 0,
bounds: tuple[float, float] = (-10, 10),
num: int = 1000,
) -> None:
meas_outcomes, step = np.linspace(*bounds, num, retstep=True)
super().__init__(list(meas_outcomes))
self._step = step
self._phi = phi

def probabilities(self, state, atol=1e-4):
probs = state.quadrature_distribution(self.meas_outcomes, self._phi) * self._step ** len(
state.modes
)
return self._validate_probs(probs, atol)

def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -> np.ndarray:
initial_mode = state.modes[0]
initial_samples, probs = self.sample_prob_dist(state[initial_mode], n_samples, seed)

if len(state.modes) == 1:
return initial_samples

unique_samples, counts = np.unique(initial_samples, return_counts=True)
ret = []
for unique_sample, counts in zip(unique_samples, counts):
quad = np.array([[unique_sample] + [None] * (state.n_modes - 1)])
quad = quad if isinstance(state, Ket) else math.tile(quad, (1, 2))
reduced_rep = (state >> BtoQ([initial_mode], phi=self._phi)).representation(quad)
reduced_state = state.__class__.from_bargmann(state.modes[1:], reduced_rep.triple)
prob = probs[initial_samples.tolist().index(unique_sample)] / self._step
norm = math.sqrt(prob) if isinstance(state, Ket) else prob
normalized_reduced_state = reduced_state / norm
samples = self.sample(normalized_reduced_state, counts)
for sample in samples:
ret.append(np.append([unique_sample], sample))
return np.array(ret)
Loading
Loading