Skip to content

Commit

Permalink
Fix ZNE with measurements_from_samples (#1165)
Browse files Browse the repository at this point in the history
**Context:**
1) Jax generates additional const arguments with [...], and take.

2) The `mitigate_with_zne` does not work when using the
`measurements_from_sample` function. The root of the the issue is the
function `removeQuantumMeasurements` that tries to remove all users of
the measurments. But potentially users are located in a block of an
operations, therefore emptying the block but not the operation.

**Description of the Change:**
1) The JaxPr const are added to the ZNE op args.
2) removeQuantumMeasurements is replaced by replaceQuantumMeasurements,
this functions is replacing quantum measurements with empty tensors. The
tensors are remove in the bufferization pass. (`--inline,
--canonicalize`)

**Benefits:**
ZNE with shots works with hardware devices (counts and samples)

---------

Co-authored-by: David Ittah <[email protected]>
  • Loading branch information
rmoyard and dime10 authored Oct 7, 2024
1 parent 2fb956c commit 2a49e34
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 4 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@

<h3>Bug fixes</h3>

* Resolve a bug where `mitigate_with_zne` does not work properly with shots and devices
supporting only Counts and Samples (e.g. Qrack). (transform: `measurements_from_sample`).
[(#1165)](https://github.com/PennyLaneAI/catalyst/pull/1165)

* Resolve a bug in the `vmap` function when passing shapeless values to the target.
[(#1150)](https://github.com/PennyLaneAI/catalyst/pull/1150)

Expand Down
15 changes: 14 additions & 1 deletion frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,10 +1057,23 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn):
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
num_folds = args[-1]

constants = []
for const in jaxpr.consts:
const_type = shape_dtype_to_ir_type(const.shape, const.dtype)
nparray = np.asarray(const)
# TODO: Fix bool case
if not const.dtype == bool:
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
constantVals = StableHLOConstantOp(attr).results
constants.append(constantVals)

args_and_consts = constants + list(args[0:-1])

return ZneOp(
flat_output_types,
symbol_ref,
mlir.flatten_lowering_ir_args(args[0:-1]),
mlir.flatten_lowering_ir_args(args_and_consts),
_folding_attribute(ctx, folding),
num_folds,
).results
Expand Down
26 changes: 26 additions & 0 deletions frontend/test/pytest/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,31 @@ def mitigated_qnode():
assert np.allclose(mitigated_qnode(), circuit())


def test_jaxpr_with_const():
"""test mitigate_with_zne with a circuit that generates arguments in MLIR"""
dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
def circuit():
a = jax.numpy.array([0.1, 0.2, 0.3, 0.4])
b = jax.numpy.take(a, 2)
qml.Hadamard(wires=0)
qml.RZ(0.1, wires=0)
qml.RZ(b, wires=0)
qml.CNOT(wires=[1, 0])
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))

@catalyst.qjit
def mitigated_qnode():
return catalyst.mitigate_with_zne(
circuit,
scale_factors=[1, 3, 5, 7],
extrapolate=quadratic_extrapolation,
)()

assert np.allclose(mitigated_qnode(), circuit())


if __name__ == "__main__":
pytest.main(["-x", __file__])
1 change: 1 addition & 0 deletions mlir/include/Quantum/Utils/RemoveQuantum.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace catalyst {
namespace quantum {

void removeQuantumMeasurements(mlir::func::FuncOp &function, mlir::PatternRewriter &rewriter);
void replaceQuantumMeasurements(mlir::func::FuncOp &function, mlir::PatternRewriter &rewriter);
mlir::LogicalResult verifyQuantumFree(mlir::func::FuncOp function);

} // namespace quantum
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew
fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp,
fnWithoutMeasurementsOp, fnWithMeasurementsOp);
}

rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end());

Block *fnFoldedOpBlock = &fnFoldedOp.getBody().front();
Expand Down Expand Up @@ -413,7 +412,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFnWithoutMeasurements(Location loc,

quantum::DeallocOp localDealloc = *fnWithoutMeasurementsOp.getOps<quantum::DeallocOp>().begin();
rewriter.eraseOp(localDealloc);
quantum::removeQuantumMeasurements(fnWithoutMeasurementsOp, rewriter);
quantum::replaceQuantumMeasurements(fnWithoutMeasurementsOp, rewriter);
return SymbolRefAttr::get(ctx, fnWithoutMeasurementsName);
}
FlatSymbolRefAttr ZneLowering::getOrInsertFnWithMeasurements(Location loc,
Expand Down
38 changes: 37 additions & 1 deletion mlir/lib/Quantum/Utils/RemoveQuantum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

#include "llvm/ADT/SmallPtrSet.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"

#include "Quantum/IR/QuantumInterfaces.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Utils/RemoveQuantum.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;

Expand Down Expand Up @@ -59,6 +62,39 @@ void removeQuantumMeasurements(func::FuncOp &function, PatternRewriter &rewriter
}
}

void replaceQuantumMeasurements(func::FuncOp &function, PatternRewriter &rewriter)
{
function.walk([&](MeasurementProcess op) {
auto types = op->getResults().getTypes();
auto loc = op.getLoc();
SmallVector<Value> results;
for (auto type : types) {
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
auto shape = tensorType.getShape();
auto elemType = tensorType.getElementType();
auto res = rewriter.create<tensor::EmptyOp>(loc, shape, elemType);
results.push_back(res);
}
else {
if (type.isInteger()) {
auto res = rewriter.create<arith::ConstantOp>(loc, type,
rewriter.getIntegerAttr(type, 0));
results.push_back(res);
}
else if (type.isIntOrFloat()) {
auto res = rewriter.create<arith::ConstantOp>(loc, type,
rewriter.getFloatAttr(type, 0.0));
results.push_back(res);
}
else {
op.emitError() << "Unexpected measurement type " << *op;
}
}
}
rewriter.replaceOp(op, results);
});
}

LogicalResult verifyQuantumFree(func::FuncOp function)
{
assert(function->hasAttr("QuantumFree") &&
Expand Down

0 comments on commit 2a49e34

Please sign in to comment.