diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3a524b3fb6..095564bb8f 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -230,6 +230,10 @@
Bug fixes
+* Resolve a bug where `mitigate_with_zne` does not work properly with shots and devices
+ supporting only Counts and Samples (e.g. Qrack). (transform: `measurements_from_sample`).
+ [(#1165)](https://github.com/PennyLaneAI/catalyst/pull/1165)
+
* Resolve a bug in the `vmap` function when passing shapeless values to the target.
[(#1150)](https://github.com/PennyLaneAI/catalyst/pull/1150)
diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py
index 50b8a5bb01..e4021e6149 100644
--- a/frontend/catalyst/jax_primitives.py
+++ b/frontend/catalyst/jax_primitives.py
@@ -1057,10 +1057,23 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn):
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
num_folds = args[-1]
+
+ constants = []
+ for const in jaxpr.consts:
+ const_type = shape_dtype_to_ir_type(const.shape, const.dtype)
+ nparray = np.asarray(const)
+ # TODO: Fix bool case
+ if not const.dtype == bool:
+ attr = ir.DenseElementsAttr.get(nparray, type=const_type)
+ constantVals = StableHLOConstantOp(attr).results
+ constants.append(constantVals)
+
+ args_and_consts = constants + list(args[0:-1])
+
return ZneOp(
flat_output_types,
symbol_ref,
- mlir.flatten_lowering_ir_args(args[0:-1]),
+ mlir.flatten_lowering_ir_args(args_and_consts),
_folding_attribute(ctx, folding),
num_folds,
).results
diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py
index 64b491423f..ec370cdaaa 100644
--- a/frontend/test/pytest/test_mitigation.py
+++ b/frontend/test/pytest/test_mitigation.py
@@ -378,5 +378,31 @@ def mitigated_qnode():
assert np.allclose(mitigated_qnode(), circuit())
+def test_jaxpr_with_const():
+ """test mitigate_with_zne with a circuit that generates arguments in MLIR"""
+ dev = qml.device("lightning.qubit", wires=2)
+
+ @qml.qnode(device=dev)
+ def circuit():
+ a = jax.numpy.array([0.1, 0.2, 0.3, 0.4])
+ b = jax.numpy.take(a, 2)
+ qml.Hadamard(wires=0)
+ qml.RZ(0.1, wires=0)
+ qml.RZ(b, wires=0)
+ qml.CNOT(wires=[1, 0])
+ qml.Hadamard(wires=1)
+ return qml.expval(qml.PauliY(wires=0))
+
+ @catalyst.qjit
+ def mitigated_qnode():
+ return catalyst.mitigate_with_zne(
+ circuit,
+ scale_factors=[1, 3, 5, 7],
+ extrapolate=quadratic_extrapolation,
+ )()
+
+ assert np.allclose(mitigated_qnode(), circuit())
+
+
if __name__ == "__main__":
pytest.main(["-x", __file__])
diff --git a/mlir/include/Quantum/Utils/RemoveQuantum.h b/mlir/include/Quantum/Utils/RemoveQuantum.h
index de45daa540..d074fddddb 100644
--- a/mlir/include/Quantum/Utils/RemoveQuantum.h
+++ b/mlir/include/Quantum/Utils/RemoveQuantum.h
@@ -21,6 +21,7 @@ namespace catalyst {
namespace quantum {
void removeQuantumMeasurements(mlir::func::FuncOp &function, mlir::PatternRewriter &rewriter);
+void replaceQuantumMeasurements(mlir::func::FuncOp &function, mlir::PatternRewriter &rewriter);
mlir::LogicalResult verifyQuantumFree(mlir::func::FuncOp function);
} // namespace quantum
diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp
index 0405c5f24f..920b1b632d 100644
--- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp
+++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp
@@ -319,7 +319,6 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew
fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp,
fnWithoutMeasurementsOp, fnWithMeasurementsOp);
}
-
rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end());
Block *fnFoldedOpBlock = &fnFoldedOp.getBody().front();
@@ -413,7 +412,7 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFnWithoutMeasurements(Location loc,
quantum::DeallocOp localDealloc = *fnWithoutMeasurementsOp.getOps().begin();
rewriter.eraseOp(localDealloc);
- quantum::removeQuantumMeasurements(fnWithoutMeasurementsOp, rewriter);
+ quantum::replaceQuantumMeasurements(fnWithoutMeasurementsOp, rewriter);
return SymbolRefAttr::get(ctx, fnWithoutMeasurementsName);
}
FlatSymbolRefAttr ZneLowering::getOrInsertFnWithMeasurements(Location loc,
diff --git a/mlir/lib/Quantum/Utils/RemoveQuantum.cpp b/mlir/lib/Quantum/Utils/RemoveQuantum.cpp
index eca1b01519..dbb3b51097 100644
--- a/mlir/lib/Quantum/Utils/RemoveQuantum.cpp
+++ b/mlir/lib/Quantum/Utils/RemoveQuantum.cpp
@@ -16,10 +16,13 @@
#include "llvm/ADT/SmallPtrSet.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+
#include "Quantum/IR/QuantumInterfaces.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Utils/RemoveQuantum.h"
-#include "mlir/IR/PatternMatch.h"
using namespace mlir;
@@ -59,6 +62,39 @@ void removeQuantumMeasurements(func::FuncOp &function, PatternRewriter &rewriter
}
}
+void replaceQuantumMeasurements(func::FuncOp &function, PatternRewriter &rewriter)
+{
+ function.walk([&](MeasurementProcess op) {
+ auto types = op->getResults().getTypes();
+ auto loc = op.getLoc();
+ SmallVector results;
+ for (auto type : types) {
+ if (auto tensorType = dyn_cast(type)) {
+ auto shape = tensorType.getShape();
+ auto elemType = tensorType.getElementType();
+ auto res = rewriter.create(loc, shape, elemType);
+ results.push_back(res);
+ }
+ else {
+ if (type.isInteger()) {
+ auto res = rewriter.create(loc, type,
+ rewriter.getIntegerAttr(type, 0));
+ results.push_back(res);
+ }
+ else if (type.isIntOrFloat()) {
+ auto res = rewriter.create(loc, type,
+ rewriter.getFloatAttr(type, 0.0));
+ results.push_back(res);
+ }
+ else {
+ op.emitError() << "Unexpected measurement type " << *op;
+ }
+ }
+ }
+ rewriter.replaceOp(op, results);
+ });
+}
+
LogicalResult verifyQuantumFree(func::FuncOp function)
{
assert(function->hasAttr("QuantumFree") &&