Skip to content

Commit

Permalink
Fix bug where unexpected queuing occurs in qml.ctrl among other funct…
Browse files Browse the repository at this point in the history
…ions (#6474)

This PR is a copy of this [other
one](#6284) pointing to
v0.39.0-rc0

[[sc-73690](https://app.shortcut.com/xanaduai/story/73690)]

---------

Co-authored-by: Astral Cai <[email protected]>
  • Loading branch information
KetpuntoG and astralcai authored Oct 31, 2024
1 parent cd9fac9 commit b851ce3
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-0.39.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,9 @@
* Fixes a bug where `default.tensor` raises an error when applying `Identity`/`GlobalPhase` on no wires, and `PauliRot`/`MultiRZ` on a single wire.
[(#6448)](https://github.com/PennyLaneAI/pennylane/pull/6448)

* Fixes a bug where applying `qml.ctrl` and `qml.adjoint` on an operator type instead of an operator instance results in extra operators in the queue.
[(#6284)](https://github.com/PennyLaneAI/pennylane/pull/6284)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ def _adjoint_transform(qfunc: Callable, lazy=True) -> Callable:
@wraps(qfunc)
def wrapper(*args, **kwargs):
qscript = make_qscript(qfunc)(*args, **kwargs)

leaves, _ = qml.pytrees.flatten((args, kwargs), lambda obj: isinstance(obj, Operator))
_ = [qml.QueuingManager.remove(l) for l in leaves if isinstance(l, Operator)]

if lazy:
adjoint_ops = [Adjoint(op) for op in reversed(qscript.operations)]
else:
Expand Down
3 changes: 3 additions & 0 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def _ctrl_transform(op, control, control_values, work_wires):
def wrapper(*args, **kwargs):
qscript = qml.tape.make_qscript(op)(*args, **kwargs)

leaves, _ = qml.pytrees.flatten((args, kwargs), lambda obj: isinstance(obj, Operator))
_ = [qml.QueuingManager.remove(l) for l in leaves if isinstance(l, Operator)]

# flip control_values == 0 wires here, so we don't have to do it for each individual op.
flip_control_on_zero = (len(qscript) > 1) and (control_values is not None)
op_control_values = None if flip_control_on_zero else control_values
Expand Down
12 changes: 12 additions & 0 deletions tests/ops/op_math/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,18 @@ def test_single_observable(self):
qs = qml.tape.QuantumScript.from_queue(q)
assert len(qs) == 0

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

with qml.queuing.AnnotatedQueue() as q:
qml.adjoint(qml.QSVT)(qml.X(1), [qml.Z(1)])
qml.adjoint(qml.QSVT(qml.X(1), [qml.Z(1)]))

for op in q.queue:
assert op.name == "Adjoint(QSVT)"

assert len(q.queue) == 2


class TestAdjointConstructorDifferentCallableTypes:
"""Test the adjoint transform on a variety of possible inputs."""
Expand Down
11 changes: 11 additions & 0 deletions tests/ops/op_math/test_controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,17 @@ def test_nested_pauli_x_based_ctrl_ops(self):
expected = qml.MultiControlledX(wires=[3, 2, 1, 0], control_values=[1, 0, 1])
assert op == expected

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

with qml.queuing.AnnotatedQueue() as q:
qml.ctrl(qml.QSVT, control=0)(qml.X(1), [qml.Z(1)])
qml.ctrl(qml.QSVT(qml.X(1), [qml.Z(1)]), control=0)
for op in q.queue:
assert op.name == "C(QSVT)"

assert len(q.queue) == 2


class _Rot(Operation):
"""A rotation operation that is not an instance of Rot
Expand Down
12 changes: 12 additions & 0 deletions tests/ops/op_math/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,18 @@ def test_nonlazy_mode_queueing(self):
assert len(q) == 1
assert q.queue[0] is prod2

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

with qml.queuing.AnnotatedQueue() as q:
qml.prod(qml.QSVT)(qml.X(1), [qml.Z(1)])
qml.prod(qml.QSVT(qml.X(1), [qml.Z(1)]))

for op in q.queue:
assert op.name == "QSVT"

assert len(q.queue) == 2


class TestIntegration:
"""Integration tests for the Prod class."""
Expand Down

0 comments on commit b851ce3

Please sign in to comment.