From 485c0d8a0266f4435e671aafbad8be5bf46c6855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9sar=20Benito=20Lamata?= Date: Tue, 22 Oct 2024 19:56:40 +0200 Subject: [PATCH 1/2] Exclude discarded shots when asserting decoder prediction size --- glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py index ee28c53e..c9a70268 100755 --- a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py +++ b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py @@ -197,8 +197,8 @@ def sample(self, max_shots: int) -> AnonTaskStats: raise ValueError("predictions.dtype != np.uint8") if len(predictions.shape) != 2: raise ValueError("len(predictions.shape) != 2") - if predictions.shape[0] != num_shots: - raise ValueError("predictions.shape[0] != num_shots") + if predictions.shape[0] != num_shots - num_discards_1: + raise ValueError("predictions.shape[0] != num_shots - num_discards_1") if predictions.shape[1] < actual_obs.shape[1]: raise ValueError("predictions.shape[1] < actual_obs.shape[1]") if predictions.shape[1] > actual_obs.shape[1] + 1: From d81c4aa4e0ae507f95b2e5261f70c5d4501eb7a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9sar=20Benito=20Lamata?= Date: Tue, 29 Oct 2024 16:50:47 +0100 Subject: [PATCH 2/2] Add unit test for detector post-selection sampling --- .../_stim_then_decode_sampler_test.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py index 413015f0..2ef3f1a7 100755 --- a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py +++ b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py @@ -1,9 +1,12 @@ import collections import numpy as np +import stim +from sinter._data import Task +from sinter._decoding._decoding_vacuous import VacuousDecoder from sinter._decoding._stim_then_decode_sampler import \ - classify_discards_and_errors + classify_discards_and_errors, _CompiledStimThenDecodeSampler def test_classify_discards_and_errors(): @@ -190,3 +193,23 @@ def test_classify_discards_and_errors(): num_obs=13, ) == (0, 1) assert counter == collections.Counter(["obs_mistake_mask=_________E___"]) + +def test_detector_post_selection(): + circuit = stim.Circuit(""" + X_ERROR(1) 0 + M 0 + DETECTOR rec[-1] + """) + sampler = _CompiledStimThenDecodeSampler( + decoder=VacuousDecoder(), + task = Task( + circuit=circuit, + detector_error_model=circuit.detector_error_model(), + postselection_mask=np.array([1], dtype=np.uint8), + ), + count_observable_error_combos=False, + count_detection_events=False, + tmp_dir=None + ) + result = sampler.sample(max_shots=1) + assert result.discards == 1 \ No newline at end of file