From a8e1d455e4f1012626f420509bf2b78d615468d3 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Wed, 28 Jun 2023 18:51:30 +0100 Subject: [PATCH] Add support for > 32 qudits to cirq.sample_state_vector. Fix for #6031 (#6090) * Add support for > 32 qudits to cirq.sample_state_vector. Fix for #6031 * refactor the method into a seprate util module * fix lint * accept only probabilities not complex numbers * added tests --- cirq/linalg/__init__.py | 2 + cirq/linalg/transformations.py | 64 +++++++++++++++++++++++++++++ cirq/linalg/transformations_test.py | 16 ++++++++ cirq/sim/density_matrix_utils.py | 28 +------------ cirq/sim/simulation_utils.py | 59 ++++++++++++++++++++++++++ cirq/sim/simulation_utils_test.py | 32 +++++++++++++++ cirq/sim/state_vector.py | 39 +++--------------- cirq/sim/state_vector_test.py | 61 ++++++++++++++++++++------- 8 files changed, 226 insertions(+), 75 deletions(-) create mode 100644 cirq/sim/simulation_utils.py create mode 100644 cirq/sim/simulation_utils_test.py diff --git a/cirq/linalg/__init__.py b/cirq/linalg/__init__.py index 62d21593551..e0181859837 100644 --- a/cirq/linalg/__init__.py +++ b/cirq/linalg/__init__.py @@ -81,4 +81,6 @@ targeted_conjugate_about, targeted_left_multiply, to_special, + transpose_flattened_array, + can_numpy_support_shape, ) diff --git a/cirq/linalg/transformations.py b/cirq/linalg/transformations.py index 6a04ce1dda0..ba02322c62f 100644 --- a/cirq/linalg/transformations.py +++ b/cirq/linalg/transformations.py @@ -29,6 +29,8 @@ # user provides a different np.array([]) value. RaiseValueErrorIfNotProvided: np.ndarray = np.array([]) +_NPY_MAXDIMS = 32 # Should be changed once numpy/numpy#5744 is resolved. + def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float): """Raises a matrix with two opposing eigenvalues to a power. @@ -746,3 +748,65 @@ def transpose_density_matrix_to_axis_order(t: np.ndarray, axes: Sequence[int]): """ axes = list(axes) + [i + len(axes) for i in axes] return transpose_state_vector_to_axis_order(t, axes) + + +def _volumes(shape: Sequence[int]) -> List[int]: + r"""Returns a list of the volume spanned by each dimension. + + Given a shape=[d_0, d_1, .., d_n] the volume spanned by each dimension is + volume[i] = `\prod_{j=i+1}^n d_j` + + Args: + shape: Sequence of the size of each dimension. + + Returns: + Sequence of the volume spanned of each dimension. + """ + volume = [0] * len(shape) + v = 1 + for i in reversed(range(len(shape))): + volume[i] = v + v *= shape[i] + return volume + + +def _coordinates_from_index(idx: int, volume: Sequence[int]) -> Sequence[int]: + ret = [] + for v in volume: + ret.append(idx // v) + idx %= v + return tuple(ret) + + +def _index_from_coordinates(s: Sequence[int], volume: Sequence[int]) -> int: + return np.dot(s, volume) + + +def transpose_flattened_array(t: np.ndarray, shape: Sequence[int], axes: Sequence[int]): + """Transposes a flattened array. + + Equivalent to np.transpose(t.reshape(shape), axes).reshape((-1,)). + + Args: + t: flat array. + shape: the shape of `t` before flattening. + axes: permutation of range(len(shape)). + + Returns: + Flattened transpose of `t`. + """ + if len(t.shape) != 1: + t = t.reshape((-1,)) + cur_volume = _volumes(shape) + new_volume = _volumes([shape[i] for i in axes]) + ret = np.zeros_like(t) + for idx in range(t.shape[0]): + cell = _coordinates_from_index(idx, cur_volume) + new_cell = [cell[i] for i in axes] + ret[_index_from_coordinates(new_cell, new_volume)] = t[idx] + return ret + + +def can_numpy_support_shape(shape: Sequence[int]) -> bool: + """Returns whether numpy supports the given shape or not numpy/numpy#5744.""" + return len(shape) <= _NPY_MAXDIMS diff --git a/cirq/linalg/transformations_test.py b/cirq/linalg/transformations_test.py index 160537045fa..bf30ad60089 100644 --- a/cirq/linalg/transformations_test.py +++ b/cirq/linalg/transformations_test.py @@ -17,6 +17,7 @@ import cirq import cirq.testing +from cirq import linalg def test_reflection_matrix_pow_consistent_results(): @@ -632,3 +633,18 @@ def test_factor_state_vector(state_1: int, state_2: int): # All phase goes into a1, and b1 is just the dephased state vector assert np.allclose(a1, a * phase) assert np.allclose(b1, b) + + +@pytest.mark.parametrize('num_dimensions', [*range(1, 7)]) +def test_transpose_flattened_array(num_dimensions): + np.random.seed(0) + for _ in range(10): + shape = np.random.randint(1, 5, (num_dimensions,)).tolist() + axes = np.random.permutation(num_dimensions).tolist() + volume = np.prod(shape) + A = np.random.permutation(volume) + want = np.transpose(A.reshape(shape), axes) + got = linalg.transpose_flattened_array(A, shape, axes).reshape(want.shape) + assert np.array_equal(want, got) + got = linalg.transpose_flattened_array(A.reshape(shape), shape, axes).reshape(want.shape) + assert np.array_equal(want, got) diff --git a/cirq/sim/density_matrix_utils.py b/cirq/sim/density_matrix_utils.py index a6235c70dba..bb3d87195e7 100644 --- a/cirq/sim/density_matrix_utils.py +++ b/cirq/sim/density_matrix_utils.py @@ -18,6 +18,7 @@ import numpy as np from cirq import linalg, value +from cirq.sim import simulation_utils if TYPE_CHECKING: import cirq @@ -188,33 +189,8 @@ def _probs( """Returns the probabilities for a measurement on the given indices.""" # Only diagonal elements matter. all_probs = np.diagonal(np.reshape(density_matrix, (np.prod(qid_shape, dtype=np.int64),) * 2)) - # Shape into a tensor - tensor = np.reshape(all_probs, qid_shape) - - # Calculate the probabilities for measuring the particular results. - if len(indices) == len(qid_shape): - # We're measuring every qudit, so no need for fancy indexing - probs = np.abs(tensor) - probs = np.transpose(probs, indices) - probs = probs.reshape(-1) - else: - # Fancy indexing required - meas_shape = tuple(qid_shape[i] for i in indices) - probs = np.abs( - [ - tensor[ - linalg.slice_for_qubits_equal_to( - indices, big_endian_qureg_value=b, qid_shape=qid_shape - ) - ] - for b in range(np.prod(meas_shape, dtype=np.int64)) - ] - ) - probs = np.sum(probs, axis=tuple(range(1, len(probs.shape)))) - # To deal with rounding issues, ensure that the probabilities sum to 1. - probs /= np.sum(probs) - return probs + return simulation_utils.state_probabilities_by_indices(all_probs.real, indices, qid_shape) def _validate_density_matrix_qid_shape( diff --git a/cirq/sim/simulation_utils.py b/cirq/sim/simulation_utils.py new file mode 100644 index 00000000000..00934962d2e --- /dev/null +++ b/cirq/sim/simulation_utils.py @@ -0,0 +1,59 @@ +# Copyright 2023 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence, Tuple + +import numpy as np + +from cirq import linalg + + +def state_probabilities_by_indices( + state_probability: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...] +) -> np.ndarray: + """Returns the probabilities for a state/measurement on the given indices. + + Args: + state_probability: The multi-qubit state probability vector. This is an + array of 2 to the power of the number of real numbers, and + so state must be of size ``2**integer``. The `state_probability` can be + a vector of size ``2**integer`` or a tensor of shape + ``(2, 2, ..., 2)``. + indices: Which qubits are measured. The `state_probability` is assumed to be + supplied in big endian order. That is the xth index of v, when + expressed as a bitstring, has its largest values in the 0th index. + qid_shape: The qid shape of the `state_probability`. + + Returns: + State probabilities. + """ + probs = state_probability.reshape((-1,)) + not_measured = [i for i in range(len(qid_shape)) if i not in indices] + if linalg.can_numpy_support_shape(qid_shape): + # Use numpy transpose if we can since it's more efficient. + probs = probs.reshape(qid_shape) + probs = np.transpose(probs, list(indices) + not_measured) + probs = probs.reshape((-1,)) + else: + # If we can't use numpy due to numpy/numpy#5744, use a slower method. + probs = linalg.transpose_flattened_array(probs, qid_shape, list(indices) + not_measured) + + if len(not_measured): + # Not all qudits are measured. + volume = np.prod([qid_shape[i] for i in indices]) + # Reshape into a 2D array in which each of the measured states correspond to a row. + probs = probs.reshape((volume, -1)) + probs = np.sum(probs, axis=-1) + + # To deal with rounding issues, ensure that the probabilities sum to 1. + return probs / np.sum(probs) diff --git a/cirq/sim/simulation_utils_test.py b/cirq/sim/simulation_utils_test.py new file mode 100644 index 00000000000..2e8736029ca --- /dev/null +++ b/cirq/sim/simulation_utils_test.py @@ -0,0 +1,32 @@ +# Copyright 2023 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import numpy as np + +from cirq.sim import simulation_utils +from cirq import testing + + +@pytest.mark.parametrize('n,m', [(n, m) for n in range(1, 4) for m in range(1, n + 1)]) +def test_state_probabilities_by_indices(n: int, m: int): + np.random.seed(0) + state = testing.random_superposition(1 << n) + d = (state.conj() * state).real + desired_axes = list(np.random.choice(n, m, replace=False)) + not_wanted = [i for i in range(n) if i not in desired_axes] + got = simulation_utils.state_probabilities_by_indices(d, desired_axes, (2,) * n) + want = np.transpose(d.reshape((2,) * n), desired_axes + not_wanted) + want = np.sum(want.reshape((1 << len(desired_axes), -1)), axis=-1) + np.testing.assert_allclose(want, got) diff --git a/cirq/sim/state_vector.py b/cirq/sim/state_vector.py index 2beea5539f5..7250e6cb1f6 100644 --- a/cirq/sim/state_vector.py +++ b/cirq/sim/state_vector.py @@ -19,7 +19,7 @@ import numpy as np from cirq import linalg, qis, value -from cirq.sim import simulator +from cirq.sim import simulator, simulation_utils if TYPE_CHECKING: import cirq @@ -215,7 +215,8 @@ def sample_state_vector( prng = value.parse_random_state(seed) # Calculate the measurement probabilities. - probs = _probs(state_vector, indices, shape) + probs = (state_vector * state_vector.conj()).real + probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape) # We now have the probability vector, correctly ordered, so sample over # it. Note that we us ints here, since numpy's choice does not allow for @@ -288,7 +289,8 @@ def measure_state_vector( initial_shape = state_vector.shape # Calculate the measurement probabilities and then make the measurement. - probs = _probs(state_vector, indices, shape) + probs = (state_vector * state_vector.conj()).real + probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape) result = prng.choice(len(probs), p=probs) ###measurement_bits = [(1 & (result >> i)) for i in range(len(indices))] # Convert to individual qudit measurements. @@ -321,34 +323,3 @@ def measure_state_vector( assert out is not None # We mutate and return out, so mypy cannot identify that the out cannot be None. return measurement_bits, out - - -def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray: - """Returns the probabilities for a measurement on the given indices.""" - tensor = np.reshape(state, qid_shape) - # Calculate the probabilities for measuring the particular results. - if len(indices) == len(qid_shape): - # We're measuring every qudit, so no need for fancy indexing - probs = np.abs(tensor) ** 2 - probs = np.transpose(probs, indices) - probs = probs.reshape(-1) - else: - # Fancy indexing required - meas_shape = tuple(qid_shape[i] for i in indices) - probs = ( - np.abs( - [ - tensor[ - linalg.slice_for_qubits_equal_to( - indices, big_endian_qureg_value=b, qid_shape=qid_shape - ) - ] - for b in range(np.prod(meas_shape, dtype=np.int64)) - ] - ) - ** 2 - ) - probs = np.sum(probs, axis=tuple(range(1, len(probs.shape)))) - - # To deal with rounding issues, ensure that the probabilities sum to 1. - return probs / np.sum(probs) diff --git a/cirq/sim/state_vector_test.py b/cirq/sim/state_vector_test.py index 4f5f2ea342c..206c2220fb1 100644 --- a/cirq/sim/state_vector_test.py +++ b/cirq/sim/state_vector_test.py @@ -21,6 +21,7 @@ import cirq import cirq.testing +from cirq import linalg def test_state_mixin(): @@ -172,7 +173,9 @@ def test_sample_no_indices_repetitions(): ) -def test_measure_state_computational_basis(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_computational_basis(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose results = [] for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) @@ -183,7 +186,9 @@ def test_measure_state_computational_basis(): assert results == expected -def test_measure_state_reshape(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_reshape(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose results = [] for x in range(8): initial_state = np.reshape(cirq.to_valid_state_vector(x, 3), [2] * 3) @@ -194,7 +199,9 @@ def test_measure_state_reshape(): assert results == expected -def test_measure_state_partial_indices(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_partial_indices(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose for index in range(3): for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) @@ -203,7 +210,9 @@ def test_measure_state_partial_indices(): assert bits == [bool(1 & (x >> (2 - index)))] -def test_measure_state_partial_indices_order(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_partial_indices_order(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) bits, state = cirq.measure_state_vector(initial_state, [2, 1]) @@ -211,7 +220,9 @@ def test_measure_state_partial_indices_order(): assert bits == [bool(1 & (x >> 0)), bool(1 & (x >> 1))] -def test_measure_state_partial_indices_all_orders(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_partial_indices_all_orders(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose for perm in itertools.permutations([0, 1, 2]): for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) @@ -220,7 +231,9 @@ def test_measure_state_partial_indices_all_orders(): assert bits == [bool(1 & (x >> (2 - p))) for p in perm] -def test_measure_state_collapse(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_collapse(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -243,7 +256,9 @@ def test_measure_state_collapse(): assert bits == [False] -def test_measure_state_seed(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_seed(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose n = 10 initial_state = np.ones(2**n) / 2 ** (n / 2) @@ -262,7 +277,9 @@ def test_measure_state_seed(): np.testing.assert_allclose(state1, state2) -def test_measure_state_out_is_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_out_is_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -273,7 +290,9 @@ def test_measure_state_out_is_state(): assert state is initial_state -def test_measure_state_out_is_not_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_out_is_not_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -283,14 +302,18 @@ def test_measure_state_out_is_not_state(): assert out is state -def test_measure_state_not_power_of_two(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_not_power_of_two(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose with pytest.raises(ValueError, match='3'): _, _ = cirq.measure_state_vector(np.array([1, 0, 0]), [1]) with pytest.raises(ValueError, match='5'): cirq.measure_state_vector(np.array([0, 1, 0, 0, 0]), [1]) -def test_measure_state_index_out_of_range(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_index_out_of_range(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose state = cirq.to_valid_state_vector(0, 3) with pytest.raises(IndexError, match='-2'): cirq.measure_state_vector(state, [-2]) @@ -298,14 +321,18 @@ def test_measure_state_index_out_of_range(): cirq.measure_state_vector(state, [3]) -def test_measure_state_no_indices(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_no_indices(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = cirq.to_valid_state_vector(0, 3) bits, state = cirq.measure_state_vector(initial_state, []) assert [] == bits np.testing.assert_almost_equal(state, initial_state) -def test_measure_state_no_indices_out_is_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_no_indices_out_is_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = cirq.to_valid_state_vector(0, 3) bits, state = cirq.measure_state_vector(initial_state, [], out=initial_state) assert [] == bits @@ -313,7 +340,9 @@ def test_measure_state_no_indices_out_is_state(): assert state is initial_state -def test_measure_state_no_indices_out_is_not_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_no_indices_out_is_not_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = cirq.to_valid_state_vector(0, 3) out = np.zeros_like(initial_state) bits, state = cirq.measure_state_vector(initial_state, [], out=out) @@ -323,7 +352,9 @@ def test_measure_state_no_indices_out_is_not_state(): assert out is not initial_state -def test_measure_state_empty_state(): +@pytest.mark.parametrize('use_np_transpose', [False, True]) +def test_measure_state_empty_state(use_np_transpose: bool): + linalg.can_numpy_support_shape = lambda s: use_np_transpose initial_state = np.array([1.0]) bits, state = cirq.measure_state_vector(initial_state, []) assert [] == bits