Skip to content

Commit

Permalink
Merge pull request #531 from quantumlib/u/maffoo/repeated-keys
Browse files Browse the repository at this point in the history
Add support for repeated keys in QSimSimulator
  • Loading branch information
95-martin-orion authored Jun 6, 2022
2 parents 4ca3c88 + ce16a06 commit 425c819
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 36 deletions.
94 changes: 58 additions & 36 deletions qsimcirq/qsim_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

from collections import deque
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from xml.etree.ElementPath import ops
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import cirq

Expand Down Expand Up @@ -168,6 +167,25 @@ def as_dict(self):
}


@dataclass
class MeasInfo:
"""Info about each measure operation in the circuit being simulated.
Attributes:
key: The measurement key.
idx: The "instance" of a possibly-repeated measurement key.
invert_mask: True for any measurement bits that should be inverted.
start: Start index in qsim's output array for this measurement.
end: End index (non-inclusive) in qsim's output array.
"""

key: str
idx: int
invert_mask: Tuple[bool, ...]
start: int
end: int


class QSimSimulator(
cirq.SimulatesSamples,
cirq.SimulatesAmplitudes,
Expand Down Expand Up @@ -313,40 +331,54 @@ def _sample_measure_results(

qubit_map = {qubit: index for index, qubit in enumerate(ordered_qubits)}

# Computes
# - the list of qubits to be measured
# - the start (inclusive) and end (exclusive) indices of each measurement
# - a mapping from measurement key to measurement gate
# Compute:
# - number of qubits for each measurement key.
# - measurement ops for each measurement key.
# - measurement info for each measurement.
# - total number of measured bits.
measurement_ops = [
op
for _, op, _ in program.findall_operations_with_gate_type(
cirq.MeasurementGate
)
]
measured_qubits: List[cirq.Qid] = []
bounds: Dict[str, Tuple] = {}
num_qubits_by_key: Dict[str, int] = {}
meas_ops: Dict[str, List[cirq.GateOperation]] = {}
current_index = 0
meas_infos: List[MeasInfo] = []
num_bits = 0
for op in measurement_ops:
gate = op.gate
key = cirq.measurement_key_name(gate)
meas_ops.setdefault(key, [])
i = len(meas_ops[key])
meas_ops[key].append(op)
if key in bounds:
raise ValueError(f"Duplicate MeasurementGate with key {key}")
bounds[key] = (current_index, current_index + len(op.qubits))
measured_qubits.extend(op.qubits)
current_index += len(op.qubits)
n = len(op.qubits)
if key in num_qubits_by_key:
if n != num_qubits_by_key[key]:
raise ValueError(
f"repeated key {key!r} with different numbers of qubits: "
f"{num_qubits_by_key[key]} != {n}"
)
else:
num_qubits_by_key[key] = n
meas_infos.append(
MeasInfo(
key=key,
idx=i,
invert_mask=gate.full_invert_mask(),
start=num_bits,
end=num_bits + n,
)
)
num_bits += n

# Set qsim options
options = {}
options.update(self.qsim_options)
options = {**self.qsim_options}

results = {}
for key, bound in bounds.items():
results[key] = np.ndarray(
shape=(repetitions, len(meas_ops[key]), bound[1] - bound[0]), dtype=int
)
results = {
key: np.ndarray(shape=(repetitions, len(meas_ops[key]), n), dtype=int)
for key, n in num_qubits_by_key.items()
}

noisy = _needs_trajectories(program)
if not noisy and program.are_all_measurements_terminal() and repetitions > 1:
Expand Down Expand Up @@ -394,25 +426,15 @@ def _sample_measure_results(
translator_fn_name,
cirq.QubitOrder.DEFAULT,
)
measurements = np.empty(
shape=(
repetitions,
sum(
cirq.num_qubits(op)
for oplist in meas_ops.values()
for op in oplist
),
),
dtype=int,
)
measurements = np.empty(shape=(repetitions, num_bits), dtype=int)
for i in range(repetitions):
options["s"] = self.get_seed()
measurements[i] = sampler_fn(options)

for key, (start, end) in bounds.items():
for i, op in enumerate(meas_ops[key]):
invert_mask = op.gate.full_invert_mask()
results[key][:, i, :] = measurements[:, start:end] ^ invert_mask
for m in meas_infos:
results[m.key][:, m.idx, :] = (
measurements[:, m.start : m.end] ^ m.invert_mask
)

return results

Expand Down
43 changes: 43 additions & 0 deletions qsimcirq_tests/qsimcirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,49 @@ def test_empty_moment(mode: str):
assert result.final_state_vector.shape == (4,)


def test_repeated_keys():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.Moment(cirq.measure(q0, key="m")),
cirq.Moment(cirq.X(q1)),
cirq.Moment(cirq.measure(q1, key="m")),
cirq.Moment(cirq.X(q0)),
cirq.Moment(cirq.measure(q0, key="m")),
cirq.Moment(cirq.X(q1)),
cirq.Moment(cirq.measure(q1, key="m")),
)
result = qsimcirq.QSimSimulator().run(circuit, repetitions=10)
assert result.records["m"].shape == (10, 4, 1)
assert np.all(result.records["m"][:, 0, :] == 0)
assert np.all(result.records["m"][:, 1, :] == 1)
assert np.all(result.records["m"][:, 2, :] == 1)
assert np.all(result.records["m"][:, 3, :] == 0)


def test_repeated_keys_same_moment():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.Moment(cirq.X(q1)),
cirq.Moment(cirq.measure(q0, key="m"), cirq.measure(q1, key="m")),
)
result = qsimcirq.QSimSimulator().run(circuit, repetitions=10)
assert result.records["m"].shape == (10, 2, 1)
assert np.all(result.records["m"][:, 0, :] == 0)
assert np.all(result.records["m"][:, 1, :] == 1)


def test_repeated_keys_different_numbers_of_qubits():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, key="m"),
cirq.measure(q0, q1, key="m"),
)
with pytest.raises(
ValueError, match="repeated key 'm' with different numbers of qubits"
):
_ = qsimcirq.QSimSimulator().run(circuit, repetitions=10)


def test_cirq_too_big_gate():
# Pick qubits.
a, b, c, d, e, f, g = [
Expand Down

0 comments on commit 425c819

Please sign in to comment.