Skip to content

Commit

Permalink
Corrected the issue of having a multiplication of BtoQ stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
arsalan-motamedi committed Oct 22, 2024
1 parent 7316083 commit 8a5c93e
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,8 @@ def _apply_btops_for_change_of_rep(self, indices: Sequence[int]) -> CircuitCompo
ret = copy.deepcopy(self)

for i in indices:
if len(self._index_representation[i]) > 2:
continue
name, arg = self._index_representation[i]

if name == "PS":
Expand Down Expand Up @@ -732,6 +734,8 @@ def _apply_btoq_for_change_of_rep(self, indices: Sequence[int]) -> CircuitCompon

ret = copy.deepcopy(self)
for i in indices:
if len(self._index_representation[i]) > 2:
continue
name, arg = self._index_representation[i]
if name == "Q":
ret._index_representation[i] = ("B", None) # perhaps not needed -- can be removed
Expand Down Expand Up @@ -868,20 +872,20 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent:
for m in other.wires.output.bra.modes:
i = result.wires.index_dicts[0][m]
j = other.wires.index_dicts[0][m]
result._index_representation[i] = other._index_representation[j]
result._index_representation[i] = other._index_representation[j][:2]
for m in other.wires.output.ket.modes:
i = result.wires.index_dicts[2][m]
j = other.wires.index_dicts[2][m]
result._index_representation[i] = other._index_representation[j]
result._index_representation[i] = other._index_representation[j][:2]

for m in self.wires.input.bra.modes:
i = result.wires.index_dicts[1][m]
j = self.wires.index_dicts[1][m]
result._index_representation[i] = self._index_representation[j]
result._index_representation[i] = self._index_representation[j][:2]
for m in self.wires.input.ket.modes:
i = result.wires.index_dicts[3][m]
j = self.wires.index_dicts[3][m]
result._index_representation[i] = self._index_representation[j]
result._index_representation[i] = self._index_representation[j][:2]

# now we check for indices that might have been contracted:
idx_1, idx_2 = self.wires.contracted_indices(other.wires)
Expand All @@ -890,25 +894,29 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent:
j = other.wires.index_dicts[1][m]
if j not in idx_2:
i = result.wires.index_dicts[1][m]
result._index_representation[i] = other._index_representation[j]
result._index_representation[i] = other._index_representation[j][:2]
for m in other.wires.input.ket.modes:

j = other.wires.index_dicts[3][m]
if j not in idx_2:
i = result.wires.index_dicts[3][m]
result._index_representation[i] = other._index_representation[j]
result._index_representation[i] = other._index_representation[j][:2]

for m in self.wires.output.bra.modes:
j = self.wires.index_dicts[0][m]
if j not in idx_1:
i = result.wires.index_dicts[0][m]
result._index_representation[i] = self._index_representation[j]
result._index_representation[i] = self._index_representation[j][:2]

for m in self.wires.output.ket.modes:
j = self.wires.index_dicts[2][m]
if j not in idx_1:
i = result.wires.index_dicts[2][m]
result._index_representation[i] = self._index_representation[j]
result._index_representation[i] = self._index_representation[j][:2]

if ((isinstance(self, BtoQ) or isinstance(self, BtoPS) )and (isinstance(other, BtoQ) or isinstance(other,BtoPS))):
for i in result.wires.indices:
result._index_representation[i] = result._index_representation[i] + (None,)

return result

Expand Down

0 comments on commit 8a5c93e

Please sign in to comment.