Skip to content

Commit

Permalink
Remove qml.Rot (and CRot) from merge rotation pass patterns (#1206)
Browse files Browse the repository at this point in the history
**Context:**
In #1162 the mlir pass for merging rotation was added, and the pattern
included `qml.Rot` (and `CRot`).
However, it was then realized that these two kinds of rotations should
not be merged, as [the rotation is specified via
](https://docs.pennylane.ai/en/stable/code/api/pennylane.Rot.html)
`rot(a, b, c) = rz(c) ry(b) rz(a)`

This means 
`rot(a, b, c) rot(d, e, f) != rot(a+d, b+e, c+f)`
since y and z rotations do not commute.


**Description of the Change:**
Remove `qml.Rot` (and `CRot`) from merge rotation pass patterns


**Benefits:**
No numerical errors
  • Loading branch information
paul0403 authored Oct 15, 2024
1 parent b8e28be commit ea55348
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 34 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
* A peephole merge rotations pass is now available in MLIR. It can be added to `catalyst.passes.pipeline`, or the
Python function `catalyst.passes.merge_rotations` can be directly called on a `QNode`.
[(#1162)](https://github.com/PennyLaneAI/catalyst/pull/1162)
[(#1206)](https://github.com/PennyLaneAI/catalyst/pull/1206)

Using the pipeline, one can run:

Expand Down
12 changes: 12 additions & 0 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def workflow():
def f(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.PhaseShift(x, wires=0)
qml.PhaseShift(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.probs()
Expand All @@ -81,6 +85,10 @@ def f(x):
def g(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.PhaseShift(x, wires=0)
qml.PhaseShift(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.probs()
Expand All @@ -91,6 +99,10 @@ def g(x):
def reference(x):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.Rot(x, x, x, wires=0)
qml.PhaseShift(x, wires=0)
qml.PhaseShift(x, wires=0)
qml.Hadamard(wires=0)
qml.Hadamard(wires=0)
return qml.probs()
Expand Down
9 changes: 1 addition & 8 deletions mlir/include/Quantum/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,7 @@ def MergeRotationsPass : Pass<"merge-rotations"> {
let summary = "merge rotation boilerplate words";

let constructor = "catalyst::createMergeRotationsPass()";
let options = !listconcat(
QuantumCircuitTransformationPass.options,
[
Option<"MyOption", "my-option",
"std::string", /*default=*/"\"\"",
"Boilerplate option. Delete if unnecessary.">,
]
);
let options = QuantumCircuitTransformationPass.options;
}

// ----- Quantum circuit transformation passes end ----- //
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ using llvm::dbgs;
using namespace mlir;
using namespace catalyst::quantum;

static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift", "Rot",
"CRX", "CRY", "CRZ", "ControlledPhaseShift", "CRot"};
static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift",
"CRX", "CRY", "CRZ", "ControlledPhaseShift"};

namespace {

Expand Down
42 changes: 18 additions & 24 deletions mlir/test/Quantum/MergeRotationsTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,14 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> !quantum.
// CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit
%0 = quantum.alloc( 1) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[angle00:%.+]] = arith.addf %arg1, %arg2 : f64
// CHECK: [[angle10:%.+]] = arith.addf %arg2, %arg0 : f64
// CHECK: [[angle20:%.+]] = arith.addf %arg0, %arg1 : f64
// CHECK: [[angle01:%.+]] = arith.addf %arg0, [[angle00]] : f64
// CHECK: [[angle11:%.+]] = arith.addf %arg1, [[angle10]] : f64
// CHECK: [[angle21:%.+]] = arith.addf %arg2, [[angle20]] : f64
// CHECK: [[ret:%.+]] = quantum.custom "Rot"([[angle01]], [[angle11]], [[angle21]]) [[qubit]] : !quantum.bit
// CHECK-NOT: quantum.custom "Rot"

// CHECK: quantum.custom "Rot"
// CHECK: quantum.custom "Rot"
// CHECK: [[ret:%.+]] = quantum.custom "Rot"
%2 = quantum.custom "Rot"(%arg0, %arg1, %arg2) %1 : !quantum.bit
%3 = quantum.custom "Rot"(%arg1, %arg2, %arg0) %2 : !quantum.bit
%4 = quantum.custom "Rot"(%arg2, %arg0, %arg1) %3 : !quantum.bit

// CHECK: return [[ret]]
return %4 : !quantum.bit
}
Expand All @@ -179,17 +176,14 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum
%0 = quantum.alloc( 2) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
// CHECK: [[angle00:%.+]] = arith.addf %arg1, %arg2 : f64
// CHECK: [[angle10:%.+]] = arith.addf %arg2, %arg0 : f64
// CHECK: [[angle20:%.+]] = arith.addf %arg0, %arg1 : f64
// CHECK: [[angle01:%.+]] = arith.addf %arg0, [[angle00]] : f64
// CHECK: [[angle11:%.+]] = arith.addf %arg1, [[angle10]] : f64
// CHECK: [[angle21:%.+]] = arith.addf %arg2, [[angle20]] : f64
// CHECK: [[ret:%.+]]:2 = quantum.custom "CRot"([[angle01]], [[angle11]], [[angle21]]) [[qubit1]], [[qubit2]] : !quantum.bit
// CHECK-NOT: quantum.custom "CRot"

// CHECK: quantum.custom "CRot"
// CHECK: quantum.custom "CRot"
// CHECK: [[ret:%.+]]:2 = quantum.custom "CRot"
%3:2 = quantum.custom "CRot"(%arg0, %arg1, %arg2) %1, %2 : !quantum.bit, !quantum.bit
%4:2 = quantum.custom "CRot"(%arg1, %arg2, %arg0) %3#0, %3#1 : !quantum.bit, !quantum.bit
%5:2 = quantum.custom "CRot"(%arg2, %arg0, %arg1) %4#0, %4#1 : !quantum.bit, !quantum.bit

// CHECK: return [[ret]]#0, [[ret]]#1
return %5#0, %5#1 : !quantum.bit, !quantum.bit
}
Expand Down Expand Up @@ -256,7 +250,7 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum

// -----

func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit, !quantum.bit) {
func.func @test_merge_rotations(%arg0: f64) -> (!quantum.bit, !quantum.bit, !quantum.bit) {
// CHECK: [[true:%.+]] = llvm.mlir.constant
// CHECK: [[false:%.+]] = llvm.mlir.constant
%true = llvm.mlir.constant (1 : i1) :i1
Expand All @@ -270,12 +264,12 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum
%0 = quantum.extract %reg[ 0] : !quantum.reg -> !quantum.bit
%1 = quantum.extract %reg[ 1] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %reg[ 2] : !quantum.reg -> !quantum.bit
// CHECK: [[angle0:%.+]] = arith.addf %arg0, %arg1 : f64
// CHECK: [[angle1:%.+]] = arith.addf %arg1, %arg2 : f64
// CHECK: [[angle2:%.+]] = arith.addf %arg2, %arg0 : f64
// CHECK: [[ret:%.+]], [[ctrlret:%.+]]:2 = quantum.custom "Rot"([[angle0]], [[angle1]], [[angle2]]) [[qubit0]] ctrls([[qubit1]], [[qubit2]]) ctrlvals([[true]], [[false]]) : !quantum.bit ctrls !quantum.bit, !quantum.bit
%out_qubits, %out_ctrl_qubits:2 = quantum.custom "Rot"(%arg0, %arg1, %arg2) %0 ctrls(%1, %2) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit
%out_qubits_1, %out_ctrl_qubits_1:2 = quantum.custom "Rot"(%arg1, %arg2, %arg0) %out_qubits ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit

// CHECK: [[angle:%.+]] = arith.addf %arg0, %arg0 : f64
// CHECK: [[ret:%.+]], [[ctrlret:%.+]]:2 = quantum.custom "RX"([[angle]]) [[qubit0]] ctrls([[qubit1]], [[qubit2]]) ctrlvals([[true]], [[false]]) : !quantum.bit ctrls !quantum.bit, !quantum.bit
%out_qubits, %out_ctrl_qubits:2 = quantum.custom "RX"(%arg0) %0 ctrls(%1, %2) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit
%out_qubits_1, %out_ctrl_qubits_1:2 = quantum.custom "RX"(%arg0) %out_qubits ctrls(%out_ctrl_qubits#0, %out_ctrl_qubits#1) ctrlvals(%true, %false) : !quantum.bit ctrls !quantum.bit, !quantum.bit

// CHECK: return [[ret]], [[ctrlret]]#0, [[ctrlret]]#1
return %out_qubits_1, %out_ctrl_qubits_1#0, %out_ctrl_qubits_1#1 : !quantum.bit, !quantum.bit, !quantum.bit
}
}

0 comments on commit ea55348

Please sign in to comment.