Skip to content

Commit

Permalink
Add support for > 32 qudits to cirq.sample_state_vector. Fix for quan…
Browse files Browse the repository at this point in the history
…tumlib#6031 (quantumlib#6090)

* Add support for > 32 qudits to cirq.sample_state_vector. Fix for quantumlib#6031

* refactor the method into a seprate util module

* fix lint

* accept only probabilities not complex numbers

* added tests
  • Loading branch information
NoureldinYosri authored Jun 28, 2023
1 parent 7b143cc commit a8e1d45
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 75 deletions.
2 changes: 2 additions & 0 deletions cirq/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,6 @@
targeted_conjugate_about,
targeted_left_multiply,
to_special,
transpose_flattened_array,
can_numpy_support_shape,
)
64 changes: 64 additions & 0 deletions cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
# user provides a different np.array([]) value.
RaiseValueErrorIfNotProvided: np.ndarray = np.array([])

_NPY_MAXDIMS = 32 # Should be changed once numpy/numpy#5744 is resolved.


def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float):
"""Raises a matrix with two opposing eigenvalues to a power.
Expand Down Expand Up @@ -746,3 +748,65 @@ def transpose_density_matrix_to_axis_order(t: np.ndarray, axes: Sequence[int]):
"""
axes = list(axes) + [i + len(axes) for i in axes]
return transpose_state_vector_to_axis_order(t, axes)


def _volumes(shape: Sequence[int]) -> List[int]:
r"""Returns a list of the volume spanned by each dimension.
Given a shape=[d_0, d_1, .., d_n] the volume spanned by each dimension is
volume[i] = `\prod_{j=i+1}^n d_j`
Args:
shape: Sequence of the size of each dimension.
Returns:
Sequence of the volume spanned of each dimension.
"""
volume = [0] * len(shape)
v = 1
for i in reversed(range(len(shape))):
volume[i] = v
v *= shape[i]
return volume


def _coordinates_from_index(idx: int, volume: Sequence[int]) -> Sequence[int]:
ret = []
for v in volume:
ret.append(idx // v)
idx %= v
return tuple(ret)


def _index_from_coordinates(s: Sequence[int], volume: Sequence[int]) -> int:
return np.dot(s, volume)


def transpose_flattened_array(t: np.ndarray, shape: Sequence[int], axes: Sequence[int]):
"""Transposes a flattened array.
Equivalent to np.transpose(t.reshape(shape), axes).reshape((-1,)).
Args:
t: flat array.
shape: the shape of `t` before flattening.
axes: permutation of range(len(shape)).
Returns:
Flattened transpose of `t`.
"""
if len(t.shape) != 1:
t = t.reshape((-1,))
cur_volume = _volumes(shape)
new_volume = _volumes([shape[i] for i in axes])
ret = np.zeros_like(t)
for idx in range(t.shape[0]):
cell = _coordinates_from_index(idx, cur_volume)
new_cell = [cell[i] for i in axes]
ret[_index_from_coordinates(new_cell, new_volume)] = t[idx]
return ret


def can_numpy_support_shape(shape: Sequence[int]) -> bool:
"""Returns whether numpy supports the given shape or not numpy/numpy#5744."""
return len(shape) <= _NPY_MAXDIMS
16 changes: 16 additions & 0 deletions cirq/linalg/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import cirq
import cirq.testing
from cirq import linalg


def test_reflection_matrix_pow_consistent_results():
Expand Down Expand Up @@ -632,3 +633,18 @@ def test_factor_state_vector(state_1: int, state_2: int):
# All phase goes into a1, and b1 is just the dephased state vector
assert np.allclose(a1, a * phase)
assert np.allclose(b1, b)


@pytest.mark.parametrize('num_dimensions', [*range(1, 7)])
def test_transpose_flattened_array(num_dimensions):
np.random.seed(0)
for _ in range(10):
shape = np.random.randint(1, 5, (num_dimensions,)).tolist()
axes = np.random.permutation(num_dimensions).tolist()
volume = np.prod(shape)
A = np.random.permutation(volume)
want = np.transpose(A.reshape(shape), axes)
got = linalg.transpose_flattened_array(A, shape, axes).reshape(want.shape)
assert np.array_equal(want, got)
got = linalg.transpose_flattened_array(A.reshape(shape), shape, axes).reshape(want.shape)
assert np.array_equal(want, got)
28 changes: 2 additions & 26 deletions cirq/sim/density_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

from cirq import linalg, value
from cirq.sim import simulation_utils

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -188,33 +189,8 @@ def _probs(
"""Returns the probabilities for a measurement on the given indices."""
# Only diagonal elements matter.
all_probs = np.diagonal(np.reshape(density_matrix, (np.prod(qid_shape, dtype=np.int64),) * 2))
# Shape into a tensor
tensor = np.reshape(all_probs, qid_shape)

# Calculate the probabilities for measuring the particular results.
if len(indices) == len(qid_shape):
# We're measuring every qudit, so no need for fancy indexing
probs = np.abs(tensor)
probs = np.transpose(probs, indices)
probs = probs.reshape(-1)
else:
# Fancy indexing required
meas_shape = tuple(qid_shape[i] for i in indices)
probs = np.abs(
[
tensor[
linalg.slice_for_qubits_equal_to(
indices, big_endian_qureg_value=b, qid_shape=qid_shape
)
]
for b in range(np.prod(meas_shape, dtype=np.int64))
]
)
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))

# To deal with rounding issues, ensure that the probabilities sum to 1.
probs /= np.sum(probs)
return probs
return simulation_utils.state_probabilities_by_indices(all_probs.real, indices, qid_shape)


def _validate_density_matrix_qid_shape(
Expand Down
59 changes: 59 additions & 0 deletions cirq/sim/simulation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Tuple

import numpy as np

from cirq import linalg


def state_probabilities_by_indices(
state_probability: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]
) -> np.ndarray:
"""Returns the probabilities for a state/measurement on the given indices.
Args:
state_probability: The multi-qubit state probability vector. This is an
array of 2 to the power of the number of real numbers, and
so state must be of size ``2**integer``. The `state_probability` can be
a vector of size ``2**integer`` or a tensor of shape
``(2, 2, ..., 2)``.
indices: Which qubits are measured. The `state_probability` is assumed to be
supplied in big endian order. That is the xth index of v, when
expressed as a bitstring, has its largest values in the 0th index.
qid_shape: The qid shape of the `state_probability`.
Returns:
State probabilities.
"""
probs = state_probability.reshape((-1,))
not_measured = [i for i in range(len(qid_shape)) if i not in indices]
if linalg.can_numpy_support_shape(qid_shape):
# Use numpy transpose if we can since it's more efficient.
probs = probs.reshape(qid_shape)
probs = np.transpose(probs, list(indices) + not_measured)
probs = probs.reshape((-1,))
else:
# If we can't use numpy due to numpy/numpy#5744, use a slower method.
probs = linalg.transpose_flattened_array(probs, qid_shape, list(indices) + not_measured)

if len(not_measured):
# Not all qudits are measured.
volume = np.prod([qid_shape[i] for i in indices])
# Reshape into a 2D array in which each of the measured states correspond to a row.
probs = probs.reshape((volume, -1))
probs = np.sum(probs, axis=-1)

# To deal with rounding issues, ensure that the probabilities sum to 1.
return probs / np.sum(probs)
32 changes: 32 additions & 0 deletions cirq/sim/simulation_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2023 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import numpy as np

from cirq.sim import simulation_utils
from cirq import testing


@pytest.mark.parametrize('n,m', [(n, m) for n in range(1, 4) for m in range(1, n + 1)])
def test_state_probabilities_by_indices(n: int, m: int):
np.random.seed(0)
state = testing.random_superposition(1 << n)
d = (state.conj() * state).real
desired_axes = list(np.random.choice(n, m, replace=False))
not_wanted = [i for i in range(n) if i not in desired_axes]
got = simulation_utils.state_probabilities_by_indices(d, desired_axes, (2,) * n)
want = np.transpose(d.reshape((2,) * n), desired_axes + not_wanted)
want = np.sum(want.reshape((1 << len(desired_axes), -1)), axis=-1)
np.testing.assert_allclose(want, got)
39 changes: 5 additions & 34 deletions cirq/sim/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np

from cirq import linalg, qis, value
from cirq.sim import simulator
from cirq.sim import simulator, simulation_utils

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -215,7 +215,8 @@ def sample_state_vector(
prng = value.parse_random_state(seed)

# Calculate the measurement probabilities.
probs = _probs(state_vector, indices, shape)
probs = (state_vector * state_vector.conj()).real
probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape)

# We now have the probability vector, correctly ordered, so sample over
# it. Note that we us ints here, since numpy's choice does not allow for
Expand Down Expand Up @@ -288,7 +289,8 @@ def measure_state_vector(
initial_shape = state_vector.shape

# Calculate the measurement probabilities and then make the measurement.
probs = _probs(state_vector, indices, shape)
probs = (state_vector * state_vector.conj()).real
probs = simulation_utils.state_probabilities_by_indices(probs, indices, shape)
result = prng.choice(len(probs), p=probs)
###measurement_bits = [(1 & (result >> i)) for i in range(len(indices))]
# Convert to individual qudit measurements.
Expand Down Expand Up @@ -321,34 +323,3 @@ def measure_state_vector(
assert out is not None
# We mutate and return out, so mypy cannot identify that the out cannot be None.
return measurement_bits, out


def _probs(state: np.ndarray, indices: Sequence[int], qid_shape: Tuple[int, ...]) -> np.ndarray:
"""Returns the probabilities for a measurement on the given indices."""
tensor = np.reshape(state, qid_shape)
# Calculate the probabilities for measuring the particular results.
if len(indices) == len(qid_shape):
# We're measuring every qudit, so no need for fancy indexing
probs = np.abs(tensor) ** 2
probs = np.transpose(probs, indices)
probs = probs.reshape(-1)
else:
# Fancy indexing required
meas_shape = tuple(qid_shape[i] for i in indices)
probs = (
np.abs(
[
tensor[
linalg.slice_for_qubits_equal_to(
indices, big_endian_qureg_value=b, qid_shape=qid_shape
)
]
for b in range(np.prod(meas_shape, dtype=np.int64))
]
)
** 2
)
probs = np.sum(probs, axis=tuple(range(1, len(probs.shape))))

# To deal with rounding issues, ensure that the probabilities sum to 1.
return probs / np.sum(probs)
Loading

0 comments on commit a8e1d45

Please sign in to comment.