From 5fbd36f1e4130cae11cb5ffa329ce0ce4a3f178e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20M=2E=20K=C3=B6hler?= Date: Fri, 11 Oct 2024 17:34:40 +0200 Subject: [PATCH 1/3] Reject last sample of annotations - Keep bad annotations that end at the start sample (annotations that are potentialy one sample long): end = start - Do not discard bad annotations that are exactly one sample long: onset = end - Also discard last sample of bad annotation: end - start + 1 --- mne/io/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mne/io/base.py b/mne/io/base.py index 79cbbe192ba..b1c473a6ae3 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -912,6 +912,8 @@ def get_data( Whether to reject by annotation. If None (default), no rejection is done. If 'omit', segments annotated with description starting with 'bad' are omitted. If 'NaN', the bad samples are filled with NaNs. + Note that the last sample of each annotation will also be omitted + or replaced with NaN. return_times : bool Whether to return times as well. Defaults to False. %(units)s @@ -983,7 +985,7 @@ def get_data( "reject_by_annotation", reject_by_annotation.lower(), ["omit", "nan"] ) onsets, ends = _annotations_starts_stops(self, ["BAD"]) - keep = (onsets < stop) & (ends > start) + keep = (onsets < stop) & (ends >= start) onsets = np.maximum(onsets[keep], start) ends = np.minimum(ends[keep], stop) if len(onsets) == 0: @@ -996,9 +998,9 @@ def get_data( n_samples = stop - start # total number of samples used = np.ones(n_samples, bool) for onset, end in zip(onsets, ends): - if onset >= end: + if onset > end: continue - used[onset - start : end - start] = False + used[onset - start : end - start + 1] = False used = np.concatenate([[False], used, [False]]) starts = np.where(~used[:-1] & used[1:])[0] + start stops = np.where(used[:-1] & ~used[1:])[0] + start From 47a010460e776145d91bc5712225eaafffc95263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20M=2E=20K=C3=B6hler?= Date: Fri, 11 Oct 2024 17:35:39 +0200 Subject: [PATCH 2/3] Fix and add tests - Fix now broken tests - Add tests for annotations that are one sample long --- mne/io/tests/test_raw.py | 40 +++++++++++++++++++++++++++-------- mne/tests/test_annotations.py | 10 ++++----- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 5f4556d6d8e..e7ac3369c23 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -706,29 +706,51 @@ def test_meas_date_orig_time(): def test_get_data_reject(): """Test if reject_by_annotation is working correctly.""" - fs = 256 + fs = 100 ch_names = ["C3", "Cz", "C4"] info = create_info(ch_names, sfreq=fs) - raw = RawArray(np.zeros((len(ch_names), 10 * fs)), info) + n_times = 10 * fs + raw = RawArray(np.zeros((len(ch_names), n_times)), info) raw.set_annotations(Annotations(onset=[2, 4], duration=[3, 2], description="bad")) with catch_logging() as log: data = raw.get_data(reject_by_annotation="omit", verbose=True) msg = ( - "Omitting 1024 of 2560 (40.00%) samples, retaining 1536" - + " (60.00%) samples." + "Omitting 401 of 1000 (40.10%) samples, retaining 599" + + " (59.90%) samples." ) assert log.getvalue().strip() == msg - assert data.shape == (len(ch_names), 1536) + assert data.shape == (len(ch_names), 599) with catch_logging() as log: data = raw.get_data(reject_by_annotation="nan", verbose=True) msg = ( - "Setting 1024 of 2560 (40.00%) samples to NaN, retaining 1536" - + " (60.00%) samples." + "Setting 401 of 1000 (40.10%) samples to NaN, retaining 599" + + " (59.90%) samples." ) assert log.getvalue().strip() == msg - assert data.shape == (len(ch_names), 2560) # shape doesn't change - assert np.isnan(data).sum() == 3072 # but NaNs are introduced instead + assert data.shape == (len(ch_names), n_times) # shape doesn't change + assert np.isnan(data).sum() == 1203 # but NaNs are introduced instead + + # Test that 1-sample annotations at start and end of recording are handled + raw.set_annotations(Annotations(onset=[0], duration=[0], description="bad")) + data = raw.get_data(reject_by_annotation="omit", verbose=True) + assert data.shape == (len(ch_names), n_times - 1) + raw.set_annotations(Annotations(onset=[raw.times[-1]], duration=[0], description="bad")) + data = raw.get_data(reject_by_annotation="omit", verbose=True) + assert data.shape == (len(ch_names), n_times - 1) + + # Test that 1-sample annotations are handled correctly, when they occur + # because of cropping + raw.set_annotations(Annotations(onset=[raw.times[-1]], duration=[1/fs], description="bad")) + with catch_logging() as log: + data = raw.get_data(reject_by_annotation="omit", start=1, verbose=True) + msg = ( + "Omitting 1 of 999 (0.10%) samples, retaining 998 (99.90%)" + + " samples." + ) + assert log.getvalue().strip() == msg + print(data.shape) + assert data.shape == (len(ch_names), n_times - 2) def test_5839(): diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6b1356ae107..167163ff8b7 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -432,7 +432,7 @@ def test_raw_reject(first_samp): return_times=True, # 1-112 s ) bad_times = np.concatenate( - [np.arange(200, 400), np.arange(10000, 10800), np.arange(10500, 11000)] + [np.arange(200, 401), np.arange(10000, 10801), np.arange(10500, 11001)] ) expected_times = np.setdiff1d(np.arange(100, 11200), bad_times) / sfreq assert_allclose(times, expected_times) @@ -450,7 +450,7 @@ def test_raw_reject(first_samp): t_stop = 18.0 assert raw.times[-1] > t_stop n_stop = int(round(t_stop * raw.info["sfreq"])) - n_drop = int(round(4 * raw.info["sfreq"])) + n_drop = int(round(4 * raw.info["sfreq"]) + 2) assert len(raw.times) >= n_stop data, times = raw.get_data(range(10), 0, n_stop, "omit", True) assert data.shape == (10, n_stop - n_drop) @@ -558,8 +558,8 @@ def test_annotation_filtering(first_samp): raw = raws[0].copy() raw.set_annotations(Annotations([0.0], [0.5], ["BAD_ACQ_SKIP"])) my_data, times = raw.get_data(reject_by_annotation="omit", return_times=True) - assert_allclose(times, raw.times[500:]) - assert my_data.shape == (1, 500) + assert_allclose(times, raw.times[501:]) + assert my_data.shape == (1, 499) raw_filt = raw.copy().filter(skip_by_annotation="bad_acq_skip", **kwargs_stop) expected = data.copy() expected[:, 500:] = 0 @@ -586,7 +586,7 @@ def test_annotation_omit(first_samp): expected = raw[0][0] assert_allclose(raw.get_data(reject_by_annotation=None), expected) # nan - expected[0, 500:1500] = np.nan + expected[0, 500:1501] = np.nan assert_allclose(raw.get_data(reject_by_annotation="nan"), expected) got = np.concatenate( [ From 16dfd8a2b8ba28eedfaac43b65b724ed7955a208 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:48:30 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/io/tests/test_raw.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index e7ac3369c23..3217f3500fc 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -735,19 +735,20 @@ def test_get_data_reject(): raw.set_annotations(Annotations(onset=[0], duration=[0], description="bad")) data = raw.get_data(reject_by_annotation="omit", verbose=True) assert data.shape == (len(ch_names), n_times - 1) - raw.set_annotations(Annotations(onset=[raw.times[-1]], duration=[0], description="bad")) + raw.set_annotations( + Annotations(onset=[raw.times[-1]], duration=[0], description="bad") + ) data = raw.get_data(reject_by_annotation="omit", verbose=True) assert data.shape == (len(ch_names), n_times - 1) # Test that 1-sample annotations are handled correctly, when they occur # because of cropping - raw.set_annotations(Annotations(onset=[raw.times[-1]], duration=[1/fs], description="bad")) + raw.set_annotations( + Annotations(onset=[raw.times[-1]], duration=[1 / fs], description="bad") + ) with catch_logging() as log: data = raw.get_data(reject_by_annotation="omit", start=1, verbose=True) - msg = ( - "Omitting 1 of 999 (0.10%) samples, retaining 998 (99.90%)" - + " samples." - ) + msg = "Omitting 1 of 999 (0.10%) samples, retaining 998 (99.90%)" + " samples." assert log.getvalue().strip() == msg print(data.shape) assert data.shape == (len(ch_names), n_times - 2)