Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
fdmalone committed Jun 29, 2024
1 parent 3f9ef9d commit f124e06
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
10 changes: 4 additions & 6 deletions recirq/qcqmc/blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import (Dict, Iterable, Iterator, List, Optional, Sequence, Tuple,
Union)
from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union

import attrs
import cirq
Expand All @@ -22,8 +21,7 @@

from recirq.qcqmc.config import OUTDIRS
from recirq.qcqmc.data import Data, Params
from recirq.qcqmc.trial_wf import (TrialWavefunctionData,
TrialWavefunctionParams)
from recirq.qcqmc.trial_wf import TrialWavefunctionData, TrialWavefunctionParams

BlueprintParams = Union["BlueprintParamsTrialWf", "BlueprintParamsRobustShadow"]

Expand Down Expand Up @@ -102,7 +100,7 @@ def _get_resolvers(
)

for clifford_set in truncated_cliffords:
yield quaff.get_truncated_clifford_resolver(clifford_set)
yield quaff.get_truncated_cliffords_resolver(clifford_set)


@attrs.frozen
Expand Down Expand Up @@ -222,7 +220,7 @@ def build_blueprint_from_base_circuit(
)

parameterized_clifford_ops: Iterable[cirq.OP_TREE] = (
quaff.get_parameterized_truncated_clifford_ops(params.qubit_partition)
quaff.get_parameterized_truncated_cliffords_ops(params.qubit_partition)
)

parameterized_clifford_circuits = tuple(
Expand Down
6 changes: 6 additions & 0 deletions recirq/qcqmc/blueprint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def test_small_blueprint(
):
_, trial_wf_data = fixture_4_qubit_ham_and_trial_wf
trial_wf_params = trial_wf_data.params
import attrs

for k, v in attrs.asdict((trial_wf_params)).items():
print(k, v)

blueprint_params = BlueprintParamsTrialWf(
name="blueprint_test",
Expand Down Expand Up @@ -150,8 +154,10 @@ def test_medium(
assert len(list(blueprint.resolvers)) == 3

resolved_circuits = list(blueprint.resolved_clifford_circuits)
# print(resolved_circuits)
assert len(resolved_circuits) == 3
for circuit_tuple in resolved_circuits:
print(len(circuit_tuple), len(circuit_tuple[0]))
assert len(circuit_tuple) == 8
for circuit, qubits in zip(circuit_tuple, blueprint_params.qubit_partition):
assert len(circuit.all_qubits()) == len(qubits)
Expand Down
12 changes: 9 additions & 3 deletions recirq/qcqmc/trial_wf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,15 @@ def _to_tuple(x: Iterable[layer_spec.LayerSpec]) -> Sequence[layer_spec.LayerSpe
return tuple(x)


@attrs.frozen(repr=False)
def _array_cmp(a: Optional[np.ndarray], b: Optional[np.ndarray]) -> bool:
if a is None and b is None:
return True
if a is None or b is None:
return False
return np.array_equal(a, b)


@attrs.frozen(repr=False, eq=False)
class PerfectPairingPlusTrialWavefunctionParams(TrialWavefunctionParams):
"""Class for storing the parameters that specify the trial wavefunction.
Expand Down Expand Up @@ -112,12 +120,10 @@ class PerfectPairingPlusTrialWavefunctionParams(TrialWavefunctionParams):
initial_orbital_rotation: Optional[np.ndarray] = attrs.field(
default=None,
converter=lambda v: _to_numpy(v) if v is not None else None,
eq=attrs.cmp_using(eq=np.array_equal),
)
initial_two_body_qchem_amplitudes: Optional[np.ndarray] = attrs.field(
default=None,
converter=lambda v: _to_numpy(v) if v is not None else None,
eq=attrs.cmp_using(eq=np.array_equal),
)
do_optimization: bool = True
use_fast_gradients: bool = False
Expand Down

0 comments on commit f124e06

Please sign in to comment.