diff --git a/qsimcirq/__init__.py b/qsimcirq/__init__.py index 89c6197e..0f022a1a 100644 --- a/qsimcirq/__init__.py +++ b/qsimcirq/__init__.py @@ -53,12 +53,7 @@ def _load_qsim_custatevec(): qsim_custatevec = _load_qsim_custatevec() from .qsim_circuit import add_op_to_opstring, add_op_to_circuit, QSimCircuit -from .qsim_simulator import ( - QSimOptions, - QSimSimulatorState, - QSimSimulatorTrialResult, - QSimSimulator, -) +from .qsim_simulator import QSimOptions, QSimSimulator from .qsimh_simulator import QSimhSimulator from qsimcirq._version import ( diff --git a/qsimcirq/qsim_circuit.py b/qsimcirq/qsim_circuit.py index f7e2eda2..0da46be0 100644 --- a/qsimcirq/qsim_circuit.py +++ b/qsimcirq/qsim_circuit.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import Dict, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import cirq import numpy as np @@ -210,26 +210,33 @@ def _has_cirq_gate_kind(op: cirq.Operation): return any(t in TYPE_TRANSLATOR for t in type(op.gate).mro()) -def _control_details(gate: cirq.ControlledGate, qubits): - control_qubits = [] - control_values = [] +def _control_details( + gate: cirq.ControlledGate, qubits: Sequence[cirq.Qid] +) -> Tuple[List[cirq.Qid], List[int]]: + control_qubits: List[cirq.Qid] = [] + control_values: List[int] = [] # TODO: support qudit control - for i, cvs in enumerate(gate.control_values): + assignments = list(gate.control_values.expand()) + if len(qubits) > 1 and len(assignments) > 1: + raise ValueError( + f"Cannot translate controlled gate with multiple assignments for multiple qubits: {gate}" + ) + for q, cvs in zip(qubits, zip(*assignments)): if 0 in cvs and 1 in cvs: # This qubit does not affect control. continue - elif 0 not in cvs and 1 not in cvs: - # This gate will never trigger. - warnings.warn(f"Gate has no valid control value: {gate}", RuntimeWarning) - return (None, None) + elif any(cv not in (0, 1) for cv in cvs): + raise ValueError( + f"Cannot translate control values other than 0 and 1: cvs={cvs}" + ) # Either 0 or 1 is in cvs, but not both. - control_qubits.append(qubits[i]) + control_qubits.append(q) if 0 in cvs: control_values.append(0) elif 1 in cvs: control_values.append(1) - return (control_qubits, control_values) + return control_qubits, control_values def add_op_to_opstring( @@ -271,7 +278,9 @@ def add_op_to_circuit( qsim_qubits = qubits is_controlled = isinstance(qsim_gate, cirq.ControlledGate) if is_controlled: - control_qubits, control_values = _control_details(qsim_gate, qubits) + control_qubits, control_values = _control_details( + qsim_gate, qubits[: qsim_gate.num_controls()] + ) if control_qubits is None: # This gate has no valid control, and will be omitted. return diff --git a/qsimcirq/qsim_simulator.py b/qsimcirq/qsim_simulator.py index 5835c574..b59a6f36 100644 --- a/qsimcirq/qsim_simulator.py +++ b/qsimcirq/qsim_simulator.py @@ -14,7 +14,7 @@ from collections import deque from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union import cirq @@ -24,68 +24,6 @@ import qsimcirq.qsim_circuit as qsimc -class QSimSimulatorState(cirq.StateVectorSimulatorState): - def __init__(self, qsim_data: np.ndarray, qubit_map: Dict[cirq.Qid, int]): - state_vector = qsim_data.view(np.complex64) - super().__init__(state_vector=state_vector, qubit_map=qubit_map) - - -@cirq.value_equality(unhashable=True) -class QSimSimulatorTrialResult(cirq.StateVectorMixin, cirq.SimulationTrialResult): - def __init__( - self, - params: cirq.ParamResolver, - measurements: Dict[str, np.ndarray], - final_simulator_state: QSimSimulatorState, - ): - super().__init__( - params=params, - measurements=measurements, - final_simulator_state=final_simulator_state, - ) - - # The following methods are (temporarily) copied here from - # cirq.StateVectorTrialResult due to incompatibility with the - # intermediate state simulation support which that class requires. - # TODO: remove these methods once inheritance is restored. - - @property - def final_state_vector(self): - return self._final_simulator_state.state_vector - - def state_vector(self): - """Return the state vector at the end of the computation.""" - return self._final_simulator_state.state_vector.copy() - - def _value_equality_values_(self): - measurements = {k: v.tolist() for k, v in sorted(self.measurements.items())} - return (self.params, measurements, self._final_simulator_state) - - def __str__(self) -> str: - samples = super().__str__() - final = self.state_vector() - if len([1 for e in final if abs(e) > 0.001]) < 16: - state_vector = self.dirac_notation(3) - else: - state_vector = str(final) - return f"measurements: {samples}\noutput vector: {state_vector}" - - def _repr_pretty_(self, p: Any, cycle: bool) -> None: - """Text output in Jupyter.""" - if cycle: - # There should never be a cycle. This is just in case. - p.text("StateVectorTrialResult(...)") - else: - p.text(str(self)) - - def __repr__(self) -> str: - return ( - f"cirq.StateVectorTrialResult(params={self.params!r}, " - f"measurements={self.measurements!r}, " - f"final_simulator_state={self._final_simulator_state!r})" - ) - - # This should probably live in Cirq... # TODO: update to support CircuitOperations. def _needs_trajectories(circuit: cirq.Circuit) -> bool: @@ -189,7 +127,7 @@ class MeasInfo: class QSimSimulator( cirq.SimulatesSamples, cirq.SimulatesAmplitudes, - cirq.SimulatesFinalState, + cirq.SimulatesFinalState[cirq.StateVectorTrialResult], cirq.SimulatesExpectationValues, ): def __init__( @@ -438,13 +376,13 @@ def _sample_measure_results( return results - def compute_amplitudes_sweep( + def compute_amplitudes_sweep_iter( self, program: cirq.Circuit, bitstrings: Sequence[int], params: cirq.Sweepable, qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT, - ) -> Sequence[Sequence[complex]]: + ) -> Iterator[Sequence[complex]]: """Computes the desired amplitudes using qsim. The initial state is assumed to be the all zeros state. @@ -460,8 +398,8 @@ def compute_amplitudes_sweep( often used in specifying the initial state, i.e. the ordering of the computational basis states. - Returns: - List of amplitudes. + Yields: + Amplitudes. """ # Add noise to the circuit if a noise model was provided. @@ -484,7 +422,6 @@ def compute_amplitudes_sweep( param_resolvers = cirq.to_resolvers(params) - trials_results = [] if _needs_trajectories(program): translator_fn_name = "translate_cirq_to_qtrajectory" simulator_fn = self._sim_module.qtrajectory_simulate @@ -500,18 +437,15 @@ def compute_amplitudes_sweep( cirq_order, ) options["s"] = self.get_seed() - amplitudes = simulator_fn(options) - trials_results.append(amplitudes) + yield simulator_fn(options) - return trials_results - - def simulate_sweep( + def simulate_sweep_iter( self, program: cirq.Circuit, params: cirq.Sweepable, qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT, initial_state: Optional[Union[int, np.ndarray]] = None, - ) -> List["SimulationTrialResult"]: + ) -> Iterator[cirq.StateVectorTrialResult]: """Simulates the supplied Circuit. This method returns a result which allows access to the entire @@ -572,7 +506,6 @@ def simulate_sweep( f"Expected: {2**num_qubits * 2} Received: {len(input_vector)}" ) - trials_results = [] if _needs_trajectories(program): translator_fn_name = "translate_cirq_to_qtrajectory" fullstate_simulator_fn = self._sim_module.qtrajectory_simulate_fullstate @@ -589,7 +522,6 @@ def simulate_sweep( cirq_order, ) options["s"] = self.get_seed() - qubit_map = {qubit: index for index, qubit in enumerate(qsim_order)} if isinstance(initial_state, int): qsim_state = fullstate_simulator_fn(options, initial_state) @@ -597,17 +529,17 @@ def simulate_sweep( qsim_state = fullstate_simulator_fn(options, input_vector) assert qsim_state.dtype == np.float32 assert qsim_state.ndim == 1 - final_state = QSimSimulatorState(qsim_state, qubit_map) + + final_state = cirq.StateVectorSimulationState( + initial_state=qsim_state.view(np.complex64), qubits=cirq_order + ) # create result for this parameter # TODO: We need to support measurements. - result = QSimSimulatorTrialResult( + yield cirq.StateVectorTrialResult( params=prs, measurements={}, final_simulator_state=final_state ) - trials_results.append(result) - return trials_results - - def simulate_expectation_values_sweep( + def simulate_expectation_values_sweep_iter( self, program: cirq.Circuit, observables: Union[cirq.PauliSumLike, List[cirq.PauliSumLike]], @@ -615,7 +547,7 @@ def simulate_expectation_values_sweep( qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT, initial_state: Any = None, permit_terminal_measurements: bool = False, - ) -> List[List[float]]: + ) -> Iterator[List[float]]: """Simulates the supplied circuit and calculates exact expectation values for the given observables on its final state. @@ -638,8 +570,8 @@ def simulate_expectation_values_sweep( is set to True. This is meant to prevent measurements from ruining expectation value calculations. - Returns: - A list of expectation values, with the value at index `n` + Yields: + Lists of expectation values, with the value at index `n` corresponding to `observables[n]` from the input. Raises: @@ -703,7 +635,6 @@ def simulate_expectation_values_sweep( f"Expected: {2**num_qubits * 2} Received: {len(input_vector)}" ) - results = [] if _needs_trajectories(program): translator_fn_name = "translate_cirq_to_qtrajectory" ev_simulator_fn = self._sim_module.qtrajectory_simulate_expectation_values @@ -724,9 +655,7 @@ def simulate_expectation_values_sweep( evs = ev_simulator_fn(options, opsums_and_qubit_counts, initial_state) elif isinstance(initial_state, np.ndarray): evs = ev_simulator_fn(options, opsums_and_qubit_counts, input_vector) - results.append(evs) - - return results + yield evs def simulate_moment_expectation_values( self, @@ -870,20 +799,13 @@ def _translate_circuit( translator_fn_name: str, qubit_order: cirq.QubitOrderOrList, ): - # If the circuit is memoized, reuse the corresponding translated - # circuit. - translated_circuit = None - for original, translated, m_indices in self._translated_circuits: + # If the circuit is memoized, reuse the corresponding translated circuit. + for original, translated, moment_indices in self._translated_circuits: if original == circuit: - translated_circuit = translated - moment_indices = m_indices - break - - if translated_circuit is None: - translator_fn = getattr(circuit, translator_fn_name) - translated_circuit, moment_indices = translator_fn(qubit_order) - self._translated_circuits.append( - (circuit, translated_circuit, moment_indices) - ) + return translated, moment_indices + + translator_fn = getattr(circuit, translator_fn_name) + translated, moment_indices = translator_fn(qubit_order) + self._translated_circuits.append((circuit, translated, moment_indices)) - return translated_circuit, moment_indices + return translated, moment_indices diff --git a/qsimcirq_tests/qsimcirq_test.py b/qsimcirq_tests/qsimcirq_test.py index 6ed7ca3a..8a92537f 100644 --- a/qsimcirq_tests/qsimcirq_test.py +++ b/qsimcirq_tests/qsimcirq_test.py @@ -839,9 +839,10 @@ def test_control_values(): cirq.X(qubits[2]).controlled_by(*qubits[:2], control_values=[1, 2]), ) qsimSim = qsimcirq.QSimSimulator() - with pytest.warns(RuntimeWarning, match="Gate has no valid control value"): - result = qsimSim.simulate(cirq_circuit, qubit_order=qubits) - assert result.state_vector()[0] == 1 + with pytest.raises( + ValueError, match="Cannot translate control values other than 0 and 1" + ): + _ = qsimSim.simulate(cirq_circuit, qubit_order=qubits) def test_control_limits(): @@ -1659,7 +1660,7 @@ def test_cirq_qsim_all_supported_gates(): qsim_result = qsim_simulator.simulate(circuit) assert cirq.linalg.allclose_up_to_global_phase( - qsim_result.state_vector(), cirq_result.state_vector() + qsim_result.state_vector(), cirq_result.state_vector(), atol=1e-5 ) diff --git a/requirements.txt b/requirements.txt index 981f22a2..6ee5dff9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ absl-py -cirq-core +cirq-core~=1.0 numpy~=1.21 pybind11 typing_extensions