diff --git a/mrmustard/lab_dev/transformations/cxgate.py b/mrmustard/lab_dev/transformations/cxgate.py index 6cc3115f1..5126abc54 100644 --- a/mrmustard/lab_dev/transformations/cxgate.py +++ b/mrmustard/lab_dev/transformations/cxgate.py @@ -62,7 +62,7 @@ def __init__( f"The number of modes for a CXgate must be 2 (your input has {len(modes)} many modes)." ) super().__init__(name="CXgate") - self.parameter_set.add_parameter(make_parameter(s_trainable, s, "s", s_bounds)) + self._add_parameter(make_parameter(s_trainable, s, "s", s_bounds)) symplectic = math.astensor( [ [1, 0, 0, 0], diff --git a/mrmustard/lab_dev/transformations/czgate.py b/mrmustard/lab_dev/transformations/czgate.py index 3e8e39aff..4bf582a89 100644 --- a/mrmustard/lab_dev/transformations/czgate.py +++ b/mrmustard/lab_dev/transformations/czgate.py @@ -64,7 +64,7 @@ def __init__( f"The number of modes for a CZgate must be 2 (your input has {len(modes)} many modes)." ) super().__init__(name="CZgate") - self.parameter_set.add_parameter(make_parameter(s_trainable, s, "s", s_bounds)) + self._add_parameter(make_parameter(s_trainable, s, "s", s_bounds)) symplectic = math.astensor( [ [1, 0, 0, 0], diff --git a/mrmustard/lab_dev/transformations/interferometer.py b/mrmustard/lab_dev/transformations/interferometer.py index 81796c73a..ccccd0867 100644 --- a/mrmustard/lab_dev/transformations/interferometer.py +++ b/mrmustard/lab_dev/transformations/interferometer.py @@ -56,7 +56,7 @@ def __init__( if unitary is None: unitary = math.random_unitary(num_modes) super().__init__(name="Interferometer") - self.parameter_set.add_parameter( + self._add_parameter( make_parameter(unitary_trainable, unitary, "unitary", (None, None), update_unitary) ) symplectic = math.block( diff --git a/mrmustard/lab_dev/transformations/mzgate.py b/mrmustard/lab_dev/transformations/mzgate.py index cc0bbf0ce..950ed2ac7 100644 --- a/mrmustard/lab_dev/transformations/mzgate.py +++ b/mrmustard/lab_dev/transformations/mzgate.py @@ -62,12 +62,8 @@ def __init__( internal: bool = False, ): super().__init__(name="MZgate") - self.parameter_set.add_parameter( - make_parameter(phi_a_trainable, phi_a, "phi_a", phi_a_bounds) - ) - self.parameter_set.add_parameter( - make_parameter(phi_b_trainable, phi_b, "phi_b", phi_b_bounds) - ) + self._add_parameter(make_parameter(phi_a_trainable, phi_a, "phi_a", phi_a_bounds)) + self._add_parameter(make_parameter(phi_b_trainable, phi_b, "phi_b", phi_b_bounds)) ca = math.cos(phi_a) sa = math.sin(phi_a) diff --git a/mrmustard/lab_dev/transformations/pgate.py b/mrmustard/lab_dev/transformations/pgate.py index 09fdf8184..fbeeeede0 100644 --- a/mrmustard/lab_dev/transformations/pgate.py +++ b/mrmustard/lab_dev/transformations/pgate.py @@ -59,7 +59,7 @@ def __init__( shearing_bounds: tuple[float | None, float | None] = (None, None), ): super().__init__(name="Pgate") - self.parameter_set.add_parameter( + self._add_parameter( make_parameter(shearing_trainable, shearing, "shearing", shearing_bounds) ) diff --git a/mrmustard/lab_dev/transformations/realinterferometer.py b/mrmustard/lab_dev/transformations/realinterferometer.py index 668615da0..76f1575bd 100644 --- a/mrmustard/lab_dev/transformations/realinterferometer.py +++ b/mrmustard/lab_dev/transformations/realinterferometer.py @@ -55,7 +55,7 @@ def __init__( orthogonal = math.random_orthogonal(num_modes) super().__init__(name="RealInterferometer") - self.parameter_set.add_parameter( + self._add_parameter( make_parameter( orthogonal_trainable, orthogonal, "orthogonal", (None, None), update_orthogonal ) diff --git a/tests/test_lab_dev/test_transformations/test_pgate.py b/tests/test_lab_dev/test_transformations/test_pgate.py new file mode 100644 index 000000000..4e6730c31 --- /dev/null +++ b/tests/test_lab_dev/test_transformations/test_pgate.py @@ -0,0 +1,45 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the ``Pgate`` class.""" + +# pylint: disable=missing-function-docstring, expression-not-assigned + +import pytest + +from mrmustard import math +from mrmustard.lab_dev.states import Vacuum +from mrmustard.lab_dev.transformations import Pgate + + +class TestPgate: + r""" + Tests for the ``Pgate`` class. + """ + + def test_init(self): + "Tests the Pgate initialization." + up = Pgate([0, 1], 0.3) + assert up.modes == [0, 1] + assert up.name == "Pgate" + assert up.shearing.value == 0.3 + + @pytest.mark.parametrize("s", [0.1, 0.5, 1]) + def test_application(self, s): + "Tests if Pgate is being applied correctly." + up = Pgate([0], s) + rho = Vacuum([0]) >> up + cov, _, _ = rho.phase_space(s=0) + temp = math.astensor([[1, 0], [s, 1]], dtype="complex128") + assert math.allclose(cov[0], temp @ math.eye(2) @ temp.T / 2)