Skip to content

Commit

Permalink
Merge branch 'v0.39.0-rc0' into queueing_controls
Browse files Browse the repository at this point in the history
  • Loading branch information
KetpuntoG authored Oct 29, 2024
2 parents 30ec1d1 + 92e48d3 commit e4f55e0
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-0.39.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

<h3>Improvements 🛠</h3>

* `qml.metric_tensor` can now be JIT compiled.
[(#6468)](https://github.com/PennyLaneAI/pennylane/pull/6468)

* RTD support for `qml.labs` added to API.
[(#6397)](https://github.com/PennyLaneAI/pennylane/pull/6397)

Expand Down
3 changes: 2 additions & 1 deletion pennylane/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.40.0-dev1"

__version__ = "0.39.0"
26 changes: 12 additions & 14 deletions pennylane/gradients/metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from pennylane.typing import PostprocessingFn


def _mt_cjac_tdot(mt, c):
return qml.math.tensordot(c, qml.math.tensordot(mt, c, axes=[[-1], [0]]), axes=[[0], [0]])


def _contract_metric_tensor_with_cjac(mt, cjac, tape): # pylint: disable=unused-argument
"""Execute the contraction of pre-computed classical Jacobian(s)
and the metric tensor of a tape in order to obtain the hybrid
Expand All @@ -47,25 +51,19 @@ def _contract_metric_tensor_with_cjac(mt, cjac, tape): # pylint: disable=unused
if isinstance(mt, tuple) and len(mt) == 1:
mt = mt[0]
if isinstance(cjac, tuple):
# Classical processing of multiple arguments is present. Return cjac.T @ mt @ cjac
# Classical processing of multiple arguments is present. Return cjac[i].T @ mt @ cjac[i]
# as a tuple of contractions.
metric_tensors = tuple(
qml.math.tensordot(c, qml.math.tensordot(mt, c, axes=[[-1], [0]]), axes=[[0], [0]])
for c in cjac
if c is not None
)
metric_tensors = tuple(_mt_cjac_tdot(mt, c) for c in cjac if c is not None)
return metric_tensors[0] if len(metric_tensors) == 1 else metric_tensors

is_square = cjac.shape == (1,) or (cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1])
if not qml.math.is_abstract(cjac):
is_square = cjac.shape == (1,) or (cjac.ndim == 2 and cjac.shape[0] == cjac.shape[1])

if is_square and qml.math.allclose(cjac, qml.numpy.eye(cjac.shape[0])):
# Classical Jacobian is the identity. No classical processing
# is present inside the QNode.
return mt
mt_cjac = qml.math.tensordot(mt, cjac, axes=[[-1], [0]])
mt = qml.math.tensordot(cjac, mt_cjac, axes=[[0], [0]])
if is_square and qml.math.allclose(cjac, qml.math.eye(cjac.shape[0])):
# Classical Jacobian is the identity. No classical processing in the QNode
return mt

return mt
return _mt_cjac_tdot(mt, cjac)


def _expand_metric_tensor(
Expand Down
17 changes: 10 additions & 7 deletions tests/gradients/core/test_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""
import importlib

# pylint: disable=too-many-arguments,too-many-public-methods,too-few-public-methods,not-callable,too-many-statements
# pylint: disable=too-many-arguments,too-many-public-methods,too-few-public-methods
# pylint: disable=not-callable,too-many-statements
import pytest
from scipy.linalg import block_diag

Expand Down Expand Up @@ -997,7 +998,8 @@ def circuit(*params):
@pytest.mark.parametrize("ansatz, params", zip(fubini_ansatze, fubini_params))
@pytest.mark.parametrize("interface", ["auto", "jax"])
@pytest.mark.parametrize("dev_name", ("default.qubit", "lightning.qubit"))
def test_correct_output_jax(self, dev_name, ansatz, params, interface):
@pytest.mark.parametrize("use_jit", [False, True])
def test_correct_output_jax(self, dev_name, ansatz, params, interface, use_jit):
import jax
from jax import numpy as jnp

Expand All @@ -1019,11 +1021,12 @@ def circuit(*params):
ansatz(*params, dev.wires[:-1])
return qml.expval(qml.PauliZ(0))

if len(params) > 1:
# pylint:disable=unexpected-keyword-arg
mt = qml.metric_tensor(circuit, argnums=range(0, len(params)), approx=None)(*params)
else:
mt = qml.metric_tensor(circuit, approx=None)(*params)
argnums = range(0, len(params)) if len(params) > 1 else None
# pylint:disable=unexpected-keyword-arg
mt_fn = qml.metric_tensor(circuit, argnums=argnums, approx=None)
if use_jit:
mt_fn = jax.jit(mt_fn)
mt = mt_fn(*params)

if isinstance(mt, tuple):
assert all(qml.math.allclose(_mt, _exp) for _mt, _exp in zip(mt, expected))
Expand Down

0 comments on commit e4f55e0

Please sign in to comment.