Skip to content

Commit

Permalink
Merge pull request #555 from quantumlib/u/maffoo/state-vector
Browse files Browse the repository at this point in the history
Update for compatibility with cirq 1.0
  • Loading branch information
95-martin-orion authored Jul 18, 2022
2 parents b4fdca1 + d2f9c9b commit 2c7632a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 128 deletions.
7 changes: 1 addition & 6 deletions qsimcirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ def _load_qsim_custatevec():
qsim_custatevec = _load_qsim_custatevec()

from .qsim_circuit import add_op_to_opstring, add_op_to_circuit, QSimCircuit
from .qsim_simulator import (
QSimOptions,
QSimSimulatorState,
QSimSimulatorTrialResult,
QSimSimulator,
)
from .qsim_simulator import QSimOptions, QSimSimulator
from .qsimh_simulator import QSimhSimulator

from qsimcirq._version import (
Expand Down
33 changes: 21 additions & 12 deletions qsimcirq/qsim_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import warnings
from typing import Dict, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import cirq
import numpy as np
Expand Down Expand Up @@ -210,26 +210,33 @@ def _has_cirq_gate_kind(op: cirq.Operation):
return any(t in TYPE_TRANSLATOR for t in type(op.gate).mro())


def _control_details(gate: cirq.ControlledGate, qubits):
control_qubits = []
control_values = []
def _control_details(
gate: cirq.ControlledGate, qubits: Sequence[cirq.Qid]
) -> Tuple[List[cirq.Qid], List[int]]:
control_qubits: List[cirq.Qid] = []
control_values: List[int] = []
# TODO: support qudit control
for i, cvs in enumerate(gate.control_values):
assignments = list(gate.control_values.expand())
if len(qubits) > 1 and len(assignments) > 1:
raise ValueError(
f"Cannot translate controlled gate with multiple assignments for multiple qubits: {gate}"
)
for q, cvs in zip(qubits, zip(*assignments)):
if 0 in cvs and 1 in cvs:
# This qubit does not affect control.
continue
elif 0 not in cvs and 1 not in cvs:
# This gate will never trigger.
warnings.warn(f"Gate has no valid control value: {gate}", RuntimeWarning)
return (None, None)
elif any(cv not in (0, 1) for cv in cvs):
raise ValueError(
f"Cannot translate control values other than 0 and 1: cvs={cvs}"
)
# Either 0 or 1 is in cvs, but not both.
control_qubits.append(qubits[i])
control_qubits.append(q)
if 0 in cvs:
control_values.append(0)
elif 1 in cvs:
control_values.append(1)

return (control_qubits, control_values)
return control_qubits, control_values


def add_op_to_opstring(
Expand Down Expand Up @@ -271,7 +278,9 @@ def add_op_to_circuit(
qsim_qubits = qubits
is_controlled = isinstance(qsim_gate, cirq.ControlledGate)
if is_controlled:
control_qubits, control_values = _control_details(qsim_gate, qubits)
control_qubits, control_values = _control_details(
qsim_gate, qubits[: qsim_gate.num_controls()]
)
if control_qubits is None:
# This gate has no valid control, and will be omitted.
return
Expand Down
132 changes: 27 additions & 105 deletions qsimcirq/qsim_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from collections import deque
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import cirq

Expand All @@ -24,68 +24,6 @@
import qsimcirq.qsim_circuit as qsimc


class QSimSimulatorState(cirq.StateVectorSimulatorState):
def __init__(self, qsim_data: np.ndarray, qubit_map: Dict[cirq.Qid, int]):
state_vector = qsim_data.view(np.complex64)
super().__init__(state_vector=state_vector, qubit_map=qubit_map)


@cirq.value_equality(unhashable=True)
class QSimSimulatorTrialResult(cirq.StateVectorMixin, cirq.SimulationTrialResult):
def __init__(
self,
params: cirq.ParamResolver,
measurements: Dict[str, np.ndarray],
final_simulator_state: QSimSimulatorState,
):
super().__init__(
params=params,
measurements=measurements,
final_simulator_state=final_simulator_state,
)

# The following methods are (temporarily) copied here from
# cirq.StateVectorTrialResult due to incompatibility with the
# intermediate state simulation support which that class requires.
# TODO: remove these methods once inheritance is restored.

@property
def final_state_vector(self):
return self._final_simulator_state.state_vector

def state_vector(self):
"""Return the state vector at the end of the computation."""
return self._final_simulator_state.state_vector.copy()

def _value_equality_values_(self):
measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())}
return (self.params, measurements, self._final_simulator_state)

def __str__(self) -> str:
samples = super().__str__()
final = self.state_vector()
if len([1 for e in final if abs(e) > 0.001]) < 16:
state_vector = self.dirac_notation(3)
else:
state_vector = str(final)
return f"measurements: {samples}\noutput vector: {state_vector}"

def _repr_pretty_(self, p: Any, cycle: bool) -> None:
"""Text output in Jupyter."""
if cycle:
# There should never be a cycle. This is just in case.
p.text("StateVectorTrialResult(...)")
else:
p.text(str(self))

def __repr__(self) -> str:
return (
f"cirq.StateVectorTrialResult(params={self.params!r}, "
f"measurements={self.measurements!r}, "
f"final_simulator_state={self._final_simulator_state!r})"
)


# This should probably live in Cirq...
# TODO: update to support CircuitOperations.
def _needs_trajectories(circuit: cirq.Circuit) -> bool:
Expand Down Expand Up @@ -189,7 +127,7 @@ class MeasInfo:
class QSimSimulator(
cirq.SimulatesSamples,
cirq.SimulatesAmplitudes,
cirq.SimulatesFinalState,
cirq.SimulatesFinalState[cirq.StateVectorTrialResult],
cirq.SimulatesExpectationValues,
):
def __init__(
Expand Down Expand Up @@ -438,13 +376,13 @@ def _sample_measure_results(

return results

def compute_amplitudes_sweep(
def compute_amplitudes_sweep_iter(
self,
program: cirq.Circuit,
bitstrings: Sequence[int],
params: cirq.Sweepable,
qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT,
) -> Sequence[Sequence[complex]]:
) -> Iterator[Sequence[complex]]:
"""Computes the desired amplitudes using qsim.
The initial state is assumed to be the all zeros state.
Expand All @@ -460,8 +398,8 @@ def compute_amplitudes_sweep(
often used in specifying the initial state, i.e. the ordering of the
computational basis states.
Returns:
List of amplitudes.
Yields:
Amplitudes.
"""

# Add noise to the circuit if a noise model was provided.
Expand All @@ -484,7 +422,6 @@ def compute_amplitudes_sweep(

param_resolvers = cirq.to_resolvers(params)

trials_results = []
if _needs_trajectories(program):
translator_fn_name = "translate_cirq_to_qtrajectory"
simulator_fn = self._sim_module.qtrajectory_simulate
Expand All @@ -500,18 +437,15 @@ def compute_amplitudes_sweep(
cirq_order,
)
options["s"] = self.get_seed()
amplitudes = simulator_fn(options)
trials_results.append(amplitudes)
yield simulator_fn(options)

return trials_results

def simulate_sweep(
def simulate_sweep_iter(
self,
program: cirq.Circuit,
params: cirq.Sweepable,
qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT,
initial_state: Optional[Union[int, np.ndarray]] = None,
) -> List["SimulationTrialResult"]:
) -> Iterator[cirq.StateVectorTrialResult]:
"""Simulates the supplied Circuit.
This method returns a result which allows access to the entire
Expand Down Expand Up @@ -572,7 +506,6 @@ def simulate_sweep(
f"Expected: {2**num_qubits * 2} Received: {len(input_vector)}"
)

trials_results = []
if _needs_trajectories(program):
translator_fn_name = "translate_cirq_to_qtrajectory"
fullstate_simulator_fn = self._sim_module.qtrajectory_simulate_fullstate
Expand All @@ -589,33 +522,32 @@ def simulate_sweep(
cirq_order,
)
options["s"] = self.get_seed()
qubit_map = {qubit: index for index, qubit in enumerate(qsim_order)}

if isinstance(initial_state, int):
qsim_state = fullstate_simulator_fn(options, initial_state)
elif isinstance(initial_state, np.ndarray):
qsim_state = fullstate_simulator_fn(options, input_vector)
assert qsim_state.dtype == np.float32
assert qsim_state.ndim == 1
final_state = QSimSimulatorState(qsim_state, qubit_map)

final_state = cirq.StateVectorSimulationState(
initial_state=qsim_state.view(np.complex64), qubits=cirq_order
)
# create result for this parameter
# TODO: We need to support measurements.
result = QSimSimulatorTrialResult(
yield cirq.StateVectorTrialResult(
params=prs, measurements={}, final_simulator_state=final_state
)
trials_results.append(result)

return trials_results

def simulate_expectation_values_sweep(
def simulate_expectation_values_sweep_iter(
self,
program: cirq.Circuit,
observables: Union[cirq.PauliSumLike, List[cirq.PauliSumLike]],
params: cirq.Sweepable,
qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT,
initial_state: Any = None,
permit_terminal_measurements: bool = False,
) -> List[List[float]]:
) -> Iterator[List[float]]:
"""Simulates the supplied circuit and calculates exact expectation
values for the given observables on its final state.
Expand All @@ -638,8 +570,8 @@ def simulate_expectation_values_sweep(
is set to True. This is meant to prevent measurements from
ruining expectation value calculations.
Returns:
A list of expectation values, with the value at index `n`
Yields:
Lists of expectation values, with the value at index `n`
corresponding to `observables[n]` from the input.
Raises:
Expand Down Expand Up @@ -703,7 +635,6 @@ def simulate_expectation_values_sweep(
f"Expected: {2**num_qubits * 2} Received: {len(input_vector)}"
)

results = []
if _needs_trajectories(program):
translator_fn_name = "translate_cirq_to_qtrajectory"
ev_simulator_fn = self._sim_module.qtrajectory_simulate_expectation_values
Expand All @@ -724,9 +655,7 @@ def simulate_expectation_values_sweep(
evs = ev_simulator_fn(options, opsums_and_qubit_counts, initial_state)
elif isinstance(initial_state, np.ndarray):
evs = ev_simulator_fn(options, opsums_and_qubit_counts, input_vector)
results.append(evs)

return results
yield evs

def simulate_moment_expectation_values(
self,
Expand Down Expand Up @@ -870,20 +799,13 @@ def _translate_circuit(
translator_fn_name: str,
qubit_order: cirq.QubitOrderOrList,
):
# If the circuit is memoized, reuse the corresponding translated
# circuit.
translated_circuit = None
for original, translated, m_indices in self._translated_circuits:
# If the circuit is memoized, reuse the corresponding translated circuit.
for original, translated, moment_indices in self._translated_circuits:
if original == circuit:
translated_circuit = translated
moment_indices = m_indices
break

if translated_circuit is None:
translator_fn = getattr(circuit, translator_fn_name)
translated_circuit, moment_indices = translator_fn(qubit_order)
self._translated_circuits.append(
(circuit, translated_circuit, moment_indices)
)
return translated, moment_indices

translator_fn = getattr(circuit, translator_fn_name)
translated, moment_indices = translator_fn(qubit_order)
self._translated_circuits.append((circuit, translated, moment_indices))

return translated_circuit, moment_indices
return translated, moment_indices
9 changes: 5 additions & 4 deletions qsimcirq_tests/qsimcirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,10 @@ def test_control_values():
cirq.X(qubits[2]).controlled_by(*qubits[:2], control_values=[1, 2]),
)
qsimSim = qsimcirq.QSimSimulator()
with pytest.warns(RuntimeWarning, match="Gate has no valid control value"):
result = qsimSim.simulate(cirq_circuit, qubit_order=qubits)
assert result.state_vector()[0] == 1
with pytest.raises(
ValueError, match="Cannot translate control values other than 0 and 1"
):
_ = qsimSim.simulate(cirq_circuit, qubit_order=qubits)


def test_control_limits():
Expand Down Expand Up @@ -1659,7 +1660,7 @@ def test_cirq_qsim_all_supported_gates():
qsim_result = qsim_simulator.simulate(circuit)

assert cirq.linalg.allclose_up_to_global_phase(
qsim_result.state_vector(), cirq_result.state_vector()
qsim_result.state_vector(), cirq_result.state_vector(), atol=1e-5
)


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
absl-py
cirq-core
cirq-core~=1.0
numpy~=1.21
pybind11
typing_extensions

0 comments on commit 2c7632a

Please sign in to comment.