diff --git a/mrmustard/lab/abstract/state.py b/mrmustard/lab/abstract/state.py index e52a52d66..97ede9741 100644 --- a/mrmustard/lab/abstract/state.py +++ b/mrmustard/lab/abstract/state.py @@ -383,9 +383,6 @@ def primal(self, other: State | Transformation) -> State: Note that the returned state is not normalized. To normalize a state you can use ``mrmustard.physics.normalize``. """ - # import pdb - - # pdb.set_trace() if isinstance(other, State): return self._project_onto_state(other) try: diff --git a/mrmustard/lab_dev/samplers.py b/mrmustard/lab_dev/samplers.py index c7fda9e15..5a690f975 100644 --- a/mrmustard/lab_dev/samplers.py +++ b/mrmustard/lab_dev/samplers.py @@ -96,11 +96,13 @@ def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) - if len(state.modes) == 1: return initial_samples - unique_samples, counts = np.unique(initial_samples, return_counts=True) + unique_samples, idxs, counts = np.unique( + initial_samples, return_index=True, return_counts=True + ) ret = [] - for unique_sample, counts in zip(unique_samples, counts): + for unique_sample, idx, counts in zip(unique_samples, idxs, counts): meas_op = self._get_povm(unique_sample, initial_mode).dual - prob = probs[initial_samples.tolist().index(unique_sample)] + prob = probs[idx] norm = math.sqrt(prob) if isinstance(state, Ket) else prob reduced_state = (state >> meas_op) / norm samples = self.sample(reduced_state, counts) @@ -128,7 +130,7 @@ def sample_prob_dist( meas_outcomes = list(product(self.meas_outcomes, repeat=len(state.modes))) samples = rng.choice( a=meas_outcomes, - p=self.probabilities(state), + p=probs, size=n_samples, ) return samples, np.array([probs[meas_outcomes.index(tuple(sample))] for sample in samples]) @@ -167,7 +169,7 @@ def _validate_probs(self, probs: Sequence[float], atol: float) -> Sequence[float atol: The absolute tolerance to validate with. """ atol = atol or settings.ATOL - prob_sum = sum(probs) + prob_sum = math.sum(probs) if not math.allclose(prob_sum, 1, atol): raise ValueError(f"Probabilities sum to {prob_sum} and not 1.0.") return math.real(probs / prob_sum) @@ -224,14 +226,16 @@ def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) - if len(state.modes) == 1: return initial_samples - unique_samples, counts = np.unique(initial_samples, return_counts=True) + unique_samples, idxs, counts = np.unique( + initial_samples, return_index=True, return_counts=True + ) ret = [] - for unique_sample, counts in zip(unique_samples, counts): + for unique_sample, idx, counts in zip(unique_samples, idxs, counts): quad = np.array([[unique_sample] + [None] * (state.n_modes - 1)]) quad = quad if isinstance(state, Ket) else math.tile(quad, (1, 2)) reduced_rep = (state >> BtoQ([initial_mode], phi=self._phi)).representation(quad) reduced_state = state.__class__.from_bargmann(state.modes[1:], reduced_rep.triple) - prob = probs[initial_samples.tolist().index(unique_sample)] / self._step + prob = probs[idx] / self._step norm = math.sqrt(prob) if isinstance(state, Ket) else prob normalized_reduced_state = reduced_state / norm samples = self.sample(normalized_reduced_state, counts) diff --git a/mrmustard/lab_dev/states/dm.py b/mrmustard/lab_dev/states/dm.py index a5bd16ce6..d80f2ad97 100644 --- a/mrmustard/lab_dev/states/dm.py +++ b/mrmustard/lab_dev/states/dm.py @@ -262,14 +262,16 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex Returns: The quadrature distribution. """ - quad = math.astensor(quad) + quad = np.array(quad) if len(quad.shape) != 1 and len(quad.shape) != self.n_modes: raise ValueError( "The dimensionality of quad should be 1, or match the number of modes." ) if len(quad.shape) == 1: - quad = math.astensor(list(product(quad, repeat=len(self.modes)))) + quad = math.astensor(np.meshgrid(*[quad] * len(self.modes))).T.reshape( + -1, len(self.modes) + ) quad = math.tile(quad, (1, 2)) return self.quadrature(quad, phi) diff --git a/mrmustard/lab_dev/states/ket.py b/mrmustard/lab_dev/states/ket.py index 2d550eae3..56f8efaed 100644 --- a/mrmustard/lab_dev/states/ket.py +++ b/mrmustard/lab_dev/states/ket.py @@ -212,14 +212,16 @@ def quadrature_distribution(self, quad: RealVector, phi: float = 0.0) -> Complex Returns: The quadrature distribution. """ - quad = math.astensor(quad) + quad = np.array(quad) if len(quad.shape) != 1 and len(quad.shape) != self.n_modes: raise ValueError( "The dimensionality of quad should be 1, or match the number of modes." ) if len(quad.shape) == 1: - quad = math.astensor(list(product(quad, repeat=len(self.modes)))) + quad = math.astensor(np.meshgrid(*[quad] * len(self.modes))).T.reshape( + -1, len(self.modes) + ) return math.abs(self.quadrature(quad, phi)) ** 2 diff --git a/mrmustard/physics/ansatze.py b/mrmustard/physics/ansatze.py index 470d3eced..beb34d2b9 100644 --- a/mrmustard/physics/ansatze.py +++ b/mrmustard/physics/ansatze.py @@ -686,20 +686,20 @@ def _call_none(self, z: Batch[Vector]) -> PolyExpAnsatz: batch_abc = self.batch_size batch_arg = z.shape[0] - Abc = [] - if batch_abc == 1 and batch_arg > 1: - for i in range(batch_arg): - Abc.append(self._call_none_single(self.A[0], self.b[0], self.c[0], z[i])) - elif batch_arg == 1 and batch_abc > 1: - for i in range(batch_abc): - Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[0])) - elif batch_abc == batch_arg: - for i in range(batch_abc): - Abc.append(self._call_none_single(self.A[i], self.b[i], self.c[i], z[i])) - else: + if batch_abc != batch_arg and batch_abc != 1 and batch_arg != 1: raise ValueError( "Batch size of the ansatz and argument must match or one of the batch sizes must be 1." ) + Abc = [] + max_batch = max(batch_abc, batch_arg) + for i in range(max_batch): + abc_index = 0 if batch_abc == 1 else i + arg_index = 0 if batch_arg == 1 else i + Abc.append( + self._call_none_single( + self.A[abc_index], self.b[abc_index], self.c[abc_index], z[arg_index] + ) + ) A, b, c = zip(*Abc) return self.__class__(A=A, b=b, c=c) diff --git a/tests/test_math/test_compactFock.py b/tests/test_math/test_compactFock.py index 0e531f161..6c77f4d9e 100644 --- a/tests/test_math/test_compactFock.py +++ b/tests/test_math/test_compactFock.py @@ -18,8 +18,6 @@ from ..conftest import skip_np -original_precision = settings.PRECISION_BITS_HERMITE_POLY - do_julia = bool(importlib.util.find_spec("juliacall")) precisions = [128, 256, 384, 512] if do_julia else [128] @@ -41,28 +39,28 @@ def test_compactFock_diagonal(precision, A_B_G0): r"""Test getting Fock amplitudes if all modes are detected (math.hermite_renormalized_diagonal) """ - settings.PRECISION_BITS_HERMITE_POLY = precision - cutoffs = (5, 5, 5) - - A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement) - - # Vanilla MM - G_ref = math.hermite_renormalized( - math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2 - ) # note: shape=[C1,C2,C3,...,C1,C2,C3,...] - G_ref = math.asnumpy(G_ref) - - # Extract diagonal amplitudes from vanilla MM - ref_diag = np.zeros(cutoffs, dtype=np.complex128) - for inds in np.ndindex(*cutoffs): - inds_expanded = list(inds) + list(inds) # a,b,c,a,b,c - ref_diag[inds] = G_ref[tuple(inds_expanded)] - - # New MM - G_diag = math.hermite_renormalized_diagonal(math.conj(-A), math.conj(B), math.conj(G0), cutoffs) - assert np.allclose(ref_diag, G_diag) - - settings.PRECISION_BITS_HERMITE_POLY = original_precision + with settings(PRECISION_BITS_HERMITE_POLY=precision): + cutoffs = (5, 5, 5) + + A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement) + + # Vanilla MM + G_ref = math.hermite_renormalized( + math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2 + ) # note: shape=[C1,C2,C3,...,C1,C2,C3,...] + G_ref = math.asnumpy(G_ref) + + # Extract diagonal amplitudes from vanilla MM + ref_diag = np.zeros(cutoffs, dtype=np.complex128) + for inds in np.ndindex(*cutoffs): + inds_expanded = list(inds) + list(inds) # a,b,c,a,b,c + ref_diag[inds] = G_ref[tuple(inds_expanded)] + + # New MM + G_diag = math.hermite_renormalized_diagonal( + math.conj(-A), math.conj(B), math.conj(G0), cutoffs + ) + assert np.allclose(ref_diag, G_diag) @given(random_ABC(M=3)) @@ -74,31 +72,29 @@ def test_compactFock_1leftover(precision, A_B_G0): """ skip_np() - settings.PRECISION_BITS_HERMITE_POLY = precision - cutoffs = (5, 5, 5) - - A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement) + with settings(PRECISION_BITS_HERMITE_POLY=precision): + cutoffs = (5, 5, 5) - # New algorithm - G_leftover = math.hermite_renormalized_1leftoverMode( - math.conj(-A), math.conj(B), math.conj(G0), cutoffs - ) + A, B, G0 = A_B_G0 # Create random state (M mode Gaussian state with displacement) - # Vanilla MM - G_ref = math.hermite_renormalized( - math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2 - ) # note: shape=[C1,C2,C3,...,C1,C2,C3,...] - G_ref = math.asnumpy(G_ref) + # New algorithm + G_leftover = math.hermite_renormalized_1leftoverMode( + math.conj(-A), math.conj(B), math.conj(G0), cutoffs + ) - # Extract amplitudes of leftover mode from vanilla MM - ref_leftover = np.zeros([cutoffs[0]] * 2 + list(cutoffs)[1:], dtype=np.complex128) - for inds in np.ndindex(*cutoffs[1:]): - ref_leftover[tuple([slice(cutoffs[0]), slice(cutoffs[0])] + list(inds))] = G_ref[ - tuple([slice(cutoffs[0])] + list(inds) + [slice(cutoffs[0])] + list(inds)) - ] - assert np.allclose(ref_leftover, G_leftover) + # Vanilla MM + G_ref = math.hermite_renormalized( + math.conj(-A), math.conj(B), math.conj(G0), shape=list(cutoffs) * 2 + ) # note: shape=[C1,C2,C3,...,C1,C2,C3,...] + G_ref = math.asnumpy(G_ref) - settings.PRECISION_BITS_HERMITE_POLY = original_precision + # Extract amplitudes of leftover mode from vanilla MM + ref_leftover = np.zeros([cutoffs[0]] * 2 + list(cutoffs)[1:], dtype=np.complex128) + for inds in np.ndindex(*cutoffs[1:]): + ref_leftover[tuple([slice(cutoffs[0]), slice(cutoffs[0])] + list(inds))] = G_ref[ + tuple([slice(cutoffs[0])] + list(inds) + [slice(cutoffs[0])] + list(inds)) + ] + assert np.allclose(ref_leftover, G_leftover) @pytest.mark.parametrize("precision", precisions) @@ -109,25 +105,23 @@ def test_compactFock_diagonal_gradients(precision): """ skip_np() - settings.PRECISION_BITS_HERMITE_POLY = precision - G = Ggate(num_modes=2, symplectic_trainable=True) - - def cost_fn(): - n1, n2 = 2, 4 # number of detected photons - state_opt = Vacuum(2) >> G - A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means) - probs = math.hermite_renormalized_diagonal( - math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[n1 + 1, n2 + 1] - ) - p = probs[n1, n2] - return -math.real(p) + with settings(PRECISION_BITS_HERMITE_POLY=precision): + G = Ggate(num_modes=1, symplectic_trainable=True) - opt = Optimizer(symplectic_lr=0.5) - opt.minimize(cost_fn, by_optimizing=[G], max_steps=50) - for i in range(2, min(20, len(opt.opt_history))): - assert opt.opt_history[i - 1] >= opt.opt_history[i] + def cost_fn(): + n1 = 2 # number of detected photons + state_opt = Vacuum(1) >> G + A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means) + probs = math.hermite_renormalized_diagonal( + math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[n1 + 1] + ) + p = probs[n1] + return -math.real(p) - settings.PRECISION_BITS_HERMITE_POLY = original_precision + opt = Optimizer(symplectic_lr=0.5) + opt.minimize(cost_fn, by_optimizing=[G], max_steps=5) + for i in range(2, min(20, len(opt.opt_history))): + assert opt.opt_history[i - 1] >= opt.opt_history[i] @pytest.mark.parametrize("precision", precisions) @@ -138,22 +132,20 @@ def test_compactFock_1leftover_gradients(precision): """ skip_np() - settings.PRECISION_BITS_HERMITE_POLY = precision - G = Ggate(num_modes=2, symplectic_trainable=True) - - def cost_fn(): - n2 = 3 # number of detected photons - state_opt = Vacuum(2) >> G - A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means) - marginal = math.hermite_renormalized_1leftoverMode( - math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[8, n2 + 1] - ) - conditional_state = normalize(State(dm=marginal[..., n2])) - return -fidelity(conditional_state, SqueezedVacuum(r=1)) - - opt = Optimizer(symplectic_lr=0.1) - opt.minimize(cost_fn, by_optimizing=[G], max_steps=50) - for i in range(2, min(20, len(opt.opt_history))): - assert opt.opt_history[i - 1] >= opt.opt_history[i] - - settings.PRECISION_BITS_HERMITE_POLY = original_precision + with settings(PRECISION_BITS_HERMITE_POLY=precision): + G = Ggate(num_modes=2, symplectic_trainable=True) + + def cost_fn(): + n2 = 2 # number of detected photons + state_opt = Vacuum(2) >> G + A, B, G0 = wigner_to_bargmann_rho(state_opt.cov, state_opt.means) + marginal = math.hermite_renormalized_1leftoverMode( + math.conj(-A), math.conj(B), math.conj(G0), cutoffs=[2, n2 + 1] + ) + conditional_state = normalize(State(dm=marginal[..., n2])) + return -fidelity(conditional_state, SqueezedVacuum(r=1)) + + opt = Optimizer(symplectic_lr=0.1) + opt.minimize(cost_fn, by_optimizing=[G], max_steps=5) + for i in range(2, len(opt.opt_history)): + assert opt.opt_history[i - 1] >= opt.opt_history[i]