Skip to content

Commit

Permalink
Working on trial.
Browse files Browse the repository at this point in the history
  • Loading branch information
fdmalone committed Jun 13, 2024
1 parent bb0013c commit 2451dfe
Showing 1 changed file with 24 additions and 28 deletions.
52 changes: 24 additions & 28 deletions recirq/qcqmc/trial_wf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc
import copy
import itertools
from dataclasses import dataclass
from typing import (
Callable,
Dict,
Expand All @@ -15,6 +14,7 @@
Union,
)

import attrs
import cirq
import fqe
import fqe.algorithm.low_rank
Expand All @@ -36,7 +36,7 @@
from recirq.qcqmc.hamiltonian import HamiltonianData, HamiltonianParams


@dataclass(frozen=True)
@attrs.frozen
class FermionicMode:
orb_ind: int
spin: str # Should be "a" or "b" only
Expand All @@ -59,7 +59,7 @@ def openfermion_standard_index(self) -> int:
return 2 * self.orb_ind + (self.spin == "b")


@dataclass(frozen=True)
@attrs.frozen
class LayerSpec:
"""A specification of a hardware-efficient layer of gates.
Expand Down Expand Up @@ -89,7 +89,7 @@ def _json_dict_(self):
return cirq.dataclass_json_dict(self)


@dataclass(frozen=True)
@attrs.frozen
class TrialWavefunctionParams(Params, metaclass=abc.ABCMeta):
name: str
hamiltonian_params: HamiltonianParams
Expand All @@ -113,7 +113,15 @@ def qubits_linearly_connected(self) -> Tuple[cirq.GridQubit, ...]:
)


@dataclass(frozen=True, repr=False, eq=False)
def _to_numpy(x: Optional[Iterable] = None) -> Optional[np.ndarray]:
return np.asarray(x)


def _to_tuple(x: Iterable[LayerSpec]) -> Sequence[LayerSpec]:
return tuple(x)


@attrs.frozen(repr=False)
class PerfectPairingPlusTrialWavefunctionParams(TrialWavefunctionParams):
"""Class for storing the parameters that specify a TrialWavefunctionData.
Expand All @@ -124,26 +132,21 @@ class PerfectPairingPlusTrialWavefunctionParams(TrialWavefunctionParams):

name: str
hamiltonian_params: HamiltonianParams
heuristic_layers: Tuple[LayerSpec, ...]
heuristic_layers: Tuple[LayerSpec, ...] = attrs.field(converter=_to_tuple)
do_pp: bool = True
restricted: bool = False
random_parameter_scale: float = 1.0
n_optimization_restarts: int = 1
seed: int = 0
initial_orbital_rotation: Optional[np.ndarray] = None
initial_two_body_qchem_amplitudes: Optional[np.ndarray] = None
initial_orbital_rotation: Optional[np.ndarray] = attrs.field(
default=None, converter=lambda v: _to_numpy(v) if v is not None else None
)
initial_two_body_qchem_amplitudes: Optional[np.ndarray] = attrs.field(
default=None, converter=lambda v: _to_numpy if v is not None else None
)
do_optimization: bool = True
use_fast_gradients: bool = False

def __post_init__(self):
"""A little special sauce to make sure that this ends up as a tuple."""
object.__setattr__(self, "heuristic_layers", tuple(self.heuristic_layers))

array_like = ["initial_orbital_rotation", "initial_two_body_qchem_amplitudes"]
for field in array_like:
if getattr(self, field) is not None:
object.__setattr__(self, field, np.asarray(getattr(self, field)))

@property
def n_orb(self) -> int:
return self.hamiltonian_params.n_orb
Expand Down Expand Up @@ -367,7 +370,7 @@ def _get_fqe_wavefunctions(
return wf, unrotated_wf


@dataclass(frozen=True, eq=False)
@attrs.frozen
class TrialWavefunctionData(Data):
"""Class for storing a trial wavefunction's data."""

Expand All @@ -377,16 +380,9 @@ class TrialWavefunctionData(Data):
hf_energy: float
ansatz_energy: float
fci_energy: float
one_body_basis_change_mat: np.ndarray
one_body_params: np.ndarray
two_body_params: np.ndarray

def __post_init__(self):
"""We need to make some inputs into np.ndarrays if aren't provided that way."""
array_like = ["one_body_basis_change_mat", "one_body_params", "two_body_params"]

for field in array_like:
object.__setattr__(self, field, np.asarray(getattr(self, field)))
one_body_basis_change_mat: np.ndarray = attrs.field(converter=_to_numpy)
one_body_params: np.ndarray = attrs.field(converter=_to_numpy)
two_body_params: np.ndarray = attrs.field(converter=_to_numpy)

def _json_dict_(self):
return cirq.dataclass_json_dict(self)
Expand Down

0 comments on commit 2451dfe

Please sign in to comment.