From 2451dfec9ade5752deb2508c8a80968924fd9caf Mon Sep 17 00:00:00 2001 From: Fionn Malone Date: Thu, 13 Jun 2024 22:46:12 +0000 Subject: [PATCH] Working on trial. --- recirq/qcqmc/trial_wf.py | 52 +++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/recirq/qcqmc/trial_wf.py b/recirq/qcqmc/trial_wf.py index b9570691..f162952c 100644 --- a/recirq/qcqmc/trial_wf.py +++ b/recirq/qcqmc/trial_wf.py @@ -1,7 +1,6 @@ import abc import copy import itertools -from dataclasses import dataclass from typing import ( Callable, Dict, @@ -15,6 +14,7 @@ Union, ) +import attrs import cirq import fqe import fqe.algorithm.low_rank @@ -36,7 +36,7 @@ from recirq.qcqmc.hamiltonian import HamiltonianData, HamiltonianParams -@dataclass(frozen=True) +@attrs.frozen class FermionicMode: orb_ind: int spin: str # Should be "a" or "b" only @@ -59,7 +59,7 @@ def openfermion_standard_index(self) -> int: return 2 * self.orb_ind + (self.spin == "b") -@dataclass(frozen=True) +@attrs.frozen class LayerSpec: """A specification of a hardware-efficient layer of gates. @@ -89,7 +89,7 @@ def _json_dict_(self): return cirq.dataclass_json_dict(self) -@dataclass(frozen=True) +@attrs.frozen class TrialWavefunctionParams(Params, metaclass=abc.ABCMeta): name: str hamiltonian_params: HamiltonianParams @@ -113,7 +113,15 @@ def qubits_linearly_connected(self) -> Tuple[cirq.GridQubit, ...]: ) -@dataclass(frozen=True, repr=False, eq=False) +def _to_numpy(x: Optional[Iterable] = None) -> Optional[np.ndarray]: + return np.asarray(x) + + +def _to_tuple(x: Iterable[LayerSpec]) -> Sequence[LayerSpec]: + return tuple(x) + + +@attrs.frozen(repr=False) class PerfectPairingPlusTrialWavefunctionParams(TrialWavefunctionParams): """Class for storing the parameters that specify a TrialWavefunctionData. @@ -124,26 +132,21 @@ class PerfectPairingPlusTrialWavefunctionParams(TrialWavefunctionParams): name: str hamiltonian_params: HamiltonianParams - heuristic_layers: Tuple[LayerSpec, ...] + heuristic_layers: Tuple[LayerSpec, ...] = attrs.field(converter=_to_tuple) do_pp: bool = True restricted: bool = False random_parameter_scale: float = 1.0 n_optimization_restarts: int = 1 seed: int = 0 - initial_orbital_rotation: Optional[np.ndarray] = None - initial_two_body_qchem_amplitudes: Optional[np.ndarray] = None + initial_orbital_rotation: Optional[np.ndarray] = attrs.field( + default=None, converter=lambda v: _to_numpy(v) if v is not None else None + ) + initial_two_body_qchem_amplitudes: Optional[np.ndarray] = attrs.field( + default=None, converter=lambda v: _to_numpy if v is not None else None + ) do_optimization: bool = True use_fast_gradients: bool = False - def __post_init__(self): - """A little special sauce to make sure that this ends up as a tuple.""" - object.__setattr__(self, "heuristic_layers", tuple(self.heuristic_layers)) - - array_like = ["initial_orbital_rotation", "initial_two_body_qchem_amplitudes"] - for field in array_like: - if getattr(self, field) is not None: - object.__setattr__(self, field, np.asarray(getattr(self, field))) - @property def n_orb(self) -> int: return self.hamiltonian_params.n_orb @@ -367,7 +370,7 @@ def _get_fqe_wavefunctions( return wf, unrotated_wf -@dataclass(frozen=True, eq=False) +@attrs.frozen class TrialWavefunctionData(Data): """Class for storing a trial wavefunction's data.""" @@ -377,16 +380,9 @@ class TrialWavefunctionData(Data): hf_energy: float ansatz_energy: float fci_energy: float - one_body_basis_change_mat: np.ndarray - one_body_params: np.ndarray - two_body_params: np.ndarray - - def __post_init__(self): - """We need to make some inputs into np.ndarrays if aren't provided that way.""" - array_like = ["one_body_basis_change_mat", "one_body_params", "two_body_params"] - - for field in array_like: - object.__setattr__(self, field, np.asarray(getattr(self, field))) + one_body_basis_change_mat: np.ndarray = attrs.field(converter=_to_numpy) + one_body_params: np.ndarray = attrs.field(converter=_to_numpy) + two_body_params: np.ndarray = attrs.field(converter=_to_numpy) def _json_dict_(self): return cirq.dataclass_json_dict(self)