diff --git a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py index f36ce8e11..bf7a1cac9 100644 --- a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py @@ -56,6 +56,7 @@ def get_input_fnames_apply_ica( processing="ica", suffix="components", extension=".tsv" ) in_files["epochs"] = bids_basename.copy().update(suffix="epo", extension=".fif") + _update_for_splits(in_files, "epochs", single=True) return in_files diff --git a/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py b/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py index d34800bb7..65fc27b70 100644 --- a/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py @@ -41,6 +41,7 @@ def get_input_fnames_apply_ssp( ) in_files = dict() in_files["epochs"] = bids_basename.copy().update(suffix="epo", check=False) + _update_for_splits(in_files, "epochs", single=True) in_files["proj"] = bids_basename.copy().update(suffix="proj", check=False) return in_files @@ -60,7 +61,7 @@ def apply_ssp( # compute SSP on first run of raw out_files = dict() out_files["epochs"] = ( - in_files["epochs"].copy().update(processing="ssp", check=False) + in_files["epochs"].copy().update(processing="ssp", split=None, check=False) ) msg = f"Input epochs: {in_files['epochs'].basename}" logger.info(**gen_log_kwargs(message=msg)) diff --git a/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py b/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py index 141910ad3..7f7caefbb 100644 --- a/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py +++ b/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py @@ -48,6 +48,7 @@ def get_input_fnames_drop_ptp( ) in_files = dict() in_files["epochs"] = bids_path.copy().update(processing=cfg.spatial_filter) + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -63,7 +64,14 @@ def drop_ptp( in_files: dict, ) -> dict: out_files = dict() - out_files["epochs"] = in_files["epochs"].copy().update(processing="clean") + out_files["epochs"] = ( + in_files["epochs"] + .copy() + .update( + processing="clean", + split=None, + ) + ) msg = f'Input: {in_files["epochs"].basename}' logger.info(**gen_log_kwargs(message=msg)) msg = f'Output: {out_files["epochs"].basename}' diff --git a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py index a0f7d1e3e..2ec0ea714 100644 --- a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py +++ b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py @@ -17,7 +17,13 @@ from ..._logging import gen_log_kwargs, logger from ..._parallel import parallel_func, get_parallel_backend from ..._report import _open_report, _sanitize_cond_tag, _all_conditions -from ..._run import failsafe_run, save_logs, _sanitize_callable, _prep_out_files +from ..._run import ( + failsafe_run, + save_logs, + _sanitize_callable, + _prep_out_files, + _update_for_splits, +) def get_input_fnames_evoked( @@ -43,6 +49,7 @@ def get_input_fnames_evoked( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -59,7 +66,14 @@ def run_evoked( ) -> dict: out_files = dict() out_files["evoked"] = ( - in_files["epochs"].copy().update(suffix="ave", processing=None, check=False) + in_files["epochs"] + .copy() + .update( + suffix="ave", + processing=None, + check=False, + split=None, + ) ) msg = f'Input: {in_files["epochs"].basename}' diff --git a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py index 287ace7bc..9960c2670 100644 --- a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py +++ b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py @@ -35,7 +35,7 @@ from ..._logging import gen_log_kwargs, logger from ..._decoding import LogReg from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, save_logs, _prep_out_files +from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits from ..._report import ( _open_report, _contrasts_to_names, @@ -68,6 +68,7 @@ def get_input_fnames_epochs_decoding( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -89,7 +90,7 @@ def run_epochs_decoding( msg = f"Contrasting conditions: {condition1} – {condition2}" logger.info(**gen_log_kwargs(message=msg)) out_files = dict() - bids_path = in_files["epochs"].copy() + bids_path = in_files["epochs"].copy().update(split=None) epochs = mne.read_epochs(in_files.pop("epochs")) _restrict_analyze_channels(epochs, cfg) diff --git a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py index f78e1d0cf..1dda99dad 100644 --- a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py @@ -38,7 +38,7 @@ ) from ..._decoding import LogReg from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs, _prep_out_files +from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits from ..._parallel import get_parallel_backend, get_parallel_backend_name from ..._report import ( _open_report, @@ -73,6 +73,7 @@ def get_input_fnames_time_decoding( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -98,7 +99,7 @@ def run_time_decoding( msg = f"Contrasting conditions ({kind}): {condition1} – {condition2}" logger.info(**gen_log_kwargs(message=msg)) out_files = dict() - bids_path = in_files["epochs"].copy() + bids_path = in_files["epochs"].copy().update(split=None) epochs = mne.read_epochs(in_files.pop("epochs")) _restrict_analyze_channels(epochs, cfg) diff --git a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py index 0e7b9b720..92a403c11 100644 --- a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py @@ -22,7 +22,7 @@ _restrict_analyze_channels, ) from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs, _prep_out_files +from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits from ..._parallel import get_parallel_backend, parallel_func from ..._report import _open_report, _sanitize_cond_tag @@ -53,6 +53,7 @@ def get_input_fnames_time_frequency( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -73,7 +74,7 @@ def run_time_frequency( msg = f"Reading {epochs_path.basename}" logger.info(**gen_log_kwargs(message=msg)) epochs = mne.read_epochs(epochs_path) - bids_path = epochs_path.copy().update(processing=None) + bids_path = epochs_path.copy().update(processing=None, split=None) del epochs_path _restrict_analyze_channels(epochs, cfg) diff --git a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py index a18385d93..14087fba7 100644 --- a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py @@ -27,7 +27,7 @@ from ..._decoding import LogReg, _handle_csp_args from ..._logging import logger, gen_log_kwargs from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, save_logs, _prep_out_files +from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits from ..._report import ( _open_report, _sanitize_cond_tag, @@ -132,6 +132,7 @@ def get_input_fnames_csp( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -157,7 +158,7 @@ def one_subject_decoding( msg = f"Contrasting conditions: {condition1} – {condition2}" logger.info(**gen_log_kwargs(msg)) - bids_path = in_files["epochs"].copy().update(processing=None) + bids_path = in_files["epochs"].copy().update(processing=None, split=None) epochs = mne.read_epochs(in_files.pop("epochs")) _restrict_analyze_channels(epochs, cfg) diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index 27323d680..4c8d25eac 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -20,7 +20,13 @@ from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func from ..._report import _open_report, _sanitize_cond_tag, _all_conditions -from ..._run import failsafe_run, save_logs, _sanitize_callable, _prep_out_files +from ..._run import ( + failsafe_run, + save_logs, + _sanitize_callable, + _prep_out_files, + _update_for_splits, +) def get_input_fnames_cov( @@ -47,9 +53,8 @@ def get_input_fnames_cov( root=cfg.deriv_root, check=False, ) - in_files["report_info"] = fname_epochs.copy().update( - processing="clean", suffix="epo" - ) + in_files["report_info"] = fname_epochs.copy().update(processing="clean") + _update_for_splits(in_files, "report_info", single=True) fname_evoked = fname_epochs.copy().update( suffix="ave", processing=None, check=False ) @@ -83,6 +88,7 @@ def get_input_fnames_cov( else: assert cov_type == "epochs", cov_type in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_base.py b/mne_bids_pipeline/tests/configs/config_ds000248_base.py index 22f29b35f..b80b6f0f0 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_base.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_base.py @@ -26,6 +26,8 @@ def noise_cov(bp): # Use pre-stimulus period as noise source bp = bp.copy().update(processing="clean", suffix="epo") + if not bp.fpath.exists(): + bp.update(split="01") epo = mne.read_epochs(bp) cov = mne.compute_covariance(epo, rank="info", tmax=0) return cov