Skip to content

Commit

Permalink
improvement + Pgate tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
arsalan-motamedi committed Nov 6, 2024
1 parent 3b4dbfa commit b66efa2
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 11 deletions.
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/transformations/cxgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
f"The number of modes for a CXgate must be 2 (your input has {len(modes)} many modes)."
)
super().__init__(name="CXgate")
self.parameter_set.add_parameter(make_parameter(s_trainable, s, "s", s_bounds))
self._add_parameter(make_parameter(s_trainable, s, "s", s_bounds))
symplectic = math.astensor(

Check warning on line 66 in mrmustard/lab_dev/transformations/cxgate.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/lab_dev/transformations/cxgate.py#L64-L66

Added lines #L64 - L66 were not covered by tests
[
[1, 0, 0, 0],
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/transformations/czgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
f"The number of modes for a CZgate must be 2 (your input has {len(modes)} many modes)."
)
super().__init__(name="CZgate")
self.parameter_set.add_parameter(make_parameter(s_trainable, s, "s", s_bounds))
self._add_parameter(make_parameter(s_trainable, s, "s", s_bounds))
symplectic = math.astensor(

Check warning on line 68 in mrmustard/lab_dev/transformations/czgate.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/lab_dev/transformations/czgate.py#L66-L68

Added lines #L66 - L68 were not covered by tests
[
[1, 0, 0, 0],
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/transformations/interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
if unitary is None:
unitary = math.random_unitary(num_modes)
super().__init__(name="Interferometer")
self.parameter_set.add_parameter(
self._add_parameter(

Check warning on line 59 in mrmustard/lab_dev/transformations/interferometer.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/lab_dev/transformations/interferometer.py#L56-L59

Added lines #L56 - L59 were not covered by tests
make_parameter(unitary_trainable, unitary, "unitary", (None, None), update_unitary)
)
symplectic = math.block(

Check warning on line 62 in mrmustard/lab_dev/transformations/interferometer.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/lab_dev/transformations/interferometer.py#L62

Added line #L62 was not covered by tests
Expand Down
8 changes: 2 additions & 6 deletions mrmustard/lab_dev/transformations/mzgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,8 @@ def __init__(
internal: bool = False,
):
super().__init__(name="MZgate")
self.parameter_set.add_parameter(
make_parameter(phi_a_trainable, phi_a, "phi_a", phi_a_bounds)
)
self.parameter_set.add_parameter(
make_parameter(phi_b_trainable, phi_b, "phi_b", phi_b_bounds)
)
self._add_parameter(make_parameter(phi_a_trainable, phi_a, "phi_a", phi_a_bounds))
self._add_parameter(make_parameter(phi_b_trainable, phi_b, "phi_b", phi_b_bounds))

Check warning on line 66 in mrmustard/lab_dev/transformations/mzgate.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/lab_dev/transformations/mzgate.py#L64-L66

Added lines #L64 - L66 were not covered by tests

ca = math.cos(phi_a)
sa = math.sin(phi_a)
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/transformations/pgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
shearing_bounds: tuple[float | None, float | None] = (None, None),
):
super().__init__(name="Pgate")
self.parameter_set.add_parameter(
self._add_parameter(
make_parameter(shearing_trainable, shearing, "shearing", shearing_bounds)
)

Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/transformations/realinterferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
orthogonal = math.random_orthogonal(num_modes)

Check warning on line 55 in mrmustard/lab_dev/transformations/realinterferometer.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/lab_dev/transformations/realinterferometer.py#L54-L55

Added lines #L54 - L55 were not covered by tests

super().__init__(name="RealInterferometer")
self.parameter_set.add_parameter(
self._add_parameter(

Check warning on line 58 in mrmustard/lab_dev/transformations/realinterferometer.py

View check run for this annotation

Codecov / codecov/patch

mrmustard/lab_dev/transformations/realinterferometer.py#L57-L58

Added lines #L57 - L58 were not covered by tests
make_parameter(
orthogonal_trainable, orthogonal, "orthogonal", (None, None), update_orthogonal
)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_lab_dev/test_transformations/test_pgate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the ``Pgate`` class."""

# pylint: disable=missing-function-docstring, expression-not-assigned

import pytest

from mrmustard import math
from mrmustard.lab_dev.states import Vacuum
from mrmustard.lab_dev.transformations import Pgate


class TestPgate:
r"""
Tests for the ``Pgate`` class.
"""

def test_init(self):
"Tests the Pgate initialization."
up = Pgate([0, 1], 0.3)
assert up.modes == [0, 1]
assert up.name == "Pgate"
assert up.shearing.value == 0.3

@pytest.mark.parametrize("s", [0.1, 0.5, 1])
def test_application(self, s):
"Tests if Pgate is being applied correctly."
up = Pgate([0], s)
rho = Vacuum([0]) >> up
cov, _, _ = rho.phase_space(s=0)
temp = math.astensor([[1, 0], [s, 1]], dtype="complex128")
assert math.allclose(cov[0], temp @ math.eye(2) @ temp.T / 2)

0 comments on commit b66efa2

Please sign in to comment.