Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simulator for linspace program #861

Merged
merged 7 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion qupulse/program/linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _add_iteration_node(self, node: LinSpaceIter):
self.add_node(node.body)

if node.length > 1:
self.iterations[-1] = node.length
self.iterations[-1] = node.length - 1
label, jmp = self.new_loop(node.length - 1)
self.commands.append(label)
self.add_node(node.body)
Expand Down Expand Up @@ -412,3 +412,72 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Comman
state.add_node(linspace_nodes)
return state.commands


class LinSpaceVM:
def __init__(self, channels: int):
self.current_values = [np.nan] * channels
self.time = TimeType(0)
self.registers = tuple({} for _ in range(channels))

self.history: List[Tuple[TimeType, Tuple[float, ...]]] = []

self.commands = None
self.label_targets = None
self.label_counts = None
self.current_command = None

def change_state(self, cmd: Union[Set, Increment, Wait, Play]):
if isinstance(cmd, Play):
raise NotImplementedError("TODO: Implement arbitrary waveform simulation")
elif isinstance(cmd, Wait):
self.history.append(
(self.time, self.current_values.copy())
)
self.time += cmd.duration
elif isinstance(cmd, Set):
self.current_values[cmd.channel] = cmd.value
self.registers[cmd.channel][cmd.key] = cmd.value
elif isinstance(cmd, Increment):
value = self.registers[cmd.channel][cmd.dependency_key]
value += cmd.value
self.registers[cmd.channel][cmd.dependency_key] = value
self.current_values[cmd.channel] = value
else:
raise NotImplementedError(cmd)

def set_commands(self, commands: Sequence[Command]):
self.commands = []
self.label_targets = {}
self.label_counts = {}
self.current_command = None

for cmd in commands:
self.commands.append(cmd)
if isinstance(cmd, LoopLabel):
# a loop label signifies a reset count followed by the actual label that targets the following command
assert cmd.idx not in self.label_targets
self.label_targets[cmd.idx] = len(self.commands)

self.current_command = 0

def step(self):
cmd = self.commands[self.current_command]
if isinstance(cmd, LoopJmp):
if self.label_counts[cmd.idx] > 0:
self.label_counts[cmd.idx] -= 1
self.current_command = self.label_targets[cmd.idx]
else:
# ignore jump
self.current_command += 1
elif isinstance(cmd, LoopLabel):
self.label_counts[cmd.idx] = cmd.count - 1
self.current_command += 1
else:
self.change_state(cmd)
self.current_command += 1

def run(self):
while self.current_command < len(self.commands):
self.step()


99 changes: 92 additions & 7 deletions tests/program/linspace_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from qupulse.program.linspace import *
from qupulse.program.transformation import *


def assert_vm_output_almost_equal(test: TestCase, expected, actual):
"""Compare two vm outputs with default TestCase.assertAlmostEqual accuracy"""
test.assertEqual(len(expected), len(actual))
for idx, ((t_e, vals_e), (t_a, vals_a)) in enumerate(zip(expected, actual)):
test.assertEqual(t_e, t_a, f"Differing times in {idx} element")
test.assertEqual(len(vals_e), len(vals_a), f"Differing channel count in {idx} element")
for ch, (val_e, val_a) in enumerate(zip(vals_e, vals_a)):
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} element channel {ch}")


class SingleRampTest(TestCase):
def setUp(self):
hold = ConstantPT(10 ** 6, {'a': '-1. + idx * 0.01'})
Expand All @@ -32,6 +43,10 @@ def setUp(self):
LoopJmp(0)
]

self.output = [
(TimeType(10**6 * idx), [sum([-1.0] + [0.01] * idx)]) for idx in range(200)
]

def test_program(self):
program_builder = LinSpaceBuilder(('a',))
program = self.pulse_template.create_program(program_builder=program_builder)
Expand All @@ -41,6 +56,12 @@ def test_commands(self):
commands = to_increment_commands([self.program])
self.assertEqual(self.commands, commands)

def test_output(self):
vm = LinSpaceVM(1)
vm.set_commands(commands=self.commands)
vm.run()
assert_vm_output_almost_equal(self, self.output, vm.history)


class PlainCSDTest(TestCase):
def setUp(self):
Expand Down Expand Up @@ -74,7 +95,7 @@ def setUp(self):

LoopLabel(1, 99),

Increment(0, -2.0, key_0),
Increment(0, -1.99, key_0),
Increment(1, 0.02, key_1),
Wait(TimeType(10 ** 6)),

Expand All @@ -86,6 +107,16 @@ def setUp(self):
LoopJmp(1),
]

a_values = [sum([-1.] + [0.01] * i) for i in range(200)]
b_values = [sum([-.5] + [0.02] * j) for j in range(100)]

self.output = [
(
TimeType(10 ** 6 * (i + 200 * j)),
[a_values[i], b_values[j]]
) for j in range(100) for i in range(200)
]

def test_program(self):
program_builder = LinSpaceBuilder(('a', 'b'))
program = self.pulse_template.create_program(program_builder=program_builder)
Expand All @@ -95,13 +126,20 @@ def test_increment_commands(self):
commands = to_increment_commands([self.program])
self.assertEqual(self.commands, commands)

def test_output(self):
vm = LinSpaceVM(2)
vm.set_commands(self.commands)
vm.run()
assert_vm_output_almost_equal(self, self.output, vm.history)


class TiltedCSDTest(TestCase):
def setUp(self):
repetition_count = 3
hold = ConstantPT(10**6, {'a': '-1. + idx_a * 0.01 + idx_b * 1e-3', 'b': '-.5 + idx_b * 0.02 - 3e-3 * idx_a'})
scan_a = hold.with_iteration('idx_a', 200)
self.pulse_template = scan_a.with_iteration('idx_b', 100)
self.repeated_pt = self.pulse_template.with_repetition(42)
self.repeated_pt = self.pulse_template.with_repetition(repetition_count)

self.program = LinSpaceIter(length=100, body=(LinSpaceIter(
length=200,
Expand All @@ -113,7 +151,7 @@ def setUp(self):
duration_factors=None
),)
),))
self.repeated_program = LinSpaceRepeat(body=(self.program,), count=42)
self.repeated_program = LinSpaceRepeat(body=(self.program,), count=repetition_count)

key_0 = DepKey.from_voltages((1e-3, 0.01,), DEFAULT_INCREMENT_RESOLUTION)
key_1 = DepKey.from_voltages((0.02, -3e-3), DEFAULT_INCREMENT_RESOLUTION)
Expand All @@ -131,8 +169,8 @@ def setUp(self):

LoopLabel(1, 99),

Increment(0, 1e-3 + -200 * 1e-2, key_0),
Increment(1, 0.02 + -200 * -3e-3, key_1),
Increment(0, 1e-3 + -199 * 1e-2, key_0),
Increment(1, 0.02 + -199 * -3e-3, key_1),
Wait(TimeType(10 ** 6)),

LoopLabel(2, 199),
Expand All @@ -147,7 +185,19 @@ def setUp(self):
for cmd in inner_commands:
if hasattr(cmd, 'idx'):
cmd.idx += 1
self.repeated_commands = [LoopLabel(0, 42)] + inner_commands + [LoopJmp(0)]
self.repeated_commands = [LoopLabel(0, repetition_count)] + inner_commands + [LoopJmp(0)]

self.output = [
(
TimeType(10 ** 6 * (i + 200 * j)),
[-1. + i * 0.01 + j * 1e-3, -.5 + j * 0.02 - 3e-3 * i]
) for j in range(100) for i in range(200)
]
self.repeated_output = [
(t + TimeType(10**6) * (n * 100 * 200), vals)
for n in range(repetition_count)
for t, vals in self.output
]

def test_program(self):
program_builder = LinSpaceBuilder(('a', 'b'))
Expand All @@ -167,6 +217,18 @@ def test_repeated_increment_commands(self):
commands = to_increment_commands([self.repeated_program])
self.assertEqual(self.repeated_commands, commands)

def test_output(self):
vm = LinSpaceVM(2)
vm.set_commands(self.commands)
vm.run()
assert_vm_output_almost_equal(self, self.output, vm.history)

def test_repeated_output(self):
vm = LinSpaceVM(2)
vm.set_commands(self.repeated_commands)
vm.run()
assert_vm_output_almost_equal(self, self.repeated_output, vm.history)


class SingletLoadProcessing(TestCase):
def setUp(self):
Expand Down Expand Up @@ -223,7 +285,7 @@ def setUp(self):
Set(0, -0.4),
Set(1, -0.3),
Wait(TimeType(10 ** 5)),
Increment(0, -2.0, key_0),
Increment(0, -1.99, key_0),
Increment(1, 0.02, key_1),
Wait(TimeType(10 ** 6)),
Set(0, 0.05),
Expand All @@ -247,6 +309,23 @@ def setUp(self):
LoopJmp(1),
]

self.output = []
time = TimeType(0)
for idx_b in range(100):
for idx_a in range(200):
self.output.append(
(time, [-.4, -.3])
)
time += 10 ** 5
self.output.append(
(time, [-1. + idx_a * 0.01, -.5 + idx_b * 0.02])
)
time += 10 ** 6
self.output.append(
(time, [0.05, 0.06])
)
time += 10 ** 5

def test_singlet_scan_program(self):
program_builder = LinSpaceBuilder(('a', 'b'))
program = self.pulse_template.create_program(program_builder=program_builder)
Expand All @@ -256,6 +335,12 @@ def test_singlet_scan_commands(self):
commands = to_increment_commands([self.program])
self.assertEqual(self.commands, commands)

def test_singlet_scan_output(self):
vm = LinSpaceVM(2)
vm.set_commands(self.commands)
vm.run()
assert_vm_output_almost_equal(self, self.output, vm.history)


class TransformedRampTest(TestCase):
def setUp(self):
Expand Down
Loading