Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where unexpected queuing occurs in qml.ctrl among other functions #6284

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@
* Fixes a bug where a simple circuit with no parameters or only builtin/numpy arrays as parameters returns autograd tensors.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

* Fixes a bug where `qml.ctrl` and `qml.adjoint` queued extra operators if they were defined as the arguments.
[(#6284)](https://github.com/PennyLaneAI/pennylane/pull/6284)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
8 changes: 8 additions & 0 deletions pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pennylane.compiler import compiler
from pennylane.math import conj, moveaxis, transpose
from pennylane.operation import Observable, Operation, Operator
from pennylane.ops.op_math.controlled import remove_from_queue_args_and_kwargs
from pennylane.queuing import QueuingManager
from pennylane.tape import make_qscript

Expand Down Expand Up @@ -236,6 +237,13 @@ def _adjoint_transform(qfunc: Callable, lazy=True) -> Callable:
@wraps(qfunc)
def wrapper(*args, **kwargs):
qscript = make_qscript(qfunc)(*args, **kwargs)

for arg in args:
remove_from_queue_args_and_kwargs(arg)

for value in kwargs.values():
remove_from_queue_args_and_kwargs(value)

if lazy:
adjoint_ops = [Adjoint(op) for op in reversed(qscript.operations)]
else:
Expand Down
18 changes: 18 additions & 0 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,29 @@ def create_controlled_op(op, control, control_values=None, work_wires=None):
return _ctrl_transform(op, control, control_values, work_wires)


def remove_from_queue_args_and_kwargs(item):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this helper method to the tape module and just call it something like recursively_remove_operators_from_queue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, is okey in tape/operation_recorder.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be more suitable in pennylane/queueing.py actually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use pytrees and simply do:

leaves, _ = qml.pytrees.flatten((args, kwargs), lambda obj: isinstance(obj, Operator))
_ = [qml.QueuingManager.remove(l) for l in leaves if isinstance(l, Operator)]

Given this is just two lines, I'd be worried that adding a public function adds more complexity and coupling that it solves. Duplicate code is not a bad thing when it helps keep modules simple and loosely coupled.

I'd prefer to just copy those two lines into adjoint, controlled, and TransformDispatcher._qfunc_transform.

"""function used to recursively remove operators that have been added to the queue in an argument or kwarg."""
if isinstance(item, (list, tuple, set)):
for elem in item:
remove_from_queue_args_and_kwargs(elem)
elif isinstance(item, dict):
for value in item.values():
remove_from_queue_args_and_kwargs(value)
elif isinstance(item, Operator):
qml.queuing.QueuingManager.remove(item)


def _ctrl_transform(op, control, control_values, work_wires):
@wraps(op)
def wrapper(*args, **kwargs):
qscript = qml.tape.make_qscript(op)(*args, **kwargs)

for arg in args:
remove_from_queue_args_and_kwargs(arg)

for value in kwargs.values():
remove_from_queue_args_and_kwargs(value)

# flip control_values == 0 wires here, so we don't have to do it for each individual op.
flip_control_on_zero = (len(qscript) > 1) and (control_values is not None)
op_control_values = None if flip_control_on_zero else control_values
Expand Down
15 changes: 15 additions & 0 deletions tests/ops/op_math/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,21 @@ def test_single_op_defined_outside_queue_eager(self):
assert len(q) == 1
assert q.queue[0] is out

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

dev = qml.device("default.qubit")

@qml.qnode(dev)
def circuit():
qml.adjoint(qml.QSVT)(qml.X(1), [qml.Z(1)])
qml.adjoint(qml.QSVT(qml.X(1), [qml.Z(1)]))
return qml.state()

circuit()
for op in circuit.tape.operations:
assert op.name == "Adjoint(QSVT)"

@pytest.mark.usefixtures("use_legacy_opmath")
def test_single_observable(self):
"""Test passing a single preconstructed observable in a queuing context."""
Expand Down
17 changes: 17 additions & 0 deletions tests/ops/op_math/test_controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,23 @@ def test_nested_pauli_x_based_ctrl_ops(self):
expected = qml.MultiControlledX(wires=[3, 2, 1, 0], control_values=[1, 0, 1])
assert op == expected

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

def func(dic):
for gate in dic.values():
qml.apply(gate)

dev = qml.device("default.qubit")

@qml.qnode(dev)
def circuit():
qml.ctrl(func, control=0)({1: qml.X(1), 2: qml.Z(1)})
return qml.state()

circuit()
assert len(circuit.tape.operations) == 2


class _Rot(Operation):
"""A rotation operation that is not an instance of Rot
Expand Down
15 changes: 15 additions & 0 deletions tests/ops/op_math/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,21 @@ def test_nonlazy_mode_queueing(self):
assert len(q) == 1
assert q.queue[0] is prod2

def test_correct_queued_operators(self):
"""Test that args and kwargs do not add operators to the queue."""

dev = qml.device("default.qubit")

@qml.qnode(dev)
def circuit():
qml.prod(qml.QSVT)(qml.X(1), [qml.Z(1)])
qml.prod(qml.QSVT(qml.X(1), [qml.Z(1)]))
return qml.state()

circuit()
for op in circuit.tape.operations:
assert op.name == "QSVT"


class TestIntegration:
"""Integration tests for the Prod class."""
Expand Down
Loading