diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index cf5e8fa13..d6067f397 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -698,6 +698,8 @@ def _apply_btops_for_change_of_rep(self, indices: Sequence[int]) -> CircuitCompo ret = copy.deepcopy(self) for i in indices: + if len(self._index_representation[i]) > 2: + continue name, arg = self._index_representation[i] if name == "PS": @@ -732,6 +734,8 @@ def _apply_btoq_for_change_of_rep(self, indices: Sequence[int]) -> CircuitCompon ret = copy.deepcopy(self) for i in indices: + if len(self._index_representation[i]) > 2: + continue name, arg = self._index_representation[i] if name == "Q": ret._index_representation[i] = ("B", None) # perhaps not needed -- can be removed @@ -868,20 +872,20 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: for m in other.wires.output.bra.modes: i = result.wires.index_dicts[0][m] j = other.wires.index_dicts[0][m] - result._index_representation[i] = other._index_representation[j] + result._index_representation[i] = other._index_representation[j][:2] for m in other.wires.output.ket.modes: i = result.wires.index_dicts[2][m] j = other.wires.index_dicts[2][m] - result._index_representation[i] = other._index_representation[j] + result._index_representation[i] = other._index_representation[j][:2] for m in self.wires.input.bra.modes: i = result.wires.index_dicts[1][m] j = self.wires.index_dicts[1][m] - result._index_representation[i] = self._index_representation[j] + result._index_representation[i] = self._index_representation[j][:2] for m in self.wires.input.ket.modes: i = result.wires.index_dicts[3][m] j = self.wires.index_dicts[3][m] - result._index_representation[i] = self._index_representation[j] + result._index_representation[i] = self._index_representation[j][:2] # now we check for indices that might have been contracted: idx_1, idx_2 = self.wires.contracted_indices(other.wires) @@ -890,25 +894,29 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: j = other.wires.index_dicts[1][m] if j not in idx_2: i = result.wires.index_dicts[1][m] - result._index_representation[i] = other._index_representation[j] + result._index_representation[i] = other._index_representation[j][:2] for m in other.wires.input.ket.modes: j = other.wires.index_dicts[3][m] if j not in idx_2: i = result.wires.index_dicts[3][m] - result._index_representation[i] = other._index_representation[j] + result._index_representation[i] = other._index_representation[j][:2] for m in self.wires.output.bra.modes: j = self.wires.index_dicts[0][m] if j not in idx_1: i = result.wires.index_dicts[0][m] - result._index_representation[i] = self._index_representation[j] + result._index_representation[i] = self._index_representation[j][:2] for m in self.wires.output.ket.modes: j = self.wires.index_dicts[2][m] if j not in idx_1: i = result.wires.index_dicts[2][m] - result._index_representation[i] = self._index_representation[j] + result._index_representation[i] = self._index_representation[j][:2] + + if ((isinstance(self, BtoQ) or isinstance(self, BtoPS) )and (isinstance(other, BtoQ) or isinstance(other,BtoPS))): + for i in result.wires.indices: + result._index_representation[i] = result._index_representation[i] + (None,) return result