Skip to content

Commit

Permalink
Speed up TF tests (#505)
Browse files Browse the repository at this point in the history
Co-authored-by: Anthony <[email protected]>
  • Loading branch information
ziofil and apchytr authored Oct 29, 2024
1 parent 67b8624 commit 5fe4a5f
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 107 deletions.
3 changes: 0 additions & 3 deletions mrmustard/lab/abstract/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,6 @@ def primal(self, other: State | Transformation) -> State:
Note that the returned state is not normalized. To normalize a state you can use
``mrmustard.physics.normalize``.
"""
# import pdb

# pdb.set_trace()
if isinstance(other, State):
return self._project_onto_state(other)
try:
Expand Down
20 changes: 12 additions & 8 deletions mrmustard/lab_dev/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -
if len(state.modes) == 1:
return initial_samples

unique_samples, counts = np.unique(initial_samples, return_counts=True)
unique_samples, idxs, counts = np.unique(
initial_samples, return_index=True, return_counts=True
)
ret = []
for unique_sample, counts in zip(unique_samples, counts):
for unique_sample, idx, counts in zip(unique_samples, idxs, counts):
meas_op = self._get_povm(unique_sample, initial_mode).dual
prob = probs[initial_samples.tolist().index(unique_sample)]
prob = probs[idx]
norm = math.sqrt(prob) if isinstance(state, Ket) else prob
reduced_state = (state >> meas_op) / norm
samples = self.sample(reduced_state, counts)
Expand Down Expand Up @@ -128,7 +130,7 @@ def sample_prob_dist(
meas_outcomes = list(product(self.meas_outcomes, repeat=len(state.modes)))
samples = rng.choice(
a=meas_outcomes,
p=self.probabilities(state),
p=probs,
size=n_samples,
)
return samples, np.array([probs[meas_outcomes.index(tuple(sample))] for sample in samples])
Expand Down Expand Up @@ -167,7 +169,7 @@ def _validate_probs(self, probs: Sequence[float], atol: float) -> Sequence[float
atol: The absolute tolerance to validate with.
"""
atol = atol or settings.ATOL
prob_sum = sum(probs)
prob_sum = math.sum(probs)
if not math.allclose(prob_sum, 1, atol):
raise ValueError(f"Probabilities sum to {prob_sum} and not 1.0.")
return math.real(probs / prob_sum)
Expand Down Expand Up @@ -224,14 +226,16 @@ def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -
if len(state.modes) == 1:
return initial_samples

unique_samples, counts = np.unique(initial_samples, return_counts=True)
unique_samples, idxs, counts = np.unique(
initial_samples, return_index=True, return_counts=True
)
ret = []
for unique_sample, counts in zip(unique_samples, counts):
for unique_sample, idx, counts in zip(unique_samples, idxs, counts):
quad = np.array([[unique_sample] + [None] * (state.n_modes - 1)])
quad = quad if isinstance(state, Ket) else math.tile(quad, (1, 2))
reduced_rep = (state >> BtoQ([initial_mode], phi=self._phi)).representation(quad)
reduced_state = state.__class__.from_bargmann(state.modes[1:], reduced_rep.triple)
prob = probs[initial_samples.tolist().index(unique_sample)] / self._step
prob = probs[idx] / self._step
norm = math.sqrt(prob) if isinstance(state, Ket) else prob
normalized_reduced_state = reduced_state / norm
samples = self.sample(normalized_reduced_state, counts)
Expand Down
6 changes: 4 additions & 2 deletions mrmustard/lab_dev/states/dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,16 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex
Returns:
The quadrature distribution.
"""
quad = math.astensor(quad)
quad = np.array(quad)
if len(quad.shape) != 1 and len(quad.shape) != self.n_modes:
raise ValueError(
"The dimensionality of quad should be 1, or match the number of modes."
)

if len(quad.shape) == 1:
quad = math.astensor(list(product(quad, repeat=len(self.modes))))
quad = math.astensor(np.meshgrid(*[quad] * len(self.modes))).T.reshape(
-1, len(self.modes)
)

quad = math.tile(quad, (1, 2))
return self.quadrature(quad, phi)
Expand Down
6 changes: 4 additions & 2 deletions mrmustard/lab_dev/states/ket.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,16 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex
Returns:
The quadrature distribution.
"""
quad = math.astensor(quad)
quad = np.array(quad)
if len(quad.shape) != 1 and len(quad.shape) != self.n_modes:
raise ValueError(
"The dimensionality of quad should be 1, or match the number of modes."
)

if len(quad.shape) == 1:
quad = math.astensor(list(product(quad, repeat=len(self.modes))))
quad = math.astensor(np.meshgrid(*[quad] * len(self.modes))).T.reshape(
-1, len(self.modes)
)

return math.abs(self.quadrature(quad, phi)) ** 2

Expand Down
22 changes: 11 additions & 11 deletions mrmustard/physics/ansatze.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,20 +686,20 @@ def _call_none(self, z: Batch[Vector]) -> PolyExpAnsatz:

batch_abc = self.batch_size
batch_arg = z.shape[0]
Abc = []
if batch_abc == 1 and batch_arg > 1:
for i in range(batch_arg):
Abc.append(self._call_none_single(self.A[0], self.b[0], self.c[0], z[i]))
elif batch_arg == 1 and batch_abc > 1:
for i in range(batch_abc):
Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[0]))
elif batch_abc == batch_arg:
for i in range(batch_abc):
Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[i]))
else:
if batch_abc != batch_arg and batch_abc != 1 and batch_arg != 1:
raise ValueError(
"Batch size of the ansatz and argument must match or one of the batch sizes must be 1."
)
Abc = []
max_batch = max(batch_abc, batch_arg)
for i in range(max_batch):
abc_index = 0 if batch_abc == 1 else i
arg_index = 0 if batch_arg == 1 else i
Abc.append(
self._call_none_single(
self.A[abc_index], self.b[abc_index], self.c[abc_index], z[arg_index]
)
)
A, b, c = zip(*Abc)
return self.__class__(A=A, b=b, c=c)

Expand Down
154 changes: 73 additions & 81 deletions tests/test_math/test_compactFock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from ..conftest import skip_np

original_precision = settings.PRECISION_BITS_HERMITE_POLY

do_julia = bool(importlib.util.find_spec("juliacall"))
precisions = [128, 256, 384, 512] if do_julia else [128]

Expand All @@ -41,28 +39,28 @@ def test_compactFock_diagonal(precision, A_B_G0):
r"""Test getting Fock amplitudes if all modes are
detected (math.hermite_renormalized_diagonal)
"""
settings.PRECISION_BITS_HERMITE_POLY = precision
cutoffs = (5, 5, 5)

A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement)

# Vanilla MM
G_ref = math.hermite_renormalized(
math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2
) # note: shape=[C1,C2,C3,...,C1,C2,C3,...]
G_ref = math.asnumpy(G_ref)

# Extract diagonal amplitudes from vanilla MM
ref_diag = np.zeros(cutoffs, dtype=np.complex128)
for inds in np.ndindex(*cutoffs):
inds_expanded = list(inds) + list(inds) # a,b,c,a,b,c
ref_diag[inds] = G_ref[tuple(inds_expanded)]

# New MM
G_diag = math.hermite_renormalized_diagonal(math.conj(-A), math.conj(B), math.conj(G0), cutoffs)
assert np.allclose(ref_diag, G_diag)

settings.PRECISION_BITS_HERMITE_POLY = original_precision
with settings(PRECISION_BITS_HERMITE_POLY=precision):
cutoffs = (5, 5, 5)

A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement)

# Vanilla MM
G_ref = math.hermite_renormalized(
math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2
) # note: shape=[C1,C2,C3,...,C1,C2,C3,...]
G_ref = math.asnumpy(G_ref)

# Extract diagonal amplitudes from vanilla MM
ref_diag = np.zeros(cutoffs, dtype=np.complex128)
for inds in np.ndindex(*cutoffs):
inds_expanded = list(inds) + list(inds) # a,b,c,a,b,c
ref_diag[inds] = G_ref[tuple(inds_expanded)]

# New MM
G_diag = math.hermite_renormalized_diagonal(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs
)
assert np.allclose(ref_diag, G_diag)


@given(random_ABC(M=3))
Expand All @@ -74,31 +72,29 @@ def test_compactFock_1leftover(precision, A_B_G0):
"""
skip_np()

settings.PRECISION_BITS_HERMITE_POLY = precision
cutoffs = (5, 5, 5)

A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement)
with settings(PRECISION_BITS_HERMITE_POLY=precision):
cutoffs = (5, 5, 5)

# New algorithm
G_leftover = math.hermite_renormalized_1leftoverMode(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs
)
A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement)

# Vanilla MM
G_ref = math.hermite_renormalized(
math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2
) # note: shape=[C1,C2,C3,...,C1,C2,C3,...]
G_ref = math.asnumpy(G_ref)
# New algorithm
G_leftover = math.hermite_renormalized_1leftoverMode(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs
)

# Extract amplitudes of leftover mode from vanilla MM
ref_leftover = np.zeros([cutoffs[0]] * 2 + list(cutoffs)[1:], dtype=np.complex128)
for inds in np.ndindex(*cutoffs[1:]):
ref_leftover[tuple([slice(cutoffs[0]), slice(cutoffs[0])] + list(inds))] = G_ref[
tuple([slice(cutoffs[0])] + list(inds) + [slice(cutoffs[0])] + list(inds))
]
assert np.allclose(ref_leftover, G_leftover)
# Vanilla MM
G_ref = math.hermite_renormalized(
math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2
) # note: shape=[C1,C2,C3,...,C1,C2,C3,...]
G_ref = math.asnumpy(G_ref)

settings.PRECISION_BITS_HERMITE_POLY = original_precision
# Extract amplitudes of leftover mode from vanilla MM
ref_leftover = np.zeros([cutoffs[0]] * 2 + list(cutoffs)[1:], dtype=np.complex128)
for inds in np.ndindex(*cutoffs[1:]):
ref_leftover[tuple([slice(cutoffs[0]), slice(cutoffs[0])] + list(inds))] = G_ref[
tuple([slice(cutoffs[0])] + list(inds) + [slice(cutoffs[0])] + list(inds))
]
assert np.allclose(ref_leftover, G_leftover)


@pytest.mark.parametrize("precision", precisions)
Expand All @@ -109,25 +105,23 @@ def test_compactFock_diagonal_gradients(precision):
"""
skip_np()

settings.PRECISION_BITS_HERMITE_POLY = precision
G = Ggate(num_modes=2, symplectic_trainable=True)

def cost_fn():
n1, n2 = 2, 4 # number of detected photons
state_opt = Vacuum(2) >> G
A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means)
probs = math.hermite_renormalized_diagonal(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[n1 + 1, n2 + 1]
)
p = probs[n1, n2]
return -math.real(p)
with settings(PRECISION_BITS_HERMITE_POLY=precision):
G = Ggate(num_modes=1, symplectic_trainable=True)

opt = Optimizer(symplectic_lr=0.5)
opt.minimize(cost_fn, by_optimizing=[G], max_steps=50)
for i in range(2, min(20, len(opt.opt_history))):
assert opt.opt_history[i - 1] >= opt.opt_history[i]
def cost_fn():
n1 = 2 # number of detected photons
state_opt = Vacuum(1) >> G
A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means)
probs = math.hermite_renormalized_diagonal(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[n1 + 1]
)
p = probs[n1]
return -math.real(p)

settings.PRECISION_BITS_HERMITE_POLY = original_precision
opt = Optimizer(symplectic_lr=0.5)
opt.minimize(cost_fn, by_optimizing=[G], max_steps=5)
for i in range(2, min(20, len(opt.opt_history))):
assert opt.opt_history[i - 1] >= opt.opt_history[i]


@pytest.mark.parametrize("precision", precisions)
Expand All @@ -138,22 +132,20 @@ def test_compactFock_1leftover_gradients(precision):
"""
skip_np()

settings.PRECISION_BITS_HERMITE_POLY = precision
G = Ggate(num_modes=2, symplectic_trainable=True)

def cost_fn():
n2 = 3 # number of detected photons
state_opt = Vacuum(2) >> G
A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means)
marginal = math.hermite_renormalized_1leftoverMode(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[8, n2 + 1]
)
conditional_state = normalize(State(dm=marginal[..., n2]))
return -fidelity(conditional_state, SqueezedVacuum(r=1))

opt = Optimizer(symplectic_lr=0.1)
opt.minimize(cost_fn, by_optimizing=[G], max_steps=50)
for i in range(2, min(20, len(opt.opt_history))):
assert opt.opt_history[i - 1] >= opt.opt_history[i]

settings.PRECISION_BITS_HERMITE_POLY = original_precision
with settings(PRECISION_BITS_HERMITE_POLY=precision):
G = Ggate(num_modes=2, symplectic_trainable=True)

def cost_fn():
n2 = 2 # number of detected photons
state_opt = Vacuum(2) >> G
A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means)
marginal = math.hermite_renormalized_1leftoverMode(
math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[2, n2 + 1]
)
conditional_state = normalize(State(dm=marginal[..., n2]))
return -fidelity(conditional_state, SqueezedVacuum(r=1))

opt = Optimizer(symplectic_lr=0.1)
opt.minimize(cost_fn, by_optimizing=[G], max_steps=5)
for i in range(2, len(opt.opt_history)):
assert opt.opt_history[i - 1] >= opt.opt_history[i]

0 comments on commit 5fe4a5f

Please sign in to comment.