diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3a524b3fb6..095564bb8f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -230,6 +230,10 @@

Bug fixes

+* Resolve a bug where `mitigate_with_zne` does not work properly with shots and devices + supporting only Counts and Samples (e.g. Qrack). (transform: `measurements_from_sample`). + [(#1165)](https://github.com/PennyLaneAI/catalyst/pull/1165) + * Resolve a bug in the `vmap` function when passing shapeless values to the target. [(#1150)](https://github.com/PennyLaneAI/catalyst/pull/1150) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 50b8a5bb01..e4021e6149 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -1057,10 +1057,23 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn): output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) num_folds = args[-1] + + constants = [] + for const in jaxpr.consts: + const_type = shape_dtype_to_ir_type(const.shape, const.dtype) + nparray = np.asarray(const) + # TODO: Fix bool case + if not const.dtype == bool: + attr = ir.DenseElementsAttr.get(nparray, type=const_type) + constantVals = StableHLOConstantOp(attr).results + constants.append(constantVals) + + args_and_consts = constants + list(args[0:-1]) + return ZneOp( flat_output_types, symbol_ref, - mlir.flatten_lowering_ir_args(args[0:-1]), + mlir.flatten_lowering_ir_args(args_and_consts), _folding_attribute(ctx, folding), num_folds, ).results diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py index 64b491423f..ec370cdaaa 100644 --- a/frontend/test/pytest/test_mitigation.py +++ b/frontend/test/pytest/test_mitigation.py @@ -378,5 +378,31 @@ def mitigated_qnode(): assert np.allclose(mitigated_qnode(), circuit()) +def test_jaxpr_with_const(): + """test mitigate_with_zne with a circuit that generates arguments in MLIR""" + dev = qml.device("lightning.qubit", wires=2) + + @qml.qnode(device=dev) + def circuit(): + a = jax.numpy.array([0.1, 0.2, 0.3, 0.4]) + b = jax.numpy.take(a, 2) + qml.Hadamard(wires=0) + qml.RZ(0.1, wires=0) + qml.RZ(b, wires=0) + qml.CNOT(wires=[1, 0]) + qml.Hadamard(wires=1) + return qml.expval(qml.PauliY(wires=0)) + + @catalyst.qjit + def mitigated_qnode(): + return catalyst.mitigate_with_zne( + circuit, + scale_factors=[1, 3, 5, 7], + extrapolate=quadratic_extrapolation, + )() + + assert np.allclose(mitigated_qnode(), circuit()) + + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/mlir/include/Quantum/Utils/RemoveQuantum.h b/mlir/include/Quantum/Utils/RemoveQuantum.h index de45daa540..d074fddddb 100644 --- a/mlir/include/Quantum/Utils/RemoveQuantum.h +++ b/mlir/include/Quantum/Utils/RemoveQuantum.h @@ -21,6 +21,7 @@ namespace catalyst { namespace quantum { void removeQuantumMeasurements(mlir::func::FuncOp &function, mlir::PatternRewriter &rewriter); +void replaceQuantumMeasurements(mlir::func::FuncOp &function, mlir::PatternRewriter &rewriter); mlir::LogicalResult verifyQuantumFree(mlir::func::FuncOp function); } // namespace quantum diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 0405c5f24f..920b1b632d 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -319,7 +319,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, fnWithoutMeasurementsOp, fnWithMeasurementsOp); } - rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); Block *fnFoldedOpBlock = &fnFoldedOp.getBody().front(); @@ -413,7 +412,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFnWithoutMeasurements(Location loc, quantum::DeallocOp localDealloc = *fnWithoutMeasurementsOp.getOps().begin(); rewriter.eraseOp(localDealloc); - quantum::removeQuantumMeasurements(fnWithoutMeasurementsOp, rewriter); + quantum::replaceQuantumMeasurements(fnWithoutMeasurementsOp, rewriter); return SymbolRefAttr::get(ctx, fnWithoutMeasurementsName); } FlatSymbolRefAttr ZneLowering::getOrInsertFnWithMeasurements(Location loc, diff --git a/mlir/lib/Quantum/Utils/RemoveQuantum.cpp b/mlir/lib/Quantum/Utils/RemoveQuantum.cpp index eca1b01519..dbb3b51097 100644 --- a/mlir/lib/Quantum/Utils/RemoveQuantum.cpp +++ b/mlir/lib/Quantum/Utils/RemoveQuantum.cpp @@ -16,10 +16,13 @@ #include "llvm/ADT/SmallPtrSet.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" + #include "Quantum/IR/QuantumInterfaces.h" #include "Quantum/IR/QuantumOps.h" #include "Quantum/Utils/RemoveQuantum.h" -#include "mlir/IR/PatternMatch.h" using namespace mlir; @@ -59,6 +62,39 @@ void removeQuantumMeasurements(func::FuncOp &function, PatternRewriter &rewriter } } +void replaceQuantumMeasurements(func::FuncOp &function, PatternRewriter &rewriter) +{ + function.walk([&](MeasurementProcess op) { + auto types = op->getResults().getTypes(); + auto loc = op.getLoc(); + SmallVector results; + for (auto type : types) { + if (auto tensorType = dyn_cast(type)) { + auto shape = tensorType.getShape(); + auto elemType = tensorType.getElementType(); + auto res = rewriter.create(loc, shape, elemType); + results.push_back(res); + } + else { + if (type.isInteger()) { + auto res = rewriter.create(loc, type, + rewriter.getIntegerAttr(type, 0)); + results.push_back(res); + } + else if (type.isIntOrFloat()) { + auto res = rewriter.create(loc, type, + rewriter.getFloatAttr(type, 0.0)); + results.push_back(res); + } + else { + op.emitError() << "Unexpected measurement type " << *op; + } + } + } + rewriter.replaceOp(op, results); + }); +} + LogicalResult verifyQuantumFree(func::FuncOp function) { assert(function->hasAttr("QuantumFree") &&