Skip to content

Commit

Permalink
Cleanup implementation and add arbitrary waveform forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Oct 19, 2023
1 parent ce42fc0 commit c724725
Showing 1 changed file with 59 additions and 55 deletions.
114 changes: 59 additions & 55 deletions qupulse/program/linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,25 @@
from qupulse.parameter_scope import Scope, MappedScope, FrozenDict
from qupulse.program import (ProgramBuilder, HardwareTime, HardwareVoltage, Waveform, RepetitionCount, TimeType,
SimpleExpression)
from qupulse.expressions import sympy as sym_expr
from qupulse.program.waveforms import MultiChannelWaveform


DEFAULT_RESOLUTION: float = 1e-9
# this resolution is used to unify increments
# the increments themselves remain floats
DEFAULT_INCREMENT_RESOLUTION: float = 1e-9


@dataclass(frozen=True)
class DepKey:
"""The key that identifies how a certain set command depends on iteration indices."""
factors: Tuple[int, ...]

@classmethod
def from_voltages(cls, voltages: Sequence[float], resolution: float):
# remove trailing zeros
while voltages and voltages[-1] == 0:
voltages = voltages[:-1]
return cls(tuple(int(round(voltage / resolution)) for voltage in voltages))


@dataclass
Expand All @@ -21,19 +36,6 @@ def dependencies(self) -> Mapping[int, set]:
raise NotImplementedError


@dataclass
class LinSpaceSet:
channel: int
base: float
factors: Optional[Tuple[float, ...]]



@dataclass
class Wait:
duration: TimeType


@dataclass
class LinSpaceHold(LinSpaceNode):
bases: Tuple[float, ...]
Expand All @@ -47,12 +49,6 @@ def dependencies(self) -> Mapping[int, set]:
for idx, factors in enumerate(self.factors)
if factors}

def to_atomic_commands(self):
if self.duration_factors:
raise NotImplementedError('Variable durations are not implemented for commands yet')
return [LinSpaceSet(idx, base, factors)
for idx, (base, factors) in enumerate(zip(self.bases, self.factors))] + [Wait(self.duration_base)]

def to_increment_commands(self, previous: Tuple[float, ...], iter_advance: Sequence[bool]):
if self.duration_factors:
raise NotImplementedError('Variable durations are not implemented for increment commands yet')
Expand Down Expand Up @@ -81,6 +77,12 @@ def to_increment_commands(self, previous: Tuple[float, ...], iter_advance: Seque
set_vals.append(set_val)


@dataclass
class LinSpaceArbitraryWaveform(LinSpaceNode):
waveform: Waveform
channels: Tuple[ChannelID, ...]


@dataclass
class LinSpaceRepeat(LinSpaceNode):
body: Tuple[LinSpaceNode, ...]
Expand Down Expand Up @@ -111,22 +113,21 @@ def dependencies(self):


class LinSpaceBuilder(ProgramBuilder):
def __init__(self, channels: Tuple[Optional[ChannelID], ...]):
"""This program builder supports efficient translation of pulse templates that use symbolic linearly
spaced voltages and durations.
The channel identifiers are reduced to their index in the given channel tuple.
Arbitrary waveforms are not implemented yet
"""
def __init__(self, channels: Tuple[ChannelID, ...]):
super().__init__()
self._name_to_idx = {name: idx for idx, name in enumerate(channels) if name is not None}
self._name_to_idx = {name: idx for idx, name in enumerate(channels)}
self._idx_to_name = channels

self._stack = [[]]
self._ranges = []

@classmethod
def from_channel_dict(cls, channels: Mapping[ChannelID, int]):
assert len(set(channels.values())) == len(channels), "no duplicate target channels"
channel_list = [None] * 20
for ch_name, ch_idx in channels.items():
channel_list[ch_idx] = ch_name
return cls(tuple(channel_list))

def _root(self):
return self._stack[0]

Expand Down Expand Up @@ -187,7 +188,7 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard
self._stack[-1].append(set_cmd)

def play_arbitrary_waveform(self, waveform: Waveform):
raise NotImplementedError('Not implemented yet (postponed)')
return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform, self._idx_to_name))

def measure(self, measurements: Optional[Sequence[MeasurementWindow]]):
"""Ignores measurements"""
Expand Down Expand Up @@ -249,39 +250,43 @@ class Set:
key: 'DepKey' = dataclasses.field(default_factory=lambda: DepKey(()))


@dataclass
class Wait:
duration: TimeType


@dataclass
class LoopJmp:
idx: int


@dataclass
class Play:
waveform: Waveform
channels: Tuple[ChannelID]


Command = Increment | Set | LoopLabel | LoopJmp | Wait | Play


@dataclass(frozen=True)
class DepState:
base: float
iterations: Tuple[int, ...]


@dataclass(frozen=True)
class DepKey:
"""The key that identifies how a certain set command depends on iteration indices."""
factors: Tuple[int, ...]

@classmethod
def from_voltages(cls, voltages: Sequence[float], resolution: float):
# remove trailing zeros
while voltages and voltages[-1] == 0:
voltages = voltages[:-1]
return cls(tuple(int(round(voltage / resolution)) for voltage in voltages))


@dataclass
class TranslationState:
label_num: int
commands: list
iterations: list
active_dep: Dict[int, DepKey]
dep_states: Dict[int, Dict[DepKey, DepState]]
plain_voltage: Dict[int, float]
resolution: float = dataclasses.field(default_factory=lambda: DEFAULT_RESOLUTION)
label_num: int = dataclasses.field(default=0)
commands: List[Command] = dataclasses.field(default_factory=list)
iterations: List[int] = dataclasses.field(default_factory=list)
active_dep: Dict[int, DepKey] = dataclasses.field(default_factory=dict)
dep_states: Dict[int, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict)
plain_voltage: Dict[int, float] = dataclasses.field(default_factory=dict)
resolution: float = dataclasses.field(default_factory=lambda: DEFAULT_INCREMENT_RESOLUTION)

def new_loop(self, count: int):
label = LoopLabel(self.label_num, count)
Expand Down Expand Up @@ -376,14 +381,13 @@ def to_atomic_commands(node: Union[LinSpaceNode, Sequence[LinSpaceNode]], state:
state.active_dep[ch] = dep_key
state.dep_states[ch][dep_key] = new_dep_state
state.commands.append(Wait(node.duration_base))
elif isinstance(node, LinSpaceArbitraryWaveform):
state.commands.append(Play(node.waveform, node.channels))
else:
raise TypeError("The node type is not handled", type(node), node)


def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> list:
def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Command]:
state = TranslationState(0, [], [], active_dep={}, dep_states={}, plain_voltage={})
to_atomic_commands(linspace_nodes, state)
return state.commands





0 comments on commit c724725

Please sign in to comment.