Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep copy from Program.compile() #688

Merged
merged 9 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
instead of `phi = np.pi / 2`) for the phase shift of the beamsplitters.
[(#674)](https://github.com/XanaduAI/strawberryfields/pull/674)

* `Program.compile()` returns a deep copy of the program, instead of a shallow copy, while still keeping
the same register references.
[(#675)](https://github.com/XanaduAI/strawberryfields/pull/675)

<h3>Documentation</h3>

<h3>Contributors</h3>
Expand Down
10 changes: 9 additions & 1 deletion strawberryfields/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,15 @@ def _linked_copy(self):
Program: a copy of the Program
"""
self.lock()
p = copy.copy(self) # shares RegRefs with the source
p = copy.copy(self)

for name, val in self.__dict__.items():
# Deep-copy all attributes except 'circuit' and 'reg_refs', since the programs
# should share the same register references. Program.circuit potentially
# contains FreeParameters/MeasuredParameters, which contain RegRefs.
if name not in ("circuit", "reg_refs", "init_reg_refs"):
setattr(p, name, copy.deepcopy(val))
Comment on lines +555 to +556
Copy link
Contributor Author

@thisac thisac Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimally, we should deep-copy the circuit as well, but since the may contain regrefs, this isn't as straight-forward to do. Just making a deepcopy here would cause two errors:

  1. Due to MeasuredParameter not supporting deepcopy (because of the different signature of the __new__() method). A __deepcopy__() method would potentially need to be added.
  2. Any copied (not referenced) regrefs would raise the unnerving "RegRef state has become inconsistent." issue we had earlier.

We could e.g., override the __deepcopy__ method for the Command class, but that would require it to only deep-copy everything except symbolic parameters. Alternatively, add a copy_everything_except_regref method to Command.

Copy link
Contributor

@sduquemesa sduquemesa Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is a potentially actionable item, perhaps is worth to have it as a (TODO) comment in the code. It might get lost here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created an issue for this instead (#691). I think it's better to keep track of it there rather than as a TODO which can easily be lost. 🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created an issue for this instead (#691). I think it's better to keep track of it there rather than as a TODO which can easily be lost. 🙂


# link to the original source Program
if self.source is None:
p.source = self
Expand Down
21 changes: 20 additions & 1 deletion tests/frontend/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_eq_symmetric_bsgate(self, compare_params):
assert prog_2.equivalence(prog_1, compare_params=compare_params)

@pytest.mark.parametrize("compare_params", [True, False])
def test_neq_operator_equivalent(self, compare_params):
def test_equivalence_different_circuits(self, compare_params):
"""Programs with differnet, but equivalent, circuits."""
thisac marked this conversation as resolved.
Show resolved Hide resolved
prog_1 = sf.Program(3)
prog_2 = sf.Program(3)
Expand Down Expand Up @@ -534,6 +534,25 @@ def test_has_feed_forward(self):
assert prog_2.has_feed_forward is False
assert prog_2.has_post_selection is False

def test_linked_copy(self, prog):
"""Check that the ``_linked_copy`` method copies a program correctly."""

with prog.context as q:
ops.Fock(2) | q[0]
ops.BSgate() | (q[0], q[1])
ops.MeasureFock() | q[1]

prog_copy = prog._linked_copy()

# registers should be the same
for i, regref in prog_copy.reg_refs.items():
assert regref is prog.reg_refs[i]

for i, cmd in enumerate(prog_copy.circuit):
assert cmd is prog.circuit[i]

assert prog_copy.source is prog


class TestRegRefs:
"""Testing register references."""
Expand Down
27 changes: 27 additions & 0 deletions tests/frontend/test_tdmprogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,33 @@ def singleloop_program(r, alpha, phi, theta):
device = Device(device_spec)


def test_linked_copy():
"""Check that the ``_linked_copy`` method copies a TDM program correctly."""
sq_r = 0.5643
c = 2
alpha = [np.pi / 4, 0] * c
phi = [0, np.pi / 2] * c
theta = [0, 0, np.pi / 2, np.pi / 2]
prog = singleloop_program(sq_r, alpha, phi, theta)

prog_copy = prog._linked_copy()
assert prog_copy.circuit
assert prog.circuit

# registers should be the same
for i, regref in prog_copy.reg_refs.items():
assert regref is prog.reg_refs[i]

for i, cmd in enumerate(prog_copy.circuit):
assert cmd is prog.circuit[i]

# tdm_params should be equal, but not the same
assert prog_copy.tdm_params is not prog.tdm_params
assert prog_copy.tdm_params == prog.tdm_params

assert prog_copy.source is prog


class TestTDMcompiler:
"""Test class for checking error messages from the compiler"""

Expand Down