Skip to content

Commit

Permalink
picked low-hanging fruits
Browse files Browse the repository at this point in the history
  • Loading branch information
ziofil committed Oct 16, 2024
1 parent a24b0bf commit 227061e
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 107 deletions.
3 changes: 0 additions & 3 deletions mrmustard/lab/abstract/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,6 @@ def primal(self, other: State | Transformation) -> State:
Note that the returned state is not normalized. To normalize a state you can use
``mrmustard.physics.normalize``.
"""
# import pdb

# pdb.set_trace()
if isinstance(other, State):
return self._project_onto_state(other)
try:
Expand Down
48 changes: 27 additions & 21 deletions mrmustard/lab_dev/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def sample_prob_dist(
meas_outcomes = list(product(self.meas_outcomes, repeat=len(state.modes)))
samples = rng.choice(
a=meas_outcomes,
p=self.probabilities(state),
p=probs,
size=n_samples,
)
return samples, np.array([probs[meas_outcomes.index(tuple(sample))] for sample in samples])
Expand Down Expand Up @@ -167,7 +167,7 @@ def _validate_probs(self, probs: Sequence[float], atol: float) -> Sequence[float
atol: The absolute tolerance to validate with.
"""
atol = atol or settings.ATOL
prob_sum = sum(probs)
prob_sum = math.sum(probs)
if not math.allclose(prob_sum, 1, atol):
raise ValueError(f"Probabilities sum to {prob_sum} and not 1.0.")
return math.real(probs / prob_sum)
Expand Down Expand Up @@ -217,24 +217,30 @@ def probabilities(self, state, atol=1e-4):
)
return self._validate_probs(probs, atol)

def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -> np.ndarray:
initial_mode = state.modes[0]
initial_samples, probs = self.sample_prob_dist(state[initial_mode], n_samples, seed)
def sample_prob_dist(
self, state: State, n_samples: int = 1000, seed: int | None = None
) -> np.ndarray:
r"""
Samples a state by computing the probability distribution.
if len(state.modes) == 1:
return initial_samples
Args:
state: The state to sample.
n_samples: The number of samples to generate.
seed: An optional seed for random sampling.
unique_samples, counts = np.unique(initial_samples, return_counts=True)
ret = []
for unique_sample, counts in zip(unique_samples, counts):
quad = np.array([[unique_sample] + [None] * (state.n_modes - 1)])
quad = quad if isinstance(state, Ket) else math.tile(quad, (1, 2))
reduced_rep = (state >> BtoQ([initial_mode], phi=self._phi)).representation(quad)
reduced_state = state.__class__.from_bargmann(state.modes[1:], reduced_rep.triple)
prob = probs[initial_samples.tolist().index(unique_sample)] / self._step
norm = math.sqrt(prob) if isinstance(state, Ket) else prob
normalized_reduced_state = reduced_state / norm
samples = self.sample(normalized_reduced_state, counts)
for sample in samples:
ret.append(np.append([unique_sample], sample))
return np.array(ret)
Returns:
A tuple of the generated samples and the probability
of obtaining the sample.
"""
rng = np.random.default_rng(seed) if seed else settings.rng
probs = self.probabilities(state)
meas_outcomes = list(product(self.meas_outcomes, repeat=len(state.modes)))
samples = rng.choice(
a=meas_outcomes,
p=probs,
size=n_samples,
)
return samples

def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -> np.ndarray:
return self.sample_prob_dist(state, n_samples, seed)
6 changes: 4 additions & 2 deletions mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,16 @@ def quadrature_distribution(self, quad: Vector, phi: float = 0.0) -> ComplexTens
Returns:
The quadrature distribution.
"""
quad = math.astensor(quad)
quad = np.array(quad)
if len(quad.shape) != 1 and len(quad.shape) != self.n_modes:
raise ValueError(
"The dimensionality of quad should be 1, or match the number of modes."
)

if len(quad.shape) == 1:
quad = math.astensor(list(product(quad, repeat=len(self.modes))))
quad = math.astensor(np.meshgrid(*[quad] * len(self.modes))).T.reshape(
-1, len(self.modes)
)

if isinstance(self, Ket):
return math.abs(self.quadrature(quad, phi)) ** 2
Expand Down
154 changes: 73 additions & 81 deletions tests/test_math/test_compactFock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from ..conftest import skip_np

original_precision = settings.PRECISION_BITS_HERMITE_POLY

do_julia = bool(importlib.util.find_spec("juliacall"))
precisions = [128, 256, 384, 512] if do_julia else [128]

Expand All @@ -41,28 +39,28 @@ def test_compactFock_diagonal(precision, A_B_G0):
r"""Test getting Fock amplitudes if all modes are
detected (math.hermite_renormalized_diagonal)
"""
settings.PRECISION_BITS_HERMITE_POLY = precision
cutoffs = (5, 5, 5)

A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement)

# Vanilla MM
G_ref = math.hermite_renormalized(
math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2
) # note: shape=[C1,C2,C3,...,C1,C2,C3,...]
G_ref = math.asnumpy(G_ref)

# Extract diagonal amplitudes from vanilla MM
ref_diag = np.zeros(cutoffs, dtype=np.complex128)
for inds in np.ndindex(*cutoffs):
inds_expanded = list(inds) + list(inds) # a,b,c,a,b,c
ref_diag[inds] = G_ref[tuple(inds_expanded)]

# New MM
G_diag = math.hermite_renormalized_diagonal(math.conj(-A), math.conj(B), math.conj(G0), cutoffs)
assert np.allclose(ref_diag, G_diag)

settings.PRECISION_BITS_HERMITE_POLY = original_precision
with settings(PRECISION_BITS_HERMITE_POLY=precision):
cutoffs = (5, 5, 5)

A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement)

# Vanilla MM
G_ref = math.hermite_renormalized(
math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2
) # note: shape=[C1,C2,C3,...,C1,C2,C3,...]
G_ref = math.asnumpy(G_ref)

# Extract diagonal amplitudes from vanilla MM
ref_diag = np.zeros(cutoffs, dtype=np.complex128)
for inds in np.ndindex(*cutoffs):
inds_expanded = list(inds) + list(inds) # a,b,c,a,b,c
ref_diag[inds] = G_ref[tuple(inds_expanded)]

# New MM
G_diag = math.hermite_renormalized_diagonal(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs
)
assert np.allclose(ref_diag, G_diag)


@given(random_ABC(M=3))
Expand All @@ -74,31 +72,29 @@ def test_compactFock_1leftover(precision, A_B_G0):
"""
skip_np()

settings.PRECISION_BITS_HERMITE_POLY = precision
cutoffs = (5, 5, 5)

A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement)
with settings(PRECISION_BITS_HERMITE_POLY=precision):
cutoffs = (5, 5, 5)

# New algorithm
G_leftover = math.hermite_renormalized_1leftoverMode(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs
)
A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement)

# Vanilla MM
G_ref = math.hermite_renormalized(
math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2
) # note: shape=[C1,C2,C3,...,C1,C2,C3,...]
G_ref = math.asnumpy(G_ref)
# New algorithm
G_leftover = math.hermite_renormalized_1leftoverMode(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs
)

# Extract amplitudes of leftover mode from vanilla MM
ref_leftover = np.zeros([cutoffs[0]] * 2 + list(cutoffs)[1:], dtype=np.complex128)
for inds in np.ndindex(*cutoffs[1:]):
ref_leftover[tuple([slice(cutoffs[0]), slice(cutoffs[0])] + list(inds))] = G_ref[
tuple([slice(cutoffs[0])] + list(inds) + [slice(cutoffs[0])] + list(inds))
]
assert np.allclose(ref_leftover, G_leftover)
# Vanilla MM
G_ref = math.hermite_renormalized(
math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2
) # note: shape=[C1,C2,C3,...,C1,C2,C3,...]
G_ref = math.asnumpy(G_ref)

settings.PRECISION_BITS_HERMITE_POLY = original_precision
# Extract amplitudes of leftover mode from vanilla MM
ref_leftover = np.zeros([cutoffs[0]] * 2 + list(cutoffs)[1:], dtype=np.complex128)
for inds in np.ndindex(*cutoffs[1:]):
ref_leftover[tuple([slice(cutoffs[0]), slice(cutoffs[0])] + list(inds))] = G_ref[
tuple([slice(cutoffs[0])] + list(inds) + [slice(cutoffs[0])] + list(inds))
]
assert np.allclose(ref_leftover, G_leftover)


@pytest.mark.parametrize("precision", precisions)
Expand All @@ -109,25 +105,23 @@ def test_compactFock_diagonal_gradients(precision):
"""
skip_np()

settings.PRECISION_BITS_HERMITE_POLY = precision
G = Ggate(num_modes=2, symplectic_trainable=True)

def cost_fn():
n1, n2 = 2, 4 # number of detected photons
state_opt = Vacuum(2) >> G
A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means)
probs = math.hermite_renormalized_diagonal(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[n1 + 1, n2 + 1]
)
p = probs[n1, n2]
return -math.real(p)
with settings(PRECISION_BITS_HERMITE_POLY=precision):
G = Ggate(num_modes=1, symplectic_trainable=True)

opt = Optimizer(symplectic_lr=0.5)
opt.minimize(cost_fn, by_optimizing=[G], max_steps=50)
for i in range(2, min(20, len(opt.opt_history))):
assert opt.opt_history[i - 1] >= opt.opt_history[i]
def cost_fn():
n1 = 2 # number of detected photons
state_opt = Vacuum(1) >> G
A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means)
probs = math.hermite_renormalized_diagonal(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[n1 + 1]
)
p = probs[n1]
return -math.real(p)

settings.PRECISION_BITS_HERMITE_POLY = original_precision
opt = Optimizer(symplectic_lr=0.5)
opt.minimize(cost_fn, by_optimizing=[G], max_steps=5)
for i in range(2, min(20, len(opt.opt_history))):
assert opt.opt_history[i - 1] >= opt.opt_history[i]


@pytest.mark.parametrize("precision", precisions)
Expand All @@ -138,22 +132,20 @@ def test_compactFock_1leftover_gradients(precision):
"""
skip_np()

settings.PRECISION_BITS_HERMITE_POLY = precision
G = Ggate(num_modes=2, symplectic_trainable=True)

def cost_fn():
n2 = 3 # number of detected photons
state_opt = Vacuum(2) >> G
A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means)
marginal = math.hermite_renormalized_1leftoverMode(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[8, n2 + 1]
)
conditional_state = normalize(State(dm=marginal[..., n2]))
return -fidelity(conditional_state, SqueezedVacuum(r=1))

opt = Optimizer(symplectic_lr=0.1)
opt.minimize(cost_fn, by_optimizing=[G], max_steps=50)
for i in range(2, min(20, len(opt.opt_history))):
assert opt.opt_history[i - 1] >= opt.opt_history[i]

settings.PRECISION_BITS_HERMITE_POLY = original_precision
with settings(PRECISION_BITS_HERMITE_POLY=precision):
G = Ggate(num_modes=2, symplectic_trainable=True)

def cost_fn():
n2 = 2 # number of detected photons
state_opt = Vacuum(2) >> G
A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means)
marginal = math.hermite_renormalized_1leftoverMode(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[2, n2 + 1]
)
conditional_state = normalize(State(dm=marginal[..., n2]))
return -fidelity(conditional_state, SqueezedVacuum(r=1))

opt = Optimizer(symplectic_lr=0.1)
opt.minimize(cost_fn, by_optimizing=[G], max_steps=5)
for i in range(2, len(opt.opt_history)):
assert opt.opt_history[i - 1] >= opt.opt_history[i]

0 comments on commit 227061e

Please sign in to comment.