From d5d3c2bb44cf535c7d07dfaa37a4b3b4e5a88b05 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Wed, 30 Oct 2024 15:31:48 -0400 Subject: [PATCH] Fix snapshot wire order when simulation indices different than device wires (#6461) **Context:** Internally, default qubit only performs the simulation with the wires that it needs for apply the operation. Any wires only present in measurements are only added in at the very end. This can be substantially more memory and time efficient in certain edge cases. Unfortunately, this can cause confusion about the snapshotted state taken mid-simulation. Mid simulation we have fewer wires and a different wire order. **Description of the Change:** 1) Add a `map_wires` method to snapshot to allow us to map the wires 2) Fill in device wires on Snapshots during `validate_device_wires`. 3) Allow `StateMP.process_state` to add in blank subsystems when the measurement has wires not present in the state's wire order. These three steps are sufficient to make sure that snapshots always match the device's wire order. **Benefits:** Less Confusion **Possible Drawbacks:** It's no longer the state being used during the simulation at that point. We are hiding away the fact we are working with a different state during the simulation. **Related GitHub Issues:** Fixes #6427 [sc-76515] --------- Co-authored-by: Yushao Chen (Jerry) Co-authored-by: Astral Cai --- doc/releases/changelog-0.39.0.md | 3 + pennylane/devices/preprocess.py | 19 +++++- pennylane/devices/qubit/apply_operation.py | 2 +- pennylane/measurements/state.py | 35 ++++++----- pennylane/ops/meta.py | 11 +++- tests/devices/test_preprocess.py | 10 ++++ tests/measurements/test_state.py | 22 +++++-- tests/ops/test_meta.py | 69 +++++++++++++--------- tests/test_debugging.py | 13 ++++ 9 files changed, 131 insertions(+), 53 deletions(-) diff --git a/doc/releases/changelog-0.39.0.md b/doc/releases/changelog-0.39.0.md index 489ae447183..da0fe901fae 100644 --- a/doc/releases/changelog-0.39.0.md +++ b/doc/releases/changelog-0.39.0.md @@ -400,6 +400,9 @@

Bug fixes 🐛

+* The wire order for `Snapshot`'s now matches the wire order of the device, rather than the simulation. + [(#6461)](https://github.com/PennyLaneAI/pennylane/pull/6461) + * Fixes a bug where `QNSPSAOptimizer`, `QNGOptimizer` and `MomentumQNGOptimizer` calculate invalid parameter updates if the metric tensor becomes singular. [(#6471)](https://github.com/PennyLaneAI/pennylane/pull/6471) diff --git a/pennylane/devices/preprocess.py b/pennylane/devices/preprocess.py index 2ada3461589..b50db2bb61e 100644 --- a/pennylane/devices/preprocess.py +++ b/pennylane/devices/preprocess.py @@ -149,8 +149,23 @@ def validate_device_wires( f"Cannot run circuit(s) on {name} as they contain wires " f"not found on the device: {extra_wires}" ) - measurements = tape.measurements.copy() + modified = False + new_ops = None + for i, op in enumerate(tape.operations): + if isinstance(op, qml.Snapshot): + mp = op.hyperparameters["measurement"] + if not mp.wires: + if not new_ops: + new_ops = list(tape.operations) + modified = True + new_mp = copy(mp) + new_mp._wires = wires # pylint:disable=protected-access + new_ops[i] = qml.Snapshot(measurement=new_mp, tag=op.tag) + if not new_ops: + new_ops = tape.operations # no copy in this case + + measurements = tape.measurements.copy() for m_idx, mp in enumerate(measurements): if not mp.obs and not mp.wires: modified = True @@ -158,7 +173,7 @@ def validate_device_wires( new_mp._wires = wires # pylint:disable=protected-access measurements[m_idx] = new_mp if modified: - tape = type(tape)(tape.operations, measurements, shots=tape.shots) + tape = tape.copy(ops=new_ops, measurements=measurements) return (tape,), null_postprocessing diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index bef7cfd1472..95aaaa7c931 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -612,7 +612,7 @@ def _apply_grover_without_matrix(state, op_wires, is_state_batched): def apply_snapshot( op: qml.Snapshot, state, is_state_batched: bool = False, debugger=None, **execution_kwargs ): - """Take a snapshot of the state""" + """Take a snapshot of the state.""" if debugger is not None and debugger.active: measurement = op.hyperparameters["measurement"] diff --git a/pennylane/measurements/state.py b/pennylane/measurements/state.py index 5b8baecf8b7..920358e606d 100644 --- a/pennylane/measurements/state.py +++ b/pennylane/measurements/state.py @@ -174,27 +174,34 @@ def cast_to_complex(state): floating_single = "float32" in dtype or "complex64" in dtype return qml.math.cast(state, "complex64" if floating_single else "complex128") - wires = self.wires - if not wires or wire_order == wires: + if not self.wires or wire_order == self.wires: return cast_to_complex(state) - if set(wires) != set(wire_order): + if not all(w in self.wires for w in wire_order): + bad_wires = [w for w in wire_order if w not in self.wires] raise WireError( - f"Unexpected unique wires {Wires.unique_wires([wires, wire_order])} found. " - f"Expected wire order {wire_order} to be a rearrangement of {wires}" + f"State wire order has wires {bad_wires} not present in " + f"measurement with wires {self.wires}. StateMP.process_state cannot trace out wires." ) - shape = (2,) * len(wires) - flat_shape = (2 ** len(wires),) - desired_axes = [wire_order.index(w) for w in wires] - if qml.math.ndim(state) == 2: # batched state - batch_size = qml.math.shape(state)[0] - shape = (batch_size,) + shape - flat_shape = (batch_size,) + flat_shape - desired_axes = [0] + [i + 1 for i in desired_axes] - + shape = (2,) * len(wire_order) + batch_size = None if qml.math.ndim(state) == 1 else qml.math.shape(state)[0] + shape = (batch_size,) + shape if batch_size else shape state = qml.math.reshape(state, shape) + + if wires_to_add := Wires(set(self.wires) - set(wire_order)): + for _ in wires_to_add: + state = qml.math.stack([state, qml.math.zeros_like(state)], axis=-1) + wire_order = wire_order + wires_to_add + + desired_axes = [wire_order.index(w) for w in self.wires] + if batch_size: + desired_axes = [0] + [i + 1 for i in desired_axes] state = qml.math.transpose(state, desired_axes) + + flat_shape = (2 ** len(self.wires),) + if batch_size: + flat_shape = (batch_size,) + flat_shape state = qml.math.reshape(state, flat_shape) return cast_to_complex(state) diff --git a/pennylane/ops/meta.py b/pennylane/ops/meta.py index e455457b74d..6d63d820609 100644 --- a/pennylane/ops/meta.py +++ b/pennylane/ops/meta.py @@ -15,6 +15,8 @@ This submodule contains the discrete-variable quantum operations that do not depend on any parameters. """ +from collections.abc import Hashable + # pylint:disable=abstract-method,arguments-differ,protected-access,invalid-overridden-method, no-member from copy import copy from typing import Optional @@ -216,10 +218,9 @@ def __init__(self, tag: Optional[str] = None, measurement=None): if measurement is None: measurement = qml.state() - + if isinstance(measurement, qml.measurements.MidMeasureMP): + raise ValueError("Mid-circuit measurements can not be used in snapshots.") if isinstance(measurement, qml.measurements.MeasurementProcess): - if isinstance(measurement, qml.measurements.MidMeasureMP): - raise ValueError("Mid-circuit measurements can not be used in snapshots.") qml.queuing.QueuingManager.remove(measurement) else: raise ValueError( @@ -251,6 +252,10 @@ def _controlled(self, _): def adjoint(self): return Snapshot(tag=self.tag, measurement=self.hyperparameters["measurement"]) + def map_wires(self, wire_map: dict[Hashable, Hashable]) -> "Snapshot": + new_measurement = self.hyperparameters["measurement"].map_wires(wire_map) + return Snapshot(tag=self.tag, measurement=new_measurement) + # Since measurements are captured as variables in plxpr with the capture module, # the measurement is treated as a traceable argument. diff --git a/tests/devices/test_preprocess.py b/tests/devices/test_preprocess.py index b2a6c107e6f..e782b4fec84 100644 --- a/tests/devices/test_preprocess.py +++ b/tests/devices/test_preprocess.py @@ -221,6 +221,16 @@ def jit_wires_dev(wires): ): jax.jit(jit_wires_dev)([0, 1]) + def test_fill_in_wires_on_snapshots(self): + """Test that validate_device_wires also fills in the wires on snapshots.""" + + tape = qml.tape.QuantumScript([qml.Snapshot(), qml.Snapshot(measurement=qml.probs())]) + + (output,), _ = validate_device_wires(tape, wires=qml.wires.Wires((0, 1, 2))) + mp0 = qml.measurements.StateMP(wires=qml.wires.Wires((0, 1, 2))) + qml.assert_equal(output[0], qml.Snapshot(measurement=mp0)) + qml.assert_equal(output[1], qml.Snapshot(measurement=qml.probs(wires=(0, 1, 2)))) + class TestDecomposeValidation: """Unit tests for helper functions in qml.devices.qubit.preprocess""" diff --git a/tests/measurements/test_state.py b/tests/measurements/test_state.py index 838cad68eab..78c8d17fb8c 100644 --- a/tests/measurements/test_state.py +++ b/tests/measurements/test_state.py @@ -118,8 +118,22 @@ def get_state(ket): def test_wire_ordering_error(self): """Test that a wire order error is raised when unknown wires are given.""" - with pytest.raises(WireError, match=r"Unexpected unique wires Wires\(\[0, 1, 2\]\) found"): - StateMP(wires=[0, 1]).process_state([1, 0], wire_order=Wires(2)) + with pytest.raises( + WireError, match=r"State wire order has wires \[2\] not present in measurement" + ): + StateMP(wires=[0, 1]).process_state(np.array([1, 0]), wire_order=Wires(2)) + + def test_adding_wires(self): + """Test that process_state can add wires not present in the original wire order.""" + + orig_state = np.array([0, 1]) + mp = StateMP(wires=Wires([0, 1])) + new_state = mp.process_state(orig_state, wire_order=Wires([1])) + + expected = np.zeros((2, 2)) + expected[0, 1] = 1 # zero for wire order, 1 for wire 1 + expected = np.reshape(expected, (4,)) + assert qml.math.allclose(expected, new_state) @pytest.mark.parametrize( "dm", @@ -854,7 +868,7 @@ def func(): elif len(return_wire_order) == 2: i, j = return_wire_order exp_statevector = np.kron(single_states[i], single_states[j]) - elif len(return_wire_order) == 3: + else: # len(return_wire_order) == 3 i, j, k = return_wire_order exp_statevector = np.kron(np.kron(single_states[i], single_states[j]), single_states[k]) @@ -885,7 +899,7 @@ def func(): elif len(return_wire_order) == 2: i, j = return_wire_order exp_statevector = np.kron(single_states[i], single_states[j]) - elif len(return_wire_order) == 3: + else: # len(return_wire_order) == 3 i, j, k = return_wire_order exp_statevector = np.kron(np.kron(single_states[i], single_states[j]), single_states[k]) diff --git a/tests/ops/test_meta.py b/tests/ops/test_meta.py index 9cb6b1eafc3..19f5c303ddd 100644 --- a/tests/ops/test_meta.py +++ b/tests/ops/test_meta.py @@ -220,32 +220,43 @@ def test_op_matrix_fails(self): op.matrix() -def test_decomposition(): - """Test the decomposition of the Snapshot operation.""" - - assert Snapshot.compute_decomposition() == [] - assert Snapshot().decomposition() == [] - - -def test_label_method(): - """Test the label method for the Snapshot operation.""" - assert Snapshot().label() == "|Snap|" - assert Snapshot("my_label").label() == "|Snap|" - - -def test_control(): - """Test the _controlled method for the Snapshot operation.""" - assert isinstance(Snapshot()._controlled(0), Snapshot) - assert Snapshot("my_label")._controlled(0).tag == Snapshot("my_label").tag - - -def test_adjoint(): - """Test the adjoint method for the Snapshot operation.""" - assert isinstance(Snapshot().adjoint(), Snapshot) - assert Snapshot("my_label").adjoint().tag == Snapshot("my_label").tag - - -def test_snapshot_no_empty_wire_list_error(): - """Test that Snapshot does not raise an empty wire error.""" - snapshot = qml.Snapshot() - assert isinstance(snapshot, qml.Snapshot) +class TestSnapshot: + """Unit tests for the snapshot class.""" + + def test_decomposition(self): + """Test the decomposition of the Snapshot operation.""" + + assert Snapshot.compute_decomposition() == [] + assert Snapshot().decomposition() == [] + + def test_label_method(self): + """Test the label method for the Snapshot operation.""" + assert Snapshot().label() == "|Snap|" + assert Snapshot("my_label").label() == "|Snap|" + + def test_control(self): + """Test the _controlled method for the Snapshot operation.""" + assert isinstance(Snapshot()._controlled(0), Snapshot) + assert Snapshot("my_label")._controlled(0).tag == Snapshot("my_label").tag + + def test_adjoint(self): + """Test the adjoint method for the Snapshot operation.""" + assert isinstance(Snapshot().adjoint(), Snapshot) + assert Snapshot("my_label").adjoint().tag == Snapshot("my_label").tag + + def test_snapshot_no_empty_wire_list_error(self): + """Test that Snapshot does not raise an empty wire error.""" + snapshot = qml.Snapshot() + assert isinstance(snapshot, qml.Snapshot) + + @pytest.mark.parametrize( + "mp", (qml.expval(qml.Z(0)), qml.measurements.StateMP(wires=(2, 1, 0))) + ) + def test_map_wires(self, mp): + """Test that the wires of the measurement are mapped""" + op = Snapshot(measurement=mp, tag="my tag") + wire_map = {0: "a", 1: "b"} + new_op = op.map_wires(wire_map) + target_mp = mp.map_wires(wire_map) + qml.assert_equal(target_mp, new_op.hyperparameters["measurement"]) + assert new_op.tag == "my tag" diff --git a/tests/test_debugging.py b/tests/test_debugging.py index e575017f5a1..f8065e36d02 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -630,6 +630,19 @@ def circuit(): qml.snapshots(circuit)() + def test_state_wire_order_preservation(self): + """Test that the snapshots wire order reflects the wire order on the device.""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(): + qml.X(1) + qml.Snapshot() + return qml.state() + + out = qml.snapshots(circuit)() + + assert qml.math.allclose(out[0], out["execution_results"]) + # pylint: disable=protected-access @pytest.mark.parametrize("method", [None, "parameter-shift"]) def test_default_qutrit(self, method):