diff --git a/mne/io/base.py b/mne/io/base.py index 79cbbe192ba..b1c473a6ae3 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -912,6 +912,8 @@ def get_data( Whether to reject by annotation. If None (default), no rejection is done. If 'omit', segments annotated with description starting with 'bad' are omitted. If 'NaN', the bad samples are filled with NaNs. + Note that the last sample of each annotation will also be omitted + or replaced with NaN. return_times : bool Whether to return times as well. Defaults to False. %(units)s @@ -983,7 +985,7 @@ def get_data( "reject_by_annotation", reject_by_annotation.lower(), ["omit", "nan"] ) onsets, ends = _annotations_starts_stops(self, ["BAD"]) - keep = (onsets < stop) & (ends > start) + keep = (onsets < stop) & (ends >= start) onsets = np.maximum(onsets[keep], start) ends = np.minimum(ends[keep], stop) if len(onsets) == 0: @@ -996,9 +998,9 @@ def get_data( n_samples = stop - start # total number of samples used = np.ones(n_samples, bool) for onset, end in zip(onsets, ends): - if onset >= end: + if onset > end: continue - used[onset - start : end - start] = False + used[onset - start : end - start + 1] = False used = np.concatenate([[False], used, [False]]) starts = np.where(~used[:-1] & used[1:])[0] + start stops = np.where(used[:-1] & ~used[1:])[0] + start diff --git a/mne/io/tests/test_raw.py b/mne/io/tests/test_raw.py index 5f4556d6d8e..3217f3500fc 100644 --- a/mne/io/tests/test_raw.py +++ b/mne/io/tests/test_raw.py @@ -706,29 +706,52 @@ def test_meas_date_orig_time(): def test_get_data_reject(): """Test if reject_by_annotation is working correctly.""" - fs = 256 + fs = 100 ch_names = ["C3", "Cz", "C4"] info = create_info(ch_names, sfreq=fs) - raw = RawArray(np.zeros((len(ch_names), 10 * fs)), info) + n_times = 10 * fs + raw = RawArray(np.zeros((len(ch_names), n_times)), info) raw.set_annotations(Annotations(onset=[2, 4], duration=[3, 2], description="bad")) with catch_logging() as log: data = raw.get_data(reject_by_annotation="omit", verbose=True) msg = ( - "Omitting 1024 of 2560 (40.00%) samples, retaining 1536" - + " (60.00%) samples." + "Omitting 401 of 1000 (40.10%) samples, retaining 599" + + " (59.90%) samples." ) assert log.getvalue().strip() == msg - assert data.shape == (len(ch_names), 1536) + assert data.shape == (len(ch_names), 599) with catch_logging() as log: data = raw.get_data(reject_by_annotation="nan", verbose=True) msg = ( - "Setting 1024 of 2560 (40.00%) samples to NaN, retaining 1536" - + " (60.00%) samples." + "Setting 401 of 1000 (40.10%) samples to NaN, retaining 599" + + " (59.90%) samples." ) assert log.getvalue().strip() == msg - assert data.shape == (len(ch_names), 2560) # shape doesn't change - assert np.isnan(data).sum() == 3072 # but NaNs are introduced instead + assert data.shape == (len(ch_names), n_times) # shape doesn't change + assert np.isnan(data).sum() == 1203 # but NaNs are introduced instead + + # Test that 1-sample annotations at start and end of recording are handled + raw.set_annotations(Annotations(onset=[0], duration=[0], description="bad")) + data = raw.get_data(reject_by_annotation="omit", verbose=True) + assert data.shape == (len(ch_names), n_times - 1) + raw.set_annotations( + Annotations(onset=[raw.times[-1]], duration=[0], description="bad") + ) + data = raw.get_data(reject_by_annotation="omit", verbose=True) + assert data.shape == (len(ch_names), n_times - 1) + + # Test that 1-sample annotations are handled correctly, when they occur + # because of cropping + raw.set_annotations( + Annotations(onset=[raw.times[-1]], duration=[1 / fs], description="bad") + ) + with catch_logging() as log: + data = raw.get_data(reject_by_annotation="omit", start=1, verbose=True) + msg = "Omitting 1 of 999 (0.10%) samples, retaining 998 (99.90%)" + " samples." + assert log.getvalue().strip() == msg + print(data.shape) + assert data.shape == (len(ch_names), n_times - 2) def test_5839(): diff --git a/mne/tests/test_annotations.py b/mne/tests/test_annotations.py index 6b1356ae107..167163ff8b7 100644 --- a/mne/tests/test_annotations.py +++ b/mne/tests/test_annotations.py @@ -432,7 +432,7 @@ def test_raw_reject(first_samp): return_times=True, # 1-112 s ) bad_times = np.concatenate( - [np.arange(200, 400), np.arange(10000, 10800), np.arange(10500, 11000)] + [np.arange(200, 401), np.arange(10000, 10801), np.arange(10500, 11001)] ) expected_times = np.setdiff1d(np.arange(100, 11200), bad_times) / sfreq assert_allclose(times, expected_times) @@ -450,7 +450,7 @@ def test_raw_reject(first_samp): t_stop = 18.0 assert raw.times[-1] > t_stop n_stop = int(round(t_stop * raw.info["sfreq"])) - n_drop = int(round(4 * raw.info["sfreq"])) + n_drop = int(round(4 * raw.info["sfreq"]) + 2) assert len(raw.times) >= n_stop data, times = raw.get_data(range(10), 0, n_stop, "omit", True) assert data.shape == (10, n_stop - n_drop) @@ -558,8 +558,8 @@ def test_annotation_filtering(first_samp): raw = raws[0].copy() raw.set_annotations(Annotations([0.0], [0.5], ["BAD_ACQ_SKIP"])) my_data, times = raw.get_data(reject_by_annotation="omit", return_times=True) - assert_allclose(times, raw.times[500:]) - assert my_data.shape == (1, 500) + assert_allclose(times, raw.times[501:]) + assert my_data.shape == (1, 499) raw_filt = raw.copy().filter(skip_by_annotation="bad_acq_skip", **kwargs_stop) expected = data.copy() expected[:, 500:] = 0 @@ -586,7 +586,7 @@ def test_annotation_omit(first_samp): expected = raw[0][0] assert_allclose(raw.get_data(reject_by_annotation=None), expected) # nan - expected[0, 500:1500] = np.nan + expected[0, 500:1501] = np.nan assert_allclose(raw.get_data(reject_by_annotation="nan"), expected) got = np.concatenate( [