Skip to content

Commit

Permalink
Fix snapshot wire order when simulation indices different than device…
Browse files Browse the repository at this point in the history
… wires (#6461)

**Context:**

Internally, default qubit only performs the simulation with the wires
that it needs for apply the operation. Any wires only present in
measurements are only added in at the very end. This can be
substantially more memory and time efficient in certain edge cases.

Unfortunately, this can cause confusion about the snapshotted state
taken mid-simulation. Mid simulation we have fewer wires and a different
wire order.

**Description of the Change:**

1) Add a `map_wires` method to snapshot to allow us to map the wires
2) Fill in device wires on Snapshots during `validate_device_wires`.
3) Allow `StateMP.process_state` to add in blank subsystems when the
measurement has wires not present in the state's wire order.

These three steps are sufficient to make sure that snapshots always
match the device's wire order.

**Benefits:**

Less Confusion

**Possible Drawbacks:**

It's no longer the state being used during the simulation at that point.
We are hiding away the fact we are working with a different state during
the simulation.

**Related GitHub Issues:**

Fixes #6427 [sc-76515]

---------

Co-authored-by: Yushao Chen (Jerry) <[email protected]>
Co-authored-by: Astral Cai <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2024
1 parent 76ca29e commit d5d3c2b
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 53 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-0.39.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@

<h3>Bug fixes 🐛</h3>

* The wire order for `Snapshot`'s now matches the wire order of the device, rather than the simulation.
[(#6461)](https://github.com/PennyLaneAI/pennylane/pull/6461)

* Fixes a bug where `QNSPSAOptimizer`, `QNGOptimizer` and `MomentumQNGOptimizer` calculate invalid
parameter updates if the metric tensor becomes singular.
[(#6471)](https://github.com/PennyLaneAI/pennylane/pull/6471)
Expand Down
19 changes: 17 additions & 2 deletions pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,31 @@ def validate_device_wires(
f"Cannot run circuit(s) on {name} as they contain wires "
f"not found on the device: {extra_wires}"
)
measurements = tape.measurements.copy()

modified = False
new_ops = None
for i, op in enumerate(tape.operations):
if isinstance(op, qml.Snapshot):
mp = op.hyperparameters["measurement"]
if not mp.wires:
if not new_ops:
new_ops = list(tape.operations)
modified = True
new_mp = copy(mp)
new_mp._wires = wires # pylint:disable=protected-access
new_ops[i] = qml.Snapshot(measurement=new_mp, tag=op.tag)
if not new_ops:
new_ops = tape.operations # no copy in this case

measurements = tape.measurements.copy()
for m_idx, mp in enumerate(measurements):
if not mp.obs and not mp.wires:
modified = True
new_mp = copy(mp)
new_mp._wires = wires # pylint:disable=protected-access
measurements[m_idx] = new_mp
if modified:
tape = type(tape)(tape.operations, measurements, shots=tape.shots)
tape = tape.copy(ops=new_ops, measurements=measurements)

return (tape,), null_postprocessing

Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def _apply_grover_without_matrix(state, op_wires, is_state_batched):
def apply_snapshot(
op: qml.Snapshot, state, is_state_batched: bool = False, debugger=None, **execution_kwargs
):
"""Take a snapshot of the state"""
"""Take a snapshot of the state."""
if debugger is not None and debugger.active:
measurement = op.hyperparameters["measurement"]

Expand Down
35 changes: 21 additions & 14 deletions pennylane/measurements/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,27 +174,34 @@ def cast_to_complex(state):
floating_single = "float32" in dtype or "complex64" in dtype
return qml.math.cast(state, "complex64" if floating_single else "complex128")

wires = self.wires
if not wires or wire_order == wires:
if not self.wires or wire_order == self.wires:
return cast_to_complex(state)

if set(wires) != set(wire_order):
if not all(w in self.wires for w in wire_order):
bad_wires = [w for w in wire_order if w not in self.wires]
raise WireError(
f"Unexpected unique wires {Wires.unique_wires([wires, wire_order])} found. "
f"Expected wire order {wire_order} to be a rearrangement of {wires}"
f"State wire order has wires {bad_wires} not present in "
f"measurement with wires {self.wires}. StateMP.process_state cannot trace out wires."
)

shape = (2,) * len(wires)
flat_shape = (2 ** len(wires),)
desired_axes = [wire_order.index(w) for w in wires]
if qml.math.ndim(state) == 2: # batched state
batch_size = qml.math.shape(state)[0]
shape = (batch_size,) + shape
flat_shape = (batch_size,) + flat_shape
desired_axes = [0] + [i + 1 for i in desired_axes]

shape = (2,) * len(wire_order)
batch_size = None if qml.math.ndim(state) == 1 else qml.math.shape(state)[0]
shape = (batch_size,) + shape if batch_size else shape
state = qml.math.reshape(state, shape)

if wires_to_add := Wires(set(self.wires) - set(wire_order)):
for _ in wires_to_add:
state = qml.math.stack([state, qml.math.zeros_like(state)], axis=-1)
wire_order = wire_order + wires_to_add

desired_axes = [wire_order.index(w) for w in self.wires]
if batch_size:
desired_axes = [0] + [i + 1 for i in desired_axes]
state = qml.math.transpose(state, desired_axes)

flat_shape = (2 ** len(self.wires),)
if batch_size:
flat_shape = (batch_size,) + flat_shape
state = qml.math.reshape(state, flat_shape)
return cast_to_complex(state)

Expand Down
11 changes: 8 additions & 3 deletions pennylane/ops/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
This submodule contains the discrete-variable quantum operations that do
not depend on any parameters.
"""
from collections.abc import Hashable

# pylint:disable=abstract-method,arguments-differ,protected-access,invalid-overridden-method, no-member
from copy import copy
from typing import Optional
Expand Down Expand Up @@ -216,10 +218,9 @@ def __init__(self, tag: Optional[str] = None, measurement=None):

if measurement is None:
measurement = qml.state()

if isinstance(measurement, qml.measurements.MidMeasureMP):
raise ValueError("Mid-circuit measurements can not be used in snapshots.")
if isinstance(measurement, qml.measurements.MeasurementProcess):
if isinstance(measurement, qml.measurements.MidMeasureMP):
raise ValueError("Mid-circuit measurements can not be used in snapshots.")
qml.queuing.QueuingManager.remove(measurement)
else:
raise ValueError(
Expand Down Expand Up @@ -251,6 +252,10 @@ def _controlled(self, _):
def adjoint(self):
return Snapshot(tag=self.tag, measurement=self.hyperparameters["measurement"])

def map_wires(self, wire_map: dict[Hashable, Hashable]) -> "Snapshot":
new_measurement = self.hyperparameters["measurement"].map_wires(wire_map)
return Snapshot(tag=self.tag, measurement=new_measurement)


# Since measurements are captured as variables in plxpr with the capture module,
# the measurement is treated as a traceable argument.
Expand Down
10 changes: 10 additions & 0 deletions tests/devices/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def jit_wires_dev(wires):
):
jax.jit(jit_wires_dev)([0, 1])

def test_fill_in_wires_on_snapshots(self):
"""Test that validate_device_wires also fills in the wires on snapshots."""

tape = qml.tape.QuantumScript([qml.Snapshot(), qml.Snapshot(measurement=qml.probs())])

(output,), _ = validate_device_wires(tape, wires=qml.wires.Wires((0, 1, 2)))
mp0 = qml.measurements.StateMP(wires=qml.wires.Wires((0, 1, 2)))
qml.assert_equal(output[0], qml.Snapshot(measurement=mp0))
qml.assert_equal(output[1], qml.Snapshot(measurement=qml.probs(wires=(0, 1, 2))))


class TestDecomposeValidation:
"""Unit tests for helper functions in qml.devices.qubit.preprocess"""
Expand Down
22 changes: 18 additions & 4 deletions tests/measurements/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,22 @@ def get_state(ket):

def test_wire_ordering_error(self):
"""Test that a wire order error is raised when unknown wires are given."""
with pytest.raises(WireError, match=r"Unexpected unique wires Wires\(\[0, 1, 2\]\) found"):
StateMP(wires=[0, 1]).process_state([1, 0], wire_order=Wires(2))
with pytest.raises(
WireError, match=r"State wire order has wires \[2\] not present in measurement"
):
StateMP(wires=[0, 1]).process_state(np.array([1, 0]), wire_order=Wires(2))

def test_adding_wires(self):
"""Test that process_state can add wires not present in the original wire order."""

orig_state = np.array([0, 1])
mp = StateMP(wires=Wires([0, 1]))
new_state = mp.process_state(orig_state, wire_order=Wires([1]))

expected = np.zeros((2, 2))
expected[0, 1] = 1 # zero for wire order, 1 for wire 1
expected = np.reshape(expected, (4,))
assert qml.math.allclose(expected, new_state)

@pytest.mark.parametrize(
"dm",
Expand Down Expand Up @@ -854,7 +868,7 @@ def func():
elif len(return_wire_order) == 2:
i, j = return_wire_order
exp_statevector = np.kron(single_states[i], single_states[j])
elif len(return_wire_order) == 3:
else: # len(return_wire_order) == 3
i, j, k = return_wire_order
exp_statevector = np.kron(np.kron(single_states[i], single_states[j]), single_states[k])

Expand Down Expand Up @@ -885,7 +899,7 @@ def func():
elif len(return_wire_order) == 2:
i, j = return_wire_order
exp_statevector = np.kron(single_states[i], single_states[j])
elif len(return_wire_order) == 3:
else: # len(return_wire_order) == 3
i, j, k = return_wire_order
exp_statevector = np.kron(np.kron(single_states[i], single_states[j]), single_states[k])

Expand Down
69 changes: 40 additions & 29 deletions tests/ops/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,32 +220,43 @@ def test_op_matrix_fails(self):
op.matrix()


def test_decomposition():
"""Test the decomposition of the Snapshot operation."""

assert Snapshot.compute_decomposition() == []
assert Snapshot().decomposition() == []


def test_label_method():
"""Test the label method for the Snapshot operation."""
assert Snapshot().label() == "|Snap|"
assert Snapshot("my_label").label() == "|Snap|"


def test_control():
"""Test the _controlled method for the Snapshot operation."""
assert isinstance(Snapshot()._controlled(0), Snapshot)
assert Snapshot("my_label")._controlled(0).tag == Snapshot("my_label").tag


def test_adjoint():
"""Test the adjoint method for the Snapshot operation."""
assert isinstance(Snapshot().adjoint(), Snapshot)
assert Snapshot("my_label").adjoint().tag == Snapshot("my_label").tag


def test_snapshot_no_empty_wire_list_error():
"""Test that Snapshot does not raise an empty wire error."""
snapshot = qml.Snapshot()
assert isinstance(snapshot, qml.Snapshot)
class TestSnapshot:
"""Unit tests for the snapshot class."""

def test_decomposition(self):
"""Test the decomposition of the Snapshot operation."""

assert Snapshot.compute_decomposition() == []
assert Snapshot().decomposition() == []

def test_label_method(self):
"""Test the label method for the Snapshot operation."""
assert Snapshot().label() == "|Snap|"
assert Snapshot("my_label").label() == "|Snap|"

def test_control(self):
"""Test the _controlled method for the Snapshot operation."""
assert isinstance(Snapshot()._controlled(0), Snapshot)
assert Snapshot("my_label")._controlled(0).tag == Snapshot("my_label").tag

def test_adjoint(self):
"""Test the adjoint method for the Snapshot operation."""
assert isinstance(Snapshot().adjoint(), Snapshot)
assert Snapshot("my_label").adjoint().tag == Snapshot("my_label").tag

def test_snapshot_no_empty_wire_list_error(self):
"""Test that Snapshot does not raise an empty wire error."""
snapshot = qml.Snapshot()
assert isinstance(snapshot, qml.Snapshot)

@pytest.mark.parametrize(
"mp", (qml.expval(qml.Z(0)), qml.measurements.StateMP(wires=(2, 1, 0)))
)
def test_map_wires(self, mp):
"""Test that the wires of the measurement are mapped"""
op = Snapshot(measurement=mp, tag="my tag")
wire_map = {0: "a", 1: "b"}
new_op = op.map_wires(wire_map)
target_mp = mp.map_wires(wire_map)
qml.assert_equal(target_mp, new_op.hyperparameters["measurement"])
assert new_op.tag == "my tag"
13 changes: 13 additions & 0 deletions tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,19 @@ def circuit():

qml.snapshots(circuit)()

def test_state_wire_order_preservation(self):
"""Test that the snapshots wire order reflects the wire order on the device."""

@qml.qnode(qml.device("default.qubit", wires=2))
def circuit():
qml.X(1)
qml.Snapshot()
return qml.state()

out = qml.snapshots(circuit)()

assert qml.math.allclose(out[0], out["execution_results"])

# pylint: disable=protected-access
@pytest.mark.parametrize("method", [None, "parameter-shift"])
def test_default_qutrit(self, method):
Expand Down

0 comments on commit d5d3c2b

Please sign in to comment.