From 1179ae4441dd9092e42acf26376000a2fbacb743 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 17 Jul 2023 17:28:56 -0400 Subject: [PATCH 01/15] ENH: Cache grand-average steps --- .circleci/run_dataset_and_copy_files.sh | 9 + docs/source/v1.5.md.inc | 1 + mne_bids_pipeline/_report.py | 710 ++++--------- mne_bids_pipeline/_run.py | 4 +- .../steps/sensor/_99_group_average.py | 950 ++++++++++++------ mne_bids_pipeline/tests/conftest.py | 3 + 6 files changed, 877 insertions(+), 800 deletions(-) diff --git a/.circleci/run_dataset_and_copy_files.sh b/.circleci/run_dataset_and_copy_files.sh index 63a49c8b4..16f926481 100755 --- a/.circleci/run_dataset_and_copy_files.sh +++ b/.circleci/run_dataset_and_copy_files.sh @@ -14,7 +14,16 @@ else COPY_FILES="true" fi +SECONDS=0 pytest mne_bids_pipeline --junit-xml=test-results/junit-results.xml -k ${DS_RUN} +echo "Runtime: ${SECONDS} seconds" + +# rerun test! +SECONDS=0 +pytest mne_bids_pipeline -k $DS_RUN +RUN_TIME=$SECONDS +echo "Runtime: ${RUN_TIME} seconds (should be < 10)" +test $RUN_TIME -lt 10 if [[ "$COPY_FILES" == "false" ]]; then exit 0 diff --git a/docs/source/v1.5.md.inc b/docs/source/v1.5.md.inc index 352abd4c7..f4771dff6 100644 --- a/docs/source/v1.5.md.inc +++ b/docs/source/v1.5.md.inc @@ -4,6 +4,7 @@ - Added support for annotating bad segments based on head movement velocity (#757 by @larsoner) - Added examples of T1 and FLASH BEM to website (#758 by @larsoner) +- Added caching of sensor and source average steps (#765 by @larsoner) [//]: # (### :warning: Behavior changes) diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index bc8ddca4c..03d7a4dc4 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -1,7 +1,6 @@ import contextlib from functools import lru_cache from io import StringIO -import os.path as op from pathlib import Path from typing import Optional, List, Literal from types import SimpleNamespace @@ -18,7 +17,7 @@ from mne_bids import BIDSPath from mne_bids.stats import count_events -from ._config_utils import sanitize_cond_name, get_subjects, _restrict_analyze_channels +from ._config_utils import sanitize_cond_name from ._decoding import _handle_csp_args from ._logging import logger, gen_log_kwargs @@ -472,7 +471,8 @@ def add_event_counts( except ValueError: msg = "Could not read events." logger.warning(**gen_log_kwargs(message=msg)) - df_events = None + return + logger.info(**gen_log_kwargs(message="Adding event counts to report …")) if df_events is not None: css_classes = ("table", "table-striped", "table-borderless", "table-hover") @@ -553,105 +553,6 @@ def _all_conditions(*, cfg): return conditions -def run_report_average_sensor( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], -) -> None: - msg = "Generating grand average report …" - logger.info(**gen_log_kwargs(message=msg)) - assert matplotlib.get_backend() == "agg", matplotlib.get_backend() - - evoked_fname = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="ave", - extension=".fif", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - - title = f"sub-{subject}" - if session is not None: - title += f", ses-{session}" - if cfg.task is not None: - title += f", task-{cfg.task}" - - all_evokeds = mne.read_evokeds(evoked_fname) - for evoked in all_evokeds: - _restrict_analyze_channels(evoked, cfg) - conditions = _all_conditions(cfg=cfg) - assert len(conditions) == len(all_evokeds) - all_evokeds = {cond: evoked for cond, evoked in zip(conditions, all_evokeds)} - - with _open_report( - cfg=cfg, exec_params=exec_params, subject=subject, session=session - ) as report: - ####################################################################### - # - # Add event stats. - # - add_event_counts( - cfg=cfg, - report=report, - subject=subject, - session=session, - ) - - ####################################################################### - # - # Visualize evoked responses. - # - if all_evokeds: - msg = ( - f"Adding {len(all_evokeds)} evoked signals and contrasts to " - "the report." - ) - else: - msg = "No evoked conditions or contrasts found." - logger.info(**gen_log_kwargs(message=msg)) - for condition, evoked in all_evokeds.items(): - tags = ("evoked", _sanitize_cond_tag(condition)) - if condition in cfg.conditions: - title = f"Condition: {condition}" - else: # It's a contrast of two conditions. - title = f"Contrast: {condition}" - tags = tags + ("contrast",) - - report.add_evokeds( - evokeds=evoked, - titles=title, - projs=False, - tags=tags, - n_time_points=cfg.report_evoked_n_time_points, - # captions=evoked.comment, # TODO upstream - replace=True, - n_jobs=1, # don't auto parallelize - ) - - ####################################################################### - # - # Visualize decoding results. - # - if cfg.decode and cfg.decoding_contrasts: - msg = "Adding decoding results." - logger.info(**gen_log_kwargs(message=msg)) - add_decoding_grand_average(session=session, cfg=cfg, report=report) - - if cfg.decode and cfg.decoding_csp: - # No need for a separate message here because these are very quick - # and the general message above is sufficient - add_csp_grand_average(session=session, cfg=cfg, report=report) - - def run_report_average_source( *, cfg: SimpleNamespace, @@ -714,165 +615,8 @@ def run_report_average_source( ) -def add_decoding_grand_average( - *, - session: Optional[str], - cfg: SimpleNamespace, - report: mne.Report, -): - """Add decoding results to the grand average report.""" - import matplotlib.pyplot as plt # nested import to help joblib - - bids_path = BIDSPath( - subject="average", - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="ave", - extension=".fif", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - - # Full-epochs decoding - all_decoding_scores = [] - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f"{cond_1}+{cond_2}".replace(op.sep, "") - processing = f"{a_vs_b}+FullEpochs+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") - fname_decoding = bids_path.copy().update( - processing=processing, suffix="decoding", extension=".mat" - ) - decoding_data = loadmat(fname_decoding) - all_decoding_scores.append(np.atleast_1d(decoding_data["scores"].squeeze())) - del fname_decoding, processing, a_vs_b, decoding_data - - fig, caption = _plot_full_epochs_decoding_scores( - contrast_names=_contrasts_to_names(cfg.decoding_contrasts), - scores=all_decoding_scores, - metric=cfg.decoding_metric, - kind="grand-average", - ) - title = f"Full-epochs decoding: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - section="Decoding: full-epochs", - caption=caption, - tags=( - "epochs", - "contrast", - "decoding", - *[ - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}" - for cond_1, cond_2 in cfg.decoding_contrasts - ], - ), - replace=True, - ) - # close figure to save memory - plt.close(fig) - del fig, caption, title - - # Time-by-time decoding - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f"{cond_1}+{cond_2}".replace(op.sep, "") - section = "Decoding: time-by-time" - tags = ( - "epochs", - "contrast", - "decoding", - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - processing = f"{a_vs_b}+TimeByTime+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") - fname_decoding = bids_path.copy().update( - processing=processing, suffix="decoding", extension=".mat" - ) - decoding_data = loadmat(fname_decoding) - del fname_decoding, processing, a_vs_b - - # Plot scores - fig = _plot_time_by_time_decoding_scores_gavg( - cfg=cfg, - decoding_data=decoding_data, - ) - caption = ( - f'Based on N={decoding_data["N"].squeeze()} ' - f"subjects. Standard error and confidence interval " - f"of the mean were bootstrapped with {cfg.n_boot} " - f"resamples. CI must not be used for statistical inference here, " - f"as it is not corrected for multiple testing." - ) - if len(get_subjects(cfg)) > 1: - caption += ( - f" Time periods with decoding performance significantly above " - f"chance, if any, were derived with a one-tailed " - f"cluster-based permutation test " - f'({decoding_data["cluster_n_permutations"].squeeze()} ' - f"permutations) and are highlighted in yellow." - ) - title = f"Decoding over time: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - caption=caption, - section=section, - tags=tags, - replace=True, - ) - plt.close(fig) - - # Plot t-values used to form clusters - if len(get_subjects(cfg)) > 1: - fig = plot_time_by_time_decoding_t_values(decoding_data=decoding_data) - t_threshold = np.round(decoding_data["cluster_t_threshold"], 3).item() - caption = ( - f"Observed t-values. Time points with " - f"t-values > {t_threshold} were used to form clusters." - ) - report.add_figure( - fig=fig, - title=f"t-values across time: {cond_1} vs. {cond_2}", - caption=caption, - section=section, - tags=tags, - replace=True, - ) - plt.close(fig) - - if cfg.decoding_time_generalization: - fig = _plot_decoding_time_generalization( - decoding_data=decoding_data, - metric=cfg.decoding_metric, - kind="grand-average", - ) - caption = ( - f"Time generalization (generalization across time, GAT): " - f"each classifier is trained on each time point, and tested " - f"on all other time points. The results were averaged across " - f'N={decoding_data["N"].item()} subjects.' - ) - title = f"Time generalization: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - caption=caption, - section=section, - tags=tags, - replace=True, - ) - plt.close(fig) - - def _sanitize_cond_tag(cond): - return cond.lower().replace(" ", "-") + return str(cond).lower().replace(" ", "-") def _imshow_tf( @@ -913,27 +657,18 @@ def _imshow_tf( def add_csp_grand_average( *, - session: str, cfg: SimpleNamespace, + subject: str, + session: str, report: mne.Report, + cond_1: str, + cond_2: str, + fname_csp_freq_results: BIDSPath, + fname_csp_cluster_results: pd.DataFrame, ): """Add CSP decoding results to the grand average report.""" import matplotlib.pyplot as plt # nested import to help joblib - bids_path = BIDSPath( - subject="average", - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="decoding", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - # First, plot decoding scores across frequency bins (entire epochs). section = "Decoding: CSP" freq_name_to_bins_map = _handle_csp_args( @@ -941,242 +676,217 @@ def add_csp_grand_average( cfg.decoding_csp_freqs, cfg.decoding_metric, ) - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f"{cond_1}+{cond_2}".replace(op.sep, "") - processing = f"{a_vs_b}+CSP+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") - fname_csp_freq_results = bids_path.copy().update( - processing=processing, - extension=".xlsx", - ) - csp_freq_results = pd.read_excel( - fname_csp_freq_results, sheet_name="CSP Frequency" - ) - freq_bin_starts = list() - freq_bin_widths = list() - decoding_scores = list() - error_bars = list() - for freq_range_name, freq_bins in freq_name_to_bins_map.items(): - results = csp_freq_results.loc[ - csp_freq_results["freq_range_name"] == freq_range_name, : - ] - results.reset_index(drop=True, inplace=True) - assert len(results) == len(freq_bins) - for bi, freq_bin in enumerate(freq_bins): - freq_bin_starts.append(freq_bin[0]) - freq_bin_widths.append(np.diff(freq_bin)[0]) - decoding_scores.append(results["mean"][bi]) - cis_lower = results["mean_ci_lower"][bi] - cis_upper = results["mean_ci_upper"][bi] - error_bars_lower = decoding_scores[-1] - cis_lower - error_bars_upper = cis_upper - decoding_scores[-1] - error_bars.append(np.stack([error_bars_lower, error_bars_upper])) - assert len(error_bars[-1]) == 2 # lower, upper - del cis_lower, cis_upper, error_bars_lower, error_bars_upper - error_bars = np.array(error_bars, float).T - - if cfg.decoding_metric == "roc_auc": - metric = "ROC AUC" - - fig, ax = plt.subplots(constrained_layout=True) - ax.bar( - x=freq_bin_starts, - width=freq_bin_widths, - height=decoding_scores, - align="edge", - yerr=error_bars, - edgecolor="black", - ) - ax.set_ylim([0, 1.02]) - offset = matplotlib.transforms.offset_copy( - ax.transData, fig, 0, 5, units="points" - ) - for freq_range_name, freq_bins in freq_name_to_bins_map.items(): - start = freq_bins[0][0] - stop = freq_bins[-1][1] - width = stop - start - ax.text( - x=start + width / 2, - y=0.0, - transform=offset, - s=freq_range_name, - ha="center", - va="bottom", - ) - ax.axhline(0.5, color="black", linestyle="--", label="chance") - ax.legend() - ax.set_xlabel("Frequency (Hz)") - ax.set_ylabel(f"Mean decoding score ({metric})") - tags = ( - "epochs", - "contrast", - "decoding", - "csp", - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - title = f"CSP decoding: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - section=section, - caption="Mean decoding scores. Error bars represent " - "bootstrapped 95% confidence intervals.", - tags=tags, - replace=True, + freq_bin_starts = list() + freq_bin_widths = list() + decoding_scores = list() + error_bars = list() + csp_freq_results = pd.read_excel(fname_csp_freq_results, sheet_name="CSP Frequency") + for freq_range_name, freq_bins in freq_name_to_bins_map.items(): + results = csp_freq_results.loc[ + csp_freq_results["freq_range_name"] == freq_range_name, : + ] + results.reset_index(drop=True, inplace=True) + assert len(results) == len(freq_bins) + for bi, freq_bin in enumerate(freq_bins): + freq_bin_starts.append(freq_bin[0]) + freq_bin_widths.append(np.diff(freq_bin)[0]) + decoding_scores.append(results["mean"][bi]) + cis_lower = results["mean_ci_lower"][bi] + cis_upper = results["mean_ci_upper"][bi] + error_bars_lower = decoding_scores[-1] - cis_lower + error_bars_upper = cis_upper - decoding_scores[-1] + error_bars.append(np.stack([error_bars_lower, error_bars_upper])) + assert len(error_bars[-1]) == 2 # lower, upper + del cis_lower, cis_upper, error_bars_lower, error_bars_upper + error_bars = np.array(error_bars, float).T + + if cfg.decoding_metric == "roc_auc": + metric = "ROC AUC" + + fig, ax = plt.subplots(constrained_layout=True) + ax.bar( + x=freq_bin_starts, + width=freq_bin_widths, + height=decoding_scores, + align="edge", + yerr=error_bars, + edgecolor="black", + ) + ax.set_ylim([0, 1.02]) + offset = matplotlib.transforms.offset_copy(ax.transData, fig, 0, 5, units="points") + for freq_range_name, freq_bins in freq_name_to_bins_map.items(): + start = freq_bins[0][0] + stop = freq_bins[-1][1] + width = stop - start + ax.text( + x=start + width / 2, + y=0.0, + transform=offset, + s=freq_range_name, + ha="center", + va="bottom", ) + ax.axhline(0.5, color="black", linestyle="--", label="chance") + ax.legend() + ax.set_xlabel("Frequency (Hz)") + ax.set_ylabel(f"Mean decoding score ({metric})") + tags = ( + "epochs", + "contrast", + "decoding", + "csp", + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + title = f"CSP decoding: {cond_1} vs. {cond_2}" + report.add_figure( + fig=fig, + title=title, + section=section, + caption="Mean decoding scores. Error bars represent " + "bootstrapped 95% confidence intervals.", + tags=tags, + replace=True, + ) # Now, plot decoding scores across time-frequency bins. - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f"{cond_1}+{cond_2}".replace(op.sep, "") - processing = f"{a_vs_b}+CSP+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") - fname_csp_cluster_results = bids_path.copy().update( - processing=processing, - extension=".mat", + csp_cluster_results = loadmat(fname_csp_cluster_results) + fig, ax = plt.subplots( + nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True + ) + n_clu = 0 + cbar = None + lims = [np.inf, -np.inf, np.inf, -np.inf] + for freq_range_name, bins in freq_name_to_bins_map.items(): + results = csp_cluster_results[freq_range_name][0][0] + mean_crossval_scores = results["mean_crossval_scores"].ravel() + # t_vals = results['t_vals'] + clusters = results["clusters"] + cluster_p_vals = np.atleast_1d(results["cluster_p_vals"].squeeze()) + tmin = results["time_bin_edges"].ravel() + tmin, tmax = tmin[:-1], tmin[1:] + fmin = results["freq_bin_edges"].ravel() + fmin, fmax = fmin[:-1], fmin[1:] + lims[0] = min(lims[0], tmin.min()) + lims[1] = max(lims[1], tmax.max()) + lims[2] = min(lims[2], fmin.min()) + lims[3] = max(lims[3], fmax.max()) + # replicate, matching time-frequency order during clustering + fmin, fmax = np.tile(fmin, len(tmin)), np.tile(fmax, len(tmax)) + tmin, tmax = np.repeat(tmin, len(bins)), np.repeat(tmax, len(bins)) + assert fmin.shape == fmax.shape == tmin.shape == tmax.shape + assert fmin.shape == mean_crossval_scores.shape + cluster_t_threshold = results["cluster_t_threshold"].ravel().item() + + significant_cluster_idx = np.where( + cluster_p_vals < cfg.cluster_permutation_p_threshold + )[0] + significant_clusters = clusters[significant_cluster_idx] + n_clu += len(significant_cluster_idx) + + # XXX Add support for more metrics + assert cfg.decoding_metric == "roc_auc" + metric = "ROC AUC" + vmax = ( + max( + np.abs(mean_crossval_scores.min() - 0.5), + np.abs(mean_crossval_scores.max() - 0.5), + ) + + 0.5 ) - csp_cluster_results = loadmat(fname_csp_cluster_results) - - fig, ax = plt.subplots( - nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True + vmin = 0.5 - (vmax - 0.5) + # For diverging gray colormap, we need to combine two existing + # colormaps, as there is no diverging colormap with gray/black at + # both endpoints. + from matplotlib.cm import gray, gray_r + from matplotlib.colors import ListedColormap + + black_to_white = gray(np.linspace(start=0, stop=1, endpoint=False, num=128)) + white_to_black = gray_r(np.linspace(start=0, stop=1, endpoint=False, num=128)) + black_to_white_to_black = np.vstack((black_to_white, white_to_black)) + diverging_gray_cmap = ListedColormap( + black_to_white_to_black, name="DivergingGray" ) - n_clu = 0 - cbar = None - lims = [np.inf, -np.inf, np.inf, -np.inf] - for freq_range_name, bins in freq_name_to_bins_map.items(): - results = csp_cluster_results[freq_range_name][0][0] - mean_crossval_scores = results["mean_crossval_scores"].ravel() - # t_vals = results['t_vals'] - clusters = results["clusters"] - cluster_p_vals = np.atleast_1d(results["cluster_p_vals"].squeeze()) - tmin = results["time_bin_edges"].ravel() - tmin, tmax = tmin[:-1], tmin[1:] - fmin = results["freq_bin_edges"].ravel() - fmin, fmax = fmin[:-1], fmin[1:] - lims[0] = min(lims[0], tmin.min()) - lims[1] = max(lims[1], tmax.max()) - lims[2] = min(lims[2], fmin.min()) - lims[3] = max(lims[3], fmax.max()) - # replicate, matching time-frequency order during clustering - fmin, fmax = np.tile(fmin, len(tmin)), np.tile(fmax, len(tmax)) - tmin, tmax = np.repeat(tmin, len(bins)), np.repeat(tmax, len(bins)) - assert fmin.shape == fmax.shape == tmin.shape == tmax.shape - assert fmin.shape == mean_crossval_scores.shape - cluster_t_threshold = results["cluster_t_threshold"].ravel().item() - - significant_cluster_idx = np.where( - cluster_p_vals < cfg.cluster_permutation_p_threshold - )[0] - significant_clusters = clusters[significant_cluster_idx] - n_clu += len(significant_cluster_idx) - - # XXX Add support for more metrics - assert cfg.decoding_metric == "roc_auc" - metric = "ROC AUC" - vmax = ( - max( - np.abs(mean_crossval_scores.min() - 0.5), - np.abs(mean_crossval_scores.max() - 0.5), - ) - + 0.5 - ) - vmin = 0.5 - (vmax - 0.5) - # For diverging gray colormap, we need to combine two existing - # colormaps, as there is no diverging colormap with gray/black at - # both endpoints. - from matplotlib.cm import gray, gray_r - from matplotlib.colors import ListedColormap - - black_to_white = gray(np.linspace(start=0, stop=1, endpoint=False, num=128)) - white_to_black = gray_r( - np.linspace(start=0, stop=1, endpoint=False, num=128) - ) - black_to_white_to_black = np.vstack((black_to_white, white_to_black)) - diverging_gray_cmap = ListedColormap( - black_to_white_to_black, name="DivergingGray" - ) - cmap_gray = diverging_gray_cmap - img = _imshow_tf( - mean_crossval_scores, - ax[0], - tmin=tmin, - tmax=tmax, - fmin=fmin, - fmax=fmax, - vmin=vmin, - vmax=vmax, - ) - if cbar is None: - ax[0].set_xlabel("Time (s)") - ax[0].set_ylabel("Frequency (Hz)") - ax[1].set_xlabel("Time (s)") - cbar = fig.colorbar( - ax=ax[1], shrink=0.75, orientation="vertical", mappable=img - ) - cbar.set_label(f"Mean decoding score ({metric})") - offset = matplotlib.transforms.offset_copy( - ax[0].transData, fig, 6, 0, units="points" - ) - ax[0].text( - tmin.min(), - 0.5 * fmin.min() + 0.5 * fmax.max(), - freq_range_name, - transform=offset, - ha="left", - va="center", - rotation=90, + cmap_gray = diverging_gray_cmap + img = _imshow_tf( + mean_crossval_scores, + ax[0], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + vmin=vmin, + vmax=vmax, + ) + if cbar is None: + ax[0].set_xlabel("Time (s)") + ax[0].set_ylabel("Frequency (Hz)") + ax[1].set_xlabel("Time (s)") + cbar = fig.colorbar( + ax=ax[1], shrink=0.75, orientation="vertical", mappable=img ) + cbar.set_label(f"Mean decoding score ({metric})") + offset = matplotlib.transforms.offset_copy( + ax[0].transData, fig, 6, 0, units="points" + ) + ax[0].text( + tmin.min(), + 0.5 * fmin.min() + 0.5 * fmax.max(), + freq_range_name, + transform=offset, + ha="left", + va="center", + rotation=90, + ) - if len(significant_clusters): - # Create a masked array that only shows the T-values for - # time-frequency bins that belong to significant clusters. - if len(significant_clusters) == 1: - mask = ~significant_clusters[0].astype(bool) - else: - mask = ~np.logical_or(*significant_clusters) - mask = mask.ravel() + if len(significant_clusters): + # Create a masked array that only shows the T-values for + # time-frequency bins that belong to significant clusters. + if len(significant_clusters) == 1: + mask = ~significant_clusters[0].astype(bool) else: - mask = np.ones(mean_crossval_scores.shape, dtype=bool) - _imshow_tf( - mean_crossval_scores, - ax[1], - tmin=tmin, - tmax=tmax, - fmin=fmin, - fmax=fmax, - vmin=vmin, - vmax=vmax, - mask=mask, - cmap_masked=cmap_gray, - ) - - ax[0].set_xlim(lims[:2]) - ax[0].set_ylim(lims[2:]) - ax[0].set_title("Scores") - ax[1].set_title("Masked") - tags = ( - "epochs", - "contrast", - "decoding", - "csp", - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - title = f"CSP TF decoding: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - section=section, - caption=f"Found {n_clu} " - f"cluster{_pl(n_clu)} with " - f"p < {cfg.cluster_permutation_p_threshold} " - f"(clustering bins with absolute t-values > " - f"{round(cluster_t_threshold, 3)}).", - tags=tags, - replace=True, + mask = ~np.logical_or(*significant_clusters) + mask = mask.ravel() + else: + mask = np.ones(mean_crossval_scores.shape, dtype=bool) + _imshow_tf( + mean_crossval_scores, + ax[1], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + vmin=vmin, + vmax=vmax, + mask=mask, + cmap_masked=cmap_gray, ) + ax[0].set_xlim(lims[:2]) + ax[0].set_ylim(lims[2:]) + ax[0].set_title("Scores") + ax[1].set_title("Masked") + tags = ( + "epochs", + "contrast", + "decoding", + "csp", + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + title = f"CSP TF decoding: {cond_1} vs. {cond_2}" + report.add_figure( + fig=fig, + title=title, + section=section, + caption=f"Found {n_clu} " + f"cluster{_pl(n_clu)} with " + f"p < {cfg.cluster_permutation_p_threshold} " + f"(clustering bins with absolute t-values > " + f"{round(cluster_t_threshold, 3)}).", + tags=tags, + replace=True, + ) + @contextlib.contextmanager def _agg_backend(): diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index 7d7bf50f0..0c997e998 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -220,7 +220,9 @@ def wrapper(*args, **kwargs): emoji = "🔂" else: # Check our output file hashes - out_files_hashes = memorized_func(*args, **kwargs) + # Need to make a copy of kwargs["in_files"] in particular + use_kwargs = copy.deepcopy(kwargs) + out_files_hashes = memorized_func(*args, **use_kwargs) for key, (fname, this_hash) in out_files_hashes.items(): fname = pathlib.Path(fname) if not fname.exists(): diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index b126eecec..45bacb257 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -6,6 +6,7 @@ import os import os.path as op from collections import defaultdict +from functools import partial from typing import Optional, TypedDict, List, Tuple from types import SimpleNamespace @@ -27,21 +28,29 @@ from ..._decoding import _handle_csp_args from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func -from ..._run import failsafe_run, save_logs -from ..._report import run_report_average_sensor +from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits +from ..._report import ( + _open_report, + _sanitize_cond_tag, + add_event_counts, + add_csp_grand_average, + _plot_full_epochs_decoding_scores, + _plot_time_by_time_decoding_scores_gavg, + plot_time_by_time_decoding_t_values, + _plot_decoding_time_generalization, + _contrasts_to_names, +) -def average_evokeds( +def get_input_fnames_average_evokeds( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> List[mne.Evoked]: - # Container for all conditions: - all_evokeds = defaultdict(list) - + session: Optional[dict], +) -> dict: + in_files = dict() for this_subject in cfg.subjects: - fname_in = BIDSPath( + in_files[f"evoked-{this_subject}"] = BIDSPath( subject=this_subject, session=session, task=cfg.task, @@ -55,10 +64,29 @@ def average_evokeds( root=cfg.deriv_root, check=False, ) + return in_files - msg = f"Input: {fname_in.basename}" - logger.info(**gen_log_kwargs(message=msg)) +@failsafe_run( + get_input_fnames=get_input_fnames_average_evokeds, +) +def average_evokeds( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + in_files: dict, +) -> dict: + logger.info(**gen_log_kwargs(message="Creating grand averages")) + # Container for all conditions: + all_evokeds = defaultdict(list) + + keys = list(in_files) + for key in keys: + if not key.startswith("evoked-"): + continue + fname_in = in_files.pop(key) evokeds = mne.read_evokeds(fname_in) for idx, evoked in enumerate(evokeds): all_evokeds[idx].append(evoked) # Insert into the container @@ -70,7 +98,8 @@ def average_evokeds( # Keep condition in comment all_evokeds[idx].comment = "Grand average: " + evokeds[0].comment - fname_out = BIDSPath( + out_files = dict() + fname_out = out_files["evokeds"] = BIDSPath( subject=subject, session=session, task=cfg.task, @@ -91,8 +120,54 @@ def average_evokeds( msg = f"Saving grand-averaged evoked sensor data: {fname_out.basename}" logger.info(**gen_log_kwargs(message=msg)) - mne.write_evokeds(fname_out, list(all_evokeds.values()), overwrite=True) - return list(all_evokeds.values()) + evokeds = list(all_evokeds.values()) + mne.write_evokeds(fname_out, evokeds, overwrite=True) + if exec_params.interactive: + for evoked in evokeds: + evoked.plot() + + # Reporting + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + # Add event stats. + add_event_counts( + cfg=cfg, + report=report, + subject=subject, + session=session, + ) + + # Evoked responses + if all_evokeds: + msg = ( + f"Adding {len(all_evokeds)} evoked signals and contrasts to " + "the report." + ) + else: + msg = "No evoked conditions or contrasts found." + logger.info(**gen_log_kwargs(message=msg)) + for condition, evoked in all_evokeds.items(): + tags = ("evoked", _sanitize_cond_tag(condition)) + if condition in cfg.conditions: + title = f"Condition: {condition}" + else: # It's a contrast of two conditions. + title = f"Contrast: {condition}" + tags = tags + ("contrast",) + + report.add_evokeds( + evokeds=evoked, + titles=title, + projs=False, + tags=tags, + n_time_points=cfg.report_evoked_n_time_points, + # captions=evoked.comment, # TODO upstream + replace=True, + n_jobs=1, # don't auto parallelize + ) + + assert len(in_files) == 0, list(in_files) + return _prep_out_files(exec_params=exec_params, out_files=out_files) class ClusterAcrossTime(TypedDict): @@ -134,10 +209,14 @@ def _decoding_cluster_permutation_test( return t_vals, clusters, n_permutations -def average_time_by_time_decoding(cfg: SimpleNamespace, session: str): - # Get the time points from the very first subject. They are identical - # across all subjects and conditions, so this should suffice. - fname_epo = BIDSPath( +def _get_epochs_in_files( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: + in_files = dict() + in_files["epochs"] = BIDSPath( subject=cfg.subjects[0], session=session, task=cfg.task, @@ -151,276 +230,493 @@ def average_time_by_time_decoding(cfg: SimpleNamespace, session: str): root=cfg.deriv_root, check=False, ) - epochs = mne.read_epochs(fname_epo) + _update_for_splits(in_files, "epochs", single=True) + return in_files + + +def _decoding_out_fname( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + cond_1: str, + cond_2: str, + kind: str, + extension: str = ".mat", +): + processing = ( + f"{cond_1}+{cond_2}+{kind}+{cfg.decoding_metric}".replace(op.sep, "") + .replace("_", "-") + .replace("-", "") + ) + return BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + processing=processing, + suffix="decoding", + extension=extension, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + ) + + +def _get_input_fnames_decoding( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + cond_1: str, + cond_2: str, + kind: str, + extension: str = ".mat", +) -> dict: + in_files = _get_epochs_in_files(cfg=cfg, subject=subject, session=session) + for this_subject in cfg.subjects: + in_files[f"scores-{subject}"] = _decoding_out_fname( + cfg=cfg, + subject=this_subject, + session=session, + cond_1=cond_1, + cond_2=cond_2, + kind=kind, + extension=extension, + ) + return in_files + + +@failsafe_run( + get_input_fnames=partial( + _get_input_fnames_decoding, + kind="TimeByTime", + ), +) +def average_time_by_time_decoding( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + cond_1: str, + cond_2: str, + in_files: dict, +) -> dict: + logger.info(**gen_log_kwargs(message="Averaging time-by-time decoding results")) + # Get the time points from the very first subject. They are identical + # across all subjects and conditions, so this should suffice. + epochs = mne.read_epochs(in_files.pop("epochs"), preload=False) dtg_decim = cfg.decoding_time_generalization_decim if cfg.decoding_time_generalization and dtg_decim > 1: epochs.decimate(dtg_decim, verbose="error") times = epochs.times - subjects = cfg.subjects - del epochs, fname_epo + del epochs + + if cfg.decoding_time_generalization: + time_points_shape = (len(times), len(times)) + else: + time_points_shape = (len(times),) + + n_subjects = len(cfg.subjects) + contrast_score_stats = { + "cond_1": cond_1, + "cond_2": cond_2, + "times": times, + "N": n_subjects, + "decim": dtg_decim, + "mean": np.empty(time_points_shape), + "mean_min": np.empty(time_points_shape), + "mean_max": np.empty(time_points_shape), + "mean_se": np.empty(time_points_shape), + "mean_ci_lower": np.empty(time_points_shape), + "mean_ci_upper": np.empty(time_points_shape), + "cluster_all_times": np.array([]), + "cluster_all_t_values": np.array([]), + "cluster_t_threshold": np.nan, + "cluster_n_permutations": np.nan, + "clusters": list(), + } + + # Extract mean CV scores from all subjects. + mean_scores = np.empty((n_subjects, *time_points_shape)) + + # Remaining in_files are all decoding data + for sub_idx, key in enumerate(list(in_files)): + decoding_data = loadmat(in_files.pop(key)) + mean_scores[sub_idx, :] = decoding_data["scores"].mean(axis=0) + + # Cluster permutation test. + # We can only permute for two or more subjects + # + # If we've performed time generalization, we will only use the diagonal + # CV scores here (classifiers trained and tested at the same time + # points). + + if n_subjects > 1: + # Constrain cluster permutation test to time points of the + # time-locked event or later. + # We subtract the chance level from the scores as we'll be + # performing a 1-sample test (i.e., test against 0)! + idx = np.where(times >= 0)[0] - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast if cfg.decoding_time_generalization: - time_points_shape = (len(times), len(times)) + cluster_permutation_scores = mean_scores[:, idx, idx] - 0.5 else: - time_points_shape = (len(times),) - - contrast_score_stats = { - "cond_1": cond_1, - "cond_2": cond_2, - "times": times, - "N": len(subjects), - "decim": dtg_decim, - "mean": np.empty(time_points_shape), - "mean_min": np.empty(time_points_shape), - "mean_max": np.empty(time_points_shape), - "mean_se": np.empty(time_points_shape), - "mean_ci_lower": np.empty(time_points_shape), - "mean_ci_upper": np.empty(time_points_shape), - "cluster_all_times": np.array([]), - "cluster_all_t_values": np.array([]), - "cluster_t_threshold": np.nan, - "cluster_n_permutations": np.nan, - "clusters": list(), - } + cluster_permutation_scores = mean_scores[:, idx] - 0.5 - processing = ( - f"{cond_1}+{cond_2}+TimeByTime+{cfg.decoding_metric}".replace(op.sep, "") - .replace("_", "-") - .replace("-", "") + cluster_permutation_times = times[idx] + if cfg.cluster_forming_t_threshold is None: + import scipy.stats + + cluster_forming_t_threshold = scipy.stats.t.ppf( + 1 - 0.05, len(cluster_permutation_scores) - 1 + ) + else: + cluster_forming_t_threshold = cfg.cluster_forming_t_threshold + + t_vals, clusters, n_perm = _decoding_cluster_permutation_test( + scores=cluster_permutation_scores, + times=cluster_permutation_times, + cluster_forming_t_threshold=cluster_forming_t_threshold, + n_permutations=cfg.cluster_n_permutations, + random_seed=cfg.random_state, ) - # Extract mean CV scores from all subjects. - mean_scores = np.empty((len(subjects), *time_points_shape)) + contrast_score_stats.update( + { + "cluster_all_times": cluster_permutation_times, + "cluster_all_t_values": t_vals, + "cluster_t_threshold": cluster_forming_t_threshold, + "clusters": clusters, + "cluster_n_permutations": n_perm, + } + ) - for sub_idx, subject in enumerate(subjects): - fname_mat = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - processing=processing, - suffix="decoding", - extension=".mat", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) + del cluster_permutation_scores, cluster_permutation_times, n_perm + + # Now we can calculate some descriptive statistics on the mean scores. + # We use the [:] here as a safeguard to ensure we don't mess up the + # dimensions. + # + # For time generalization, all values (each time point vs each other) + # are considered. + contrast_score_stats["mean"][:] = mean_scores.mean(axis=0) + contrast_score_stats["mean_min"][:] = mean_scores.min(axis=0) + contrast_score_stats["mean_max"][:] = mean_scores.max(axis=0) + + # Finally, for each time point, bootstrap the mean, and calculate the + # SD of the bootstrapped distribution: this is the standard error of + # the mean. We also derive 95% confidence intervals. + rng = np.random.default_rng(seed=cfg.random_state) + for time_idx in range(len(times)): + if cfg.decoding_time_generalization: + data = mean_scores[:, time_idx, time_idx] + else: + data = mean_scores[:, time_idx] + scores_resampled = rng.choice(data, size=(cfg.n_boot, n_subjects), replace=True) + bootstrapped_means = scores_resampled.mean(axis=1) - decoding_data = loadmat(fname_mat) - mean_scores[sub_idx, :] = decoding_data["scores"].mean(axis=0) - - # Cluster permutation test. - # We can only permute for two or more subjects - # - # If we've performed time generalization, we will only use the diagonal - # CV scores here (classifiers trained and tested at the same time - # points). - - if len(subjects) > 1: - # Constrain cluster permutation test to time points of the - # time-locked event or later. - # We subtract the chance level from the scores as we'll be - # performing a 1-sample test (i.e., test against 0)! - idx = np.where(times >= 0)[0] - - if cfg.decoding_time_generalization: - cluster_permutation_scores = mean_scores[:, idx, idx] - 0.5 - else: - cluster_permutation_scores = mean_scores[:, idx] - 0.5 - - cluster_permutation_times = times[idx] - if cfg.cluster_forming_t_threshold is None: - import scipy.stats - - cluster_forming_t_threshold = scipy.stats.t.ppf( - 1 - 0.05, len(cluster_permutation_scores) - 1 - ) - else: - cluster_forming_t_threshold = cfg.cluster_forming_t_threshold - - t_vals, clusters, n_perm = _decoding_cluster_permutation_test( - scores=cluster_permutation_scores, - times=cluster_permutation_times, - cluster_forming_t_threshold=cluster_forming_t_threshold, - n_permutations=cfg.cluster_n_permutations, - random_seed=cfg.random_state, + # SD of the bootstrapped distribution == SE of the metric. + se = bootstrapped_means.std(ddof=1) + ci_lower = np.quantile(bootstrapped_means, q=0.025) + ci_upper = np.quantile(bootstrapped_means, q=0.975) + + contrast_score_stats["mean_se"][time_idx] = se + contrast_score_stats["mean_ci_lower"][time_idx] = ci_lower + contrast_score_stats["mean_ci_upper"][time_idx] = ci_upper + + del bootstrapped_means, se, ci_lower, ci_upper + + out_files = dict() + out_files["mat"] = _decoding_out_fname( + cfg=cfg, + subject=subject, + session=session, + cond_1=cond_1, + cond_2=cond_2, + kind="TimeByTime", + ) + savemat(out_files["mat"], contrast_score_stats) + + section = "Decoding: time-by-time" + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + logger.info(**gen_log_kwargs(message="Adding time-by-time decoding results")) + import matplotlib.pyplot as plt + + tags = ( + "epochs", + "contrast", + "decoding", + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + decoding_data = loadmat(out_files["mat"]) + + # Plot scores + fig = _plot_time_by_time_decoding_scores_gavg( + cfg=cfg, + decoding_data=decoding_data, + ) + caption = ( + f'Based on N={decoding_data["N"].squeeze()} ' + f"subjects. Standard error and confidence interval " + f"of the mean were bootstrapped with {cfg.n_boot} " + f"resamples. CI must not be used for statistical inference here, " + f"as it is not corrected for multiple testing." + ) + if len(get_subjects(cfg)) > 1: + caption += ( + f" Time periods with decoding performance significantly above " + f"chance, if any, were derived with a one-tailed " + f"cluster-based permutation test " + f'({decoding_data["cluster_n_permutations"].squeeze()} ' + f"permutations) and are highlighted in yellow." + ) + title = f"Decoding over time: {cond_1} vs. {cond_2}" + report.add_figure( + fig=fig, + title=title, + caption=caption, + section=section, + tags=tags, + replace=True, + ) + plt.close(fig) + + # Plot t-values used to form clusters + if len(get_subjects(cfg)) > 1: + fig = plot_time_by_time_decoding_t_values(decoding_data=decoding_data) + t_threshold = np.round(decoding_data["cluster_t_threshold"], 3).item() + caption = ( + f"Observed t-values. Time points with " + f"t-values > {t_threshold} were used to form clusters." ) + report.add_figure( + fig=fig, + title=f"t-values across time: {cond_1} vs. {cond_2}", + caption=caption, + section=section, + tags=tags, + replace=True, + ) + plt.close(fig) - contrast_score_stats.update( - { - "cluster_all_times": cluster_permutation_times, - "cluster_all_t_values": t_vals, - "cluster_t_threshold": cluster_forming_t_threshold, - "clusters": clusters, - "cluster_n_permutations": n_perm, - } + if cfg.decoding_time_generalization: + fig = _plot_decoding_time_generalization( + decoding_data=decoding_data, + metric=cfg.decoding_metric, + kind="grand-average", + ) + caption = ( + f"Time generalization (generalization across time, GAT): " + f"each classifier is trained on each time point, and tested " + f"on all other time points. The results were averaged across " + f'N={decoding_data["N"].item()} subjects.' + ) + title = f"Time generalization: {cond_1} vs. {cond_2}" + report.add_figure( + fig=fig, + title=title, + caption=caption, + section=section, + tags=tags, + replace=True, ) + plt.close(fig) - del cluster_permutation_scores, cluster_permutation_times, n_perm + return _prep_out_files(out_files=out_files, exec_params=exec_params) - # Now we can calculate some descriptive statistics on the mean scores. - # We use the [:] here as a safeguard to ensure we don't mess up the - # dimensions. - # - # For time generalization, all values (each time point vs each other) - # are considered. - contrast_score_stats["mean"][:] = mean_scores.mean(axis=0) - contrast_score_stats["mean_min"][:] = mean_scores.min(axis=0) - contrast_score_stats["mean_max"][:] = mean_scores.max(axis=0) - # Finally, for each time point, bootstrap the mean, and calculate the - # SD of the bootstrapped distribution: this is the standard error of - # the mean. We also derive 95% confidence intervals. - rng = np.random.default_rng(seed=cfg.random_state) - for time_idx in range(len(times)): - if cfg.decoding_time_generalization: - data = mean_scores[:, time_idx, time_idx] - else: - data = mean_scores[:, time_idx] - scores_resampled = rng.choice( - data, size=(cfg.n_boot, len(subjects)), replace=True - ) - bootstrapped_means = scores_resampled.mean(axis=1) - - # SD of the bootstrapped distribution == SE of the metric. - se = bootstrapped_means.std(ddof=1) - ci_lower = np.quantile(bootstrapped_means, q=0.025) - ci_upper = np.quantile(bootstrapped_means, q=0.975) - - contrast_score_stats["mean_se"][time_idx] = se - contrast_score_stats["mean_ci_lower"][time_idx] = ci_lower - contrast_score_stats["mean_ci_upper"][time_idx] = ci_upper - - del bootstrapped_means, se, ci_lower, ci_upper - - fname_out = fname_mat.copy().update(subject="average") - savemat(fname_out, contrast_score_stats) - del contrast_score_stats, fname_out - - -def average_full_epochs_decoding(cfg: SimpleNamespace, session: str): - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - n_subjects = len(cfg.subjects) - - contrast_score_stats = { - "cond_1": cond_1, - "cond_2": cond_2, - "N": n_subjects, - "subjects": cfg.subjects, - "scores": np.nan, - "mean": np.nan, - "mean_min": np.nan, - "mean_max": np.nan, - "mean_se": np.nan, - "mean_ci_lower": np.nan, - "mean_ci_upper": np.nan, - } +@failsafe_run( + get_input_fnames=partial( + _get_input_fnames_decoding, + kind="FullEpochs", + ), +) +def average_full_epochs_decoding( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + cond_1: str, + cond_2: str, + in_files: dict, +) -> dict: + n_subjects = len(cfg.subjects) + in_files.pop("epochs") # not used but okay to include + + contrast_score_stats = { + "cond_1": cond_1, + "cond_2": cond_2, + "N": n_subjects, + "subjects": cfg.subjects, + "scores": np.nan, + "mean": np.nan, + "mean_min": np.nan, + "mean_max": np.nan, + "mean_se": np.nan, + "mean_ci_lower": np.nan, + "mean_ci_upper": np.nan, + } - processing = ( - f"{cond_1}+{cond_2}+FullEpochs+{cfg.decoding_metric}".replace(op.sep, "") - .replace("_", "-") - .replace("-", "") - ) + # Extract mean CV scores from all subjects. + mean_scores = np.empty(n_subjects) + for sub_idx, key in enumerate(list(in_files)): + decoding_data = loadmat(in_files.pop(key)) + mean_scores[sub_idx] = decoding_data["scores"].mean() + + # Now we can calculate some descriptive statistics on the mean scores. + # We use the [:] here as a safeguard to ensure we don't mess up the + # dimensions. + contrast_score_stats["scores"] = mean_scores + contrast_score_stats["mean"] = mean_scores.mean() + contrast_score_stats["mean_min"] = mean_scores.min() + contrast_score_stats["mean_max"] = mean_scores.max() + + # Finally, bootstrap the mean, and calculate the + # SD of the bootstrapped distribution: this is the standard error of + # the mean. We also derive 95% confidence intervals. + rng = np.random.default_rng(seed=cfg.random_state) + scores_resampled = rng.choice( + mean_scores, size=(cfg.n_boot, n_subjects), replace=True + ) + bootstrapped_means = scores_resampled.mean(axis=1) - # Extract mean CV scores from all subjects. - mean_scores = np.empty(n_subjects) - for sub_idx, subject in enumerate(cfg.subjects): - fname_mat = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - processing=processing, - suffix="decoding", - extension=".mat", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) + # SD of the bootstrapped distribution == SE of the metric. + se = bootstrapped_means.std(ddof=1) + ci_lower = np.quantile(bootstrapped_means, q=0.025) + ci_upper = np.quantile(bootstrapped_means, q=0.975) - decoding_data = loadmat(fname_mat) - mean_scores[sub_idx] = decoding_data["scores"].mean() + contrast_score_stats["mean_se"] = se + contrast_score_stats["mean_ci_lower"] = ci_lower + contrast_score_stats["mean_ci_upper"] = ci_upper - # Now we can calculate some descriptive statistics on the mean scores. - # We use the [:] here as a safeguard to ensure we don't mess up the - # dimensions. - contrast_score_stats["scores"] = mean_scores - contrast_score_stats["mean"] = mean_scores.mean() - contrast_score_stats["mean_min"] = mean_scores.min() - contrast_score_stats["mean_max"] = mean_scores.max() + del bootstrapped_means, se, ci_lower, ci_upper - # Finally, bootstrap the mean, and calculate the - # SD of the bootstrapped distribution: this is the standard error of - # the mean. We also derive 95% confidence intervals. - rng = np.random.default_rng(seed=cfg.random_state) - scores_resampled = rng.choice( - mean_scores, size=(cfg.n_boot, n_subjects), replace=True - ) - bootstrapped_means = scores_resampled.mean(axis=1) + out_files = dict() + fname_out = out_files["mat"] = _decoding_out_fname( + cfg=cfg, + subject=subject, + session=session, + cond_1=cond_1, + cond_2=cond_2, + kind="FullEpochs", + ) + if not fname_out.fpath.parent.exists(): + os.makedirs(fname_out.fpath.parent) + savemat(fname_out, contrast_score_stats) + return _prep_out_files(out_files=out_files, exec_params=exec_params) - # SD of the bootstrapped distribution == SE of the metric. - se = bootstrapped_means.std(ddof=1) - ci_lower = np.quantile(bootstrapped_means, q=0.025) - ci_upper = np.quantile(bootstrapped_means, q=0.975) - contrast_score_stats["mean_se"] = se - contrast_score_stats["mean_ci_lower"] = ci_lower - contrast_score_stats["mean_ci_upper"] = ci_upper +def get_input_files_average_full_epochs_report( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + decoding_contrasts: List[List[str]], +) -> dict: + in_files = dict() + for contrast in decoding_contrasts: + in_files[f"decoding-full-epochs-{contrast}"] = _decoding_out_fname( + cfg=cfg, + subject=subject, + session=session, + cond_1=contrast[0], + cond_2=contrast[1], + kind="FullEpochs", + ) + return in_files - del bootstrapped_means, se, ci_lower, ci_upper - fname_out = fname_mat.copy().update(subject="average") - if not fname_out.fpath.parent.exists(): - os.makedirs(fname_out.fpath.parent) - savemat(fname_out, contrast_score_stats) - del contrast_score_stats, fname_out +@failsafe_run( + get_input_fnames=get_input_files_average_full_epochs_report, +) +def average_full_epochs_report( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: Optional[str], + decoding_contrasts: List[List[str]], + in_files: dict, +) -> dict: + """Add decoding results to the grand average report.""" + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + import matplotlib.pyplot as plt # nested import to help joblib + + logger.info( + **gen_log_kwargs(message="Adding full-epochs decoding results to report") + ) + + # Full-epochs decoding + all_decoding_scores = [] + for key in list(in_files): + if not key.startswith("decoding-full-epochs-"): + continue + decoding_data = loadmat(in_files.pop(key)) + all_decoding_scores.append(np.atleast_1d(decoding_data["scores"].squeeze())) + del decoding_data + + fig, caption = _plot_full_epochs_decoding_scores( + contrast_names=_contrasts_to_names(decoding_contrasts), + scores=all_decoding_scores, + metric=cfg.decoding_metric, + kind="grand-average", + ) + report.add_figure( + fig=fig, + title="Full-epochs decoding", + section="Decoding: full-epochs", + caption=caption, + tags=( + "epochs", + "contrast", + "decoding", + *[ + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}" + for cond_1, cond_2 in cfg.decoding_contrasts + ], + ), + replace=True, + ) + # close figure to save memory + plt.close(fig) + return _prep_out_files(exec_params=exec_params, out_files=dict()) +@failsafe_run( + get_input_fnames=partial( + _get_input_fnames_decoding, + kind="CSP", + extension=".xlsx", + ), +) def average_csp_decoding( + *, cfg: SimpleNamespace, - session: str, + exec_params: SimpleNamespace, subject: str, - condition_1: str, - condition_2: str, + session: Optional[str], + cond_1: str, + cond_2: str, + in_files: dict, ): - msg = f"Summarizing CSP results: {condition_1} - {condition_2}." + msg = f"Summarizing CSP results: {cond_1} - {cond_2}." logger.info(**gen_log_kwargs(message=msg)) - - # Extract mean CV scores from all subjects. - a_vs_b = f"{condition_1}+{condition_2}".replace(op.sep, "") - processing = f"{a_vs_b}+CSP+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") + in_files.pop("epochs") all_decoding_data_freq = [] all_decoding_data_time_freq = [] - - # First load the data. - fname_out = BIDSPath( - subject="average", - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - processing=processing, - suffix="decoding", - extension=".xlsx", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - for subject in cfg.subjects: - fname_xlsx = fname_out.copy().update(subject=subject) + for key in list(in_files): + fname_xlsx = in_files.pop(key) decoding_data_freq = pd.read_excel( fname_xlsx, sheet_name="CSP Frequency", @@ -438,14 +734,28 @@ def average_csp_decoding( # Now calculate descriptes and bootstrap CIs. grand_average_freq = _average_csp_time_freq( cfg=cfg, + subject=subject, + session=session, data=all_decoding_data_freq, ) grand_average_time_freq = _average_csp_time_freq( cfg=cfg, + subject=subject, + session=session, data=all_decoding_data_time_freq, ) - with pd.ExcelWriter(fname_out) as w: + out_files = dict() + out_files["freq"] = _decoding_out_fname( + cfg=cfg, + subject=subject, + session=session, + cond_1=cond_1, + cond_2=cond_2, + kind="CSP", + extension=".xlsx", + ) + with pd.ExcelWriter(out_files["freq"]) as w: grand_average_freq.to_excel(w, sheet_name="CSP Frequency", index=False) grand_average_time_freq.to_excel( w, sheet_name="CSP Time-Frequency", index=False @@ -526,20 +836,37 @@ def average_csp_decoding( "freq_bin_edges": cfg.decoding_csp_freqs[freq_range_name], } - fname_out.update(extension=".mat") - savemat(file_name=fname_out, mdict=cluster_permutation_results) + out_files["cluster"] = out_files["cluster"].copy().update(extension=".mat") + savemat(file_name=out_files["cluster"], mdict=cluster_permutation_results) + + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + add_csp_grand_average( + cfg=cfg, + subject=subject, + session=session, + report=report, + cond_1=cond_1, + cond_2=cond_2, + fname_csp_freq_results=out_files["freq"], + fname_csp_cluster_results=out_files["cluster"], + ) + return _prep_out_files(out_files=out_files, exec_params=exec_params) def _average_csp_time_freq( *, cfg: SimpleNamespace, + subject: str, + session: Optional[str], data: pd.DataFrame, ) -> pd.DataFrame: # Prepare a dataframe for storing the results. grand_average = data[0].copy() del grand_average["mean_crossval_score"] - grand_average["subject"] = "average" + grand_average["subject"] = subject grand_average["mean"] = np.nan grand_average["mean_se"] = np.nan grand_average["mean_ci_lower"] = np.nan @@ -628,67 +955,92 @@ def get_config( return cfg -@failsafe_run() -def run_group_average_sensor( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, -) -> None: - if cfg.task_is_rest: +def main(*, config: SimpleNamespace) -> None: + if config.task_is_rest: msg = ' … skipping: for "rest" task.' logger.info(**gen_log_kwargs(message=msg)) return - - sessions = get_sessions(cfg) - if not sessions: - sessions = [None] - + cfg = get_config( + config=config, + ) + exec_params = config.exec_params + subject = "average" + sessions = get_sessions(config=config) + if cfg.decode or cfg.decoding_csp: + decoding_contrasts = get_decoding_contrasts(config=cfg) + else: + decoding_contrasts = [] + logs = list() with get_parallel_backend(exec_params): - for session in sessions: - evokeds = average_evokeds( + # 1. Evoked data + logs += [ + average_evokeds( cfg=cfg, + exec_params=exec_params, subject=subject, session=session, ) - if exec_params.interactive: - for evoked in evokeds: - evoked.plot() + for session in sessions + ] + + # 2. Time decoding + if cfg.decode: + # Full epochs (single report function plots across all contrasts + # so it's a separate cached step) + logs += [ + average_full_epochs_decoding( + cfg=cfg, + subject=subject, + session=session, + cond_1=contrast[0], + cond_2=contrast[1], + exec_params=exec_params, + ) + for session in sessions + for contrast in decoding_contrasts + ] + logs += [ + average_full_epochs_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + decoding_contrasts=decoding_contrasts, + ) + for session in sessions + ] + # Time-by-time + parallel, run_func = parallel_func( + average_time_by_time_decoding, exec_params=exec_params + ) + logs += parallel( + run_func( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + cond_1=contrast[0], + cond_2=contrast[1], + ) + for session in sessions + for contrast in decoding_contrasts + ) - if cfg.decode: - average_full_epochs_decoding(cfg, session) - average_time_by_time_decoding(cfg, session) + # 3. CSP if cfg.decoding_csp: parallel, run_func = parallel_func( average_csp_decoding, exec_params=exec_params ) - parallel( + logs += parallel( run_func( cfg=cfg, - session=session, subject=subject, - condition_1=contrast[0], - condition_2=contrast[1], + session=session, + cond_1=contrast[0], + cond_2=contrast[1], ) - for session in get_sessions(config=cfg) for contrast in get_decoding_contrasts(config=cfg) + for session in sessions ) - for session in sessions: - run_report_average_sensor( - cfg=cfg, - exec_params=exec_params, - subject=subject, - session=session, - ) - - -def main(*, config: SimpleNamespace) -> None: - log = run_group_average_sensor( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject="average", - ) - save_logs(config=config, logs=[log]) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/tests/conftest.py b/mne_bids_pipeline/tests/conftest.py index 020d292d9..a4f6fe6eb 100644 --- a/mne_bids_pipeline/tests/conftest.py +++ b/mne_bids_pipeline/tests/conftest.py @@ -36,6 +36,9 @@ def pytest_configure(config): ignore:The \S+_cmap function was deprecated.*:DeprecationWarning # Dask distributed with jsonschema 4.18 ignore:jsonschema\.RefResolver is deprecated.*:DeprecationWarning + # seaborn->pandas + ignore:is_categorical_dtype is deprecated.*:FutureWarning + ignore:use_inf_as_na option is deprecated.*:FutureWarning """ for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() From e23018c873cfe4be746aa561c163e0aa77008874 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 07:58:36 -0400 Subject: [PATCH 02/15] FIX: Sensor working maybe --- .circleci/config.yml | 20 +++++----- .circleci/run_dataset_and_copy_files.sh | 38 +++++++++++++------ .../steps/sensor/_99_group_average.py | 6 ++- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1427dda16..6768ae3c1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -529,7 +529,7 @@ jobs: - data-cache-ds000248-4 - run: name: test BEM from FLASH - command: $RUN_TESTS ds000248_FLASH_BEM + command: $RUN_TESTS ds000248_FLASH_BEM -r - codecov/upload - store_test_results: path: ./test-results @@ -556,7 +556,7 @@ jobs: - run: name: test BEM from T1 (watershed) no_output_timeout: 20m - command: $RUN_TESTS ds000248_T1_BEM + command: $RUN_TESTS ds000248_T1_BEM -r - codecov/upload - store_test_results: path: ./test-results @@ -582,7 +582,7 @@ jobs: - data-cache-ds000248-4 - run: name: test head surface creation for MNE coregistration - command: $RUN_TESTS ds000248_coreg_surfaces ds000248_coreg_surfaces --no-copy + command: $RUN_TESTS ds000248_coreg_surfaces -c -r - codecov/upload - store_test_results: path: ./test-results @@ -772,7 +772,7 @@ jobs: google-chrome --version - run: name: test ERP CORE N400 - command: $RUN_TESTS ERP_CORE_N400 ERP_CORE + command: $RUN_TESTS ERP_CORE_N400 - codecov/upload - store_test_results: path: ./test-results @@ -802,7 +802,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE ERN - command: $RUN_TESTS ERP_CORE_ERN ERP_CORE + command: $RUN_TESTS ERP_CORE_ERN - codecov/upload - store_test_results: path: ./test-results @@ -832,7 +832,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE LRP - command: $RUN_TESTS ERP_CORE_LRP ERP_CORE + command: $RUN_TESTS ERP_CORE_LRP - codecov/upload - store_test_results: path: ./test-results @@ -862,7 +862,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE MMN - command: $RUN_TESTS ERP_CORE_MMN ERP_CORE + command: $RUN_TESTS ERP_CORE_MMN - codecov/upload - store_test_results: path: ./test-results @@ -892,7 +892,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE N2pc - command: $RUN_TESTS ERP_CORE_N2pc ERP_CORE + command: $RUN_TESTS ERP_CORE_N2pc - codecov/upload - store_test_results: path: ./test-results @@ -922,7 +922,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE N170 - command: $RUN_TESTS ERP_CORE_N170 ERP_CORE + command: $RUN_TESTS ERP_CORE_N170 - codecov/upload - store_test_results: path: ./test-results @@ -952,7 +952,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE P3 - command: $RUN_TESTS ERP_CORE_P3 ERP_CORE + command: $RUN_TESTS ERP_CORE_P3 - codecov/upload - store_test_results: path: ./test-results diff --git a/.circleci/run_dataset_and_copy_files.sh b/.circleci/run_dataset_and_copy_files.sh index 16f926481..3be835ac5 100755 --- a/.circleci/run_dataset_and_copy_files.sh +++ b/.circleci/run_dataset_and_copy_files.sh @@ -2,27 +2,43 @@ set -eo pipefail +COPY_FILES="true" +RERUN_TEST="true" +while getopts "cr" option; do + echo $option + case $option in + c) + COPY_FILES="false";; + r) + RERUN_TEST="false";; + esac +done +shift "$(($OPTIND -1))" + DS_RUN=$1 -if [[ "$2" == "" ]]; then - DS="$DS_RUN" -else - DS="$2" +if [[ -z $1 ]]; then + echo "Missing dataset argument" + exit 1 fi -if [[ "$3" == "--no-copy" ]]; then - COPY_FILES="false" +if [[ "$DS_RUN" == "ERP_CORE_"* ]]; then + DS="ERP_CORE" else - COPY_FILES="true" + DS="$1" fi SECONDS=0 pytest mne_bids_pipeline --junit-xml=test-results/junit-results.xml -k ${DS_RUN} echo "Runtime: ${SECONDS} seconds" -# rerun test! +# rerun test (check caching)! SECONDS=0 -pytest mne_bids_pipeline -k $DS_RUN -RUN_TIME=$SECONDS -echo "Runtime: ${RUN_TIME} seconds (should be < 10)" +if [[ "$RERUN_TEST" == "false" ]]; then + RUN_TIME=0 +else + pytest mne_bids_pipeline -k $DS_RUN + RUN_TIME=$SECONDS + echo "Runtime: ${RUN_TIME} seconds (should be < 10)" +fi test $RUN_TIME -lt 10 if [[ "$COPY_FILES" == "false" ]]; then diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index 45bacb257..b0920af2a 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -24,6 +24,7 @@ get_decoding_contrasts, get_all_contrasts, _bids_kwargs, + _restrict_analyze_channels, ) from ..._decoding import _handle_csp_args from ..._logging import gen_log_kwargs, logger @@ -127,6 +128,7 @@ def average_evokeds( evoked.plot() # Reporting + evokeds = [_restrict_analyze_channels(evoked, cfg) for evoked in evokeds] with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session ) as report: @@ -984,7 +986,7 @@ def main(*, config: SimpleNamespace) -> None: ] # 2. Time decoding - if cfg.decode: + if cfg.decode and decoding_contrasts: # Full epochs (single report function plots across all contrasts # so it's a separate cached step) logs += [ @@ -1027,7 +1029,7 @@ def main(*, config: SimpleNamespace) -> None: ) # 3. CSP - if cfg.decoding_csp: + if cfg.decoding_csp and decoding_contrasts: parallel, run_func = parallel_func( average_csp_decoding, exec_params=exec_params ) From ad2c258faa3ea457419687730f1268e96e62bd57 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 08:49:43 -0400 Subject: [PATCH 03/15] FIX --- .circleci/config.yml | 6 +++--- .circleci/run_dataset_and_copy_files.sh | 6 ++++-- mne_bids_pipeline/steps/sensor/_99_group_average.py | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 6768ae3c1..1fcddea70 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -529,7 +529,7 @@ jobs: - data-cache-ds000248-4 - run: name: test BEM from FLASH - command: $RUN_TESTS ds000248_FLASH_BEM -r + command: $RUN_TESTS -r ds000248_FLASH_BEM - codecov/upload - store_test_results: path: ./test-results @@ -556,7 +556,7 @@ jobs: - run: name: test BEM from T1 (watershed) no_output_timeout: 20m - command: $RUN_TESTS ds000248_T1_BEM -r + command: $RUN_TESTS -r ds000248_T1_BEM - codecov/upload - store_test_results: path: ./test-results @@ -582,7 +582,7 @@ jobs: - data-cache-ds000248-4 - run: name: test head surface creation for MNE coregistration - command: $RUN_TESTS ds000248_coreg_surfaces -c -r + command: $RUN_TESTS -c -r ds000248_coreg_surfaces - codecov/upload - store_test_results: path: ./test-results diff --git a/.circleci/run_dataset_and_copy_files.sh b/.circleci/run_dataset_and_copy_files.sh index 3be835ac5..2987cb449 100755 --- a/.circleci/run_dataset_and_copy_files.sh +++ b/.circleci/run_dataset_and_copy_files.sh @@ -33,15 +33,17 @@ echo "Runtime: ${SECONDS} seconds" # rerun test (check caching)! SECONDS=0 if [[ "$RERUN_TEST" == "false" ]]; then + echo "Skipping rerun test" RUN_TIME=0 else pytest mne_bids_pipeline -k $DS_RUN RUN_TIME=$SECONDS - echo "Runtime: ${RUN_TIME} seconds (should be < 10)" + echo "Runtime: ${RUN_TIME} seconds (should be < 20)" fi -test $RUN_TIME -lt 10 +test $RUN_TIME -lt 20 if [[ "$COPY_FILES" == "false" ]]; then + echo "Not copying files" exit 0 fi mkdir -p ~/reports/${DS} diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index b0920af2a..f284ffee4 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -838,7 +838,7 @@ def average_csp_decoding( "freq_bin_edges": cfg.decoding_csp_freqs[freq_range_name], } - out_files["cluster"] = out_files["cluster"].copy().update(extension=".mat") + out_files["cluster"] = out_files["freq"].copy().update(extension=".mat") savemat(file_name=out_files["cluster"], mdict=cluster_permutation_results) with _open_report( @@ -1036,6 +1036,7 @@ def main(*, config: SimpleNamespace) -> None: logs += parallel( run_func( cfg=cfg, + exec_params=exec_params, subject=subject, session=session, cond_1=contrast[0], From 1f4c91f517a3cb0ac8ebe11e5070e9662b16551b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 09:03:02 -0400 Subject: [PATCH 04/15] FIX: Check --- mne_bids_pipeline/tests/test_documented.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_bids_pipeline/tests/test_documented.py b/mne_bids_pipeline/tests/test_documented.py index 727dcf6bc..20cb37349 100644 --- a/mne_bids_pipeline/tests/test_documented.py +++ b/mne_bids_pipeline/tests/test_documented.py @@ -121,7 +121,7 @@ def test_datasets_in_doc(): n_found = len(pw.findall(circle_yaml_src)) assert n_found == this_count, f"{pw} ({n_found} != {this_count})" # jobs: test_*: steps: run test - cp = re.compile(rf" command: \$RUN_TESTS {name}.*") + cp = re.compile(rf" command: \$RUN_TESTS[ -rc]*{name}.*") n_found = len(cp.findall(circle_yaml_src)) assert n_found == count, f"{cp} ({n_found} != {count})" From 769a1c3cdff6fdb97262ff949c57f3d3ac6b12a0 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 09:03:27 -0400 Subject: [PATCH 05/15] FIX: Plus --- mne_bids_pipeline/tests/test_documented.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_bids_pipeline/tests/test_documented.py b/mne_bids_pipeline/tests/test_documented.py index 20cb37349..097fc1032 100644 --- a/mne_bids_pipeline/tests/test_documented.py +++ b/mne_bids_pipeline/tests/test_documented.py @@ -121,7 +121,7 @@ def test_datasets_in_doc(): n_found = len(pw.findall(circle_yaml_src)) assert n_found == this_count, f"{pw} ({n_found} != {this_count})" # jobs: test_*: steps: run test - cp = re.compile(rf" command: \$RUN_TESTS[ -rc]*{name}.*") + cp = re.compile(rf" command: \$RUN_TESTS[ -rc]+{name}.*") n_found = len(cp.findall(circle_yaml_src)) assert n_found == count, f"{cp} ({n_found} != {count})" From 04393b6bc26ae5de0dc0b68daa1ef2eb997e7eed Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 09:20:24 -0400 Subject: [PATCH 06/15] FIX: Codecov --- .circleci/config.yml | 32 ++++++++++++++++++++++---------- .circleci/setup_bash.sh | 2 +- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1fcddea70..5834f692c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -82,6 +82,7 @@ jobs: name: Get ds000117 command: | $DOWNLOAD_DATA ds000117 + - codecov/upload - save_cache: key: data-cache-ds000117-2 paths: @@ -118,6 +119,7 @@ jobs: name: Get ds001971 command: | $DOWNLOAD_DATA ds001971 + - codecov/upload - save_cache: key: data-cache-ds001971-2 paths: @@ -136,6 +138,7 @@ jobs: name: Get ds004107 command: | $DOWNLOAD_DATA ds004107 + - codecov/upload - save_cache: key: data-cache-ds004107-2 paths: @@ -154,6 +157,7 @@ jobs: name: Get ds000246 command: | $DOWNLOAD_DATA ds000246 + - codecov/upload - save_cache: key: data-cache-ds000246-2 paths: @@ -172,6 +176,7 @@ jobs: name: Get ds000247 command: | $DOWNLOAD_DATA ds000247 + - codecov/upload - save_cache: key: data-cache-ds000247-2 paths: @@ -190,6 +195,7 @@ jobs: name: Get ds000248 command: | $DOWNLOAD_DATA ds000248 + - codecov/upload - save_cache: key: data-cache-ds000248-4 paths: @@ -208,6 +214,7 @@ jobs: name: Get ds001810 command: | $DOWNLOAD_DATA ds001810 + - codecov/upload - save_cache: key: data-cache-ds001810-2 paths: @@ -226,6 +233,7 @@ jobs: name: Get ds003104 command: | $DOWNLOAD_DATA ds003104 + - codecov/upload - save_cache: key: data-cache-ds003104-2 paths: @@ -244,6 +252,7 @@ jobs: name: Get ds003392 command: | $DOWNLOAD_DATA ds003392 + - codecov/upload - save_cache: key: data-cache-ds003392-2 paths: @@ -262,6 +271,7 @@ jobs: name: Get ds004229 command: | $DOWNLOAD_DATA ds004229 + - codecov/upload - save_cache: key: data-cache-ds004229-2 paths: @@ -281,6 +291,7 @@ jobs: name: Get eeg_matchingpennies command: | $DOWNLOAD_DATA eeg_matchingpennies + - codecov/upload - save_cache: key: data-cache-eeg_matchingpennies-1 paths: @@ -299,6 +310,7 @@ jobs: name: Get ERP_CORE command: | $DOWNLOAD_DATA ERP_CORE + - codecov/upload - save_cache: key: data-cache-ERP_CORE-1 paths: @@ -529,7 +541,7 @@ jobs: - data-cache-ds000248-4 - run: name: test BEM from FLASH - command: $RUN_TESTS -r ds000248_FLASH_BEM + command: $RUN_TESTS ds000248_FLASH_BEM - codecov/upload - store_test_results: path: ./test-results @@ -556,7 +568,7 @@ jobs: - run: name: test BEM from T1 (watershed) no_output_timeout: 20m - command: $RUN_TESTS -r ds000248_T1_BEM + command: $RUN_TESTS ds000248_T1_BEM - codecov/upload - store_test_results: path: ./test-results @@ -582,7 +594,7 @@ jobs: - data-cache-ds000248-4 - run: name: test head surface creation for MNE coregistration - command: $RUN_TESTS -c -r ds000248_coreg_surfaces + command: $RUN_TESTS ds000248_coreg_surfaces ds000248_coreg_surfaces --no-copy - codecov/upload - store_test_results: path: ./test-results @@ -772,7 +784,7 @@ jobs: google-chrome --version - run: name: test ERP CORE N400 - command: $RUN_TESTS ERP_CORE_N400 + command: $RUN_TESTS ERP_CORE_N400 ERP_CORE - codecov/upload - store_test_results: path: ./test-results @@ -802,7 +814,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE ERN - command: $RUN_TESTS ERP_CORE_ERN + command: $RUN_TESTS ERP_CORE_ERN ERP_CORE - codecov/upload - store_test_results: path: ./test-results @@ -832,7 +844,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE LRP - command: $RUN_TESTS ERP_CORE_LRP + command: $RUN_TESTS ERP_CORE_LRP ERP_CORE - codecov/upload - store_test_results: path: ./test-results @@ -862,7 +874,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE MMN - command: $RUN_TESTS ERP_CORE_MMN + command: $RUN_TESTS ERP_CORE_MMN ERP_CORE - codecov/upload - store_test_results: path: ./test-results @@ -892,7 +904,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE N2pc - command: $RUN_TESTS ERP_CORE_N2pc + command: $RUN_TESTS ERP_CORE_N2pc ERP_CORE - codecov/upload - store_test_results: path: ./test-results @@ -922,7 +934,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE N170 - command: $RUN_TESTS ERP_CORE_N170 + command: $RUN_TESTS ERP_CORE_N170 ERP_CORE - codecov/upload - store_test_results: path: ./test-results @@ -952,7 +964,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE P3 - command: $RUN_TESTS ERP_CORE_P3 + command: $RUN_TESTS ERP_CORE_P3 ERP_CORE - codecov/upload - store_test_results: path: ./test-results diff --git a/.circleci/setup_bash.sh b/.circleci/setup_bash.sh index 073a45b77..ee44b317b 100755 --- a/.circleci/setup_bash.sh +++ b/.circleci/setup_bash.sh @@ -36,7 +36,7 @@ sudo ln -s /usr/lib/x86_64-linux-gnu/libxcb-util.so.0 /usr/lib/x86_64-linux-gnu/ wget -q -O- http://neuro.debian.net/lists/focal.us-tn.libre | sudo tee /etc/apt/sources.list.d/neurodebian.sources.list sudo apt-key adv --recv-keys --keyserver hkps://keyserver.ubuntu.com 0xA5D32F012649A5A9 echo "export RUN_TESTS=\".circleci/run_dataset_and_copy_files.sh\"" >> "$BASH_ENV" -echo "export DOWNLOAD_DATA=\"python -m mne_bids_pipeline._download\"" >> "$BASH_ENV" +echo "export DOWNLOAD_DATA=\"coverage run -m mne_bids_pipeline._download\"" >> "$BASH_ENV" # Similar CircleCI setup to mne-python (Xvfb, venv, minimal commands, env vars) wget -q https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/setup_xvfb.sh From 798728480d256b120e10dd0a6fd4555020a7ebae Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 09:21:02 -0400 Subject: [PATCH 07/15] FIX: Cover --- mne_bids_pipeline/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_bids_pipeline/__init__.py b/mne_bids_pipeline/__init__.py index 39d9f4177..2826b97e6 100644 --- a/mne_bids_pipeline/__init__.py +++ b/mne_bids_pipeline/__init__.py @@ -2,6 +2,6 @@ try: __version__ = version("mne_bids_pipeline") -except PackageNotFoundError: +except PackageNotFoundError: # pragma: no cover # package is not installed __version__ = "0.0.0" From c424515eb4dca5447d57e9d80b98b51b98b0ff60 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 09:26:41 -0400 Subject: [PATCH 08/15] FIX: n_jobs --- mne_bids_pipeline/tests/configs/config_ds000248_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_base.py b/mne_bids_pipeline/tests/configs/config_ds000248_base.py index 795ed618b..4110ecfe6 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_base.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_base.py @@ -44,7 +44,7 @@ def noise_cov(bp): bem_mri_images = "FLASH" recreate_bem = True -n_jobs = 2 +n_jobs = 1 def mri_t1_path_generator(bids_path): From 209d0047ec3c1e5a6c6447418a6a15f27f77fc77 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 09:27:14 -0400 Subject: [PATCH 09/15] FIX: Comment --- mne_bids_pipeline/tests/configs/config_ds000248_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_base.py b/mne_bids_pipeline/tests/configs/config_ds000248_base.py index 4110ecfe6..8a7155776 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_base.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_base.py @@ -44,6 +44,7 @@ def noise_cov(bp): bem_mri_images = "FLASH" recreate_bem = True +# use n_jobs=1 here to ensure that we get coverage for metadata_query, etc. n_jobs = 1 From ee2a7baa6862bf8db5b8ee95a933616ef0a7b945 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 11:19:31 -0400 Subject: [PATCH 10/15] FIX: More cache --- .circleci/config.yml | 20 +++++++++---------- .../steps/preprocessing/_07a_apply_ica.py | 4 +++- .../steps/sensor/_02_decoding_full_epochs.py | 4 ++-- .../steps/sensor/_03_decoding_time_by_time.py | 4 ++-- .../steps/sensor/_04_time_frequency.py | 2 +- .../steps/sensor/_99_group_average.py | 6 ++++-- .../tests/configs/config_ds001971.py | 10 ++++++++++ mne_bids_pipeline/tests/conftest.py | 3 +++ 8 files changed, 35 insertions(+), 18 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5834f692c..f64bfbd31 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -541,7 +541,7 @@ jobs: - data-cache-ds000248-4 - run: name: test BEM from FLASH - command: $RUN_TESTS ds000248_FLASH_BEM + command: $RUN_TESTS -r ds000248_FLASH_BEM - codecov/upload - store_test_results: path: ./test-results @@ -568,7 +568,7 @@ jobs: - run: name: test BEM from T1 (watershed) no_output_timeout: 20m - command: $RUN_TESTS ds000248_T1_BEM + command: $RUN_TESTS -r ds000248_T1_BEM - codecov/upload - store_test_results: path: ./test-results @@ -594,7 +594,7 @@ jobs: - data-cache-ds000248-4 - run: name: test head surface creation for MNE coregistration - command: $RUN_TESTS ds000248_coreg_surfaces ds000248_coreg_surfaces --no-copy + command: $RUN_TESTS -c -r ds000248_coreg_surfaces - codecov/upload - store_test_results: path: ./test-results @@ -784,7 +784,7 @@ jobs: google-chrome --version - run: name: test ERP CORE N400 - command: $RUN_TESTS ERP_CORE_N400 ERP_CORE + command: $RUN_TESTS ERP_CORE_N400 - codecov/upload - store_test_results: path: ./test-results @@ -814,7 +814,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE ERN - command: $RUN_TESTS ERP_CORE_ERN ERP_CORE + command: $RUN_TESTS ERP_CORE_ERN - codecov/upload - store_test_results: path: ./test-results @@ -844,7 +844,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE LRP - command: $RUN_TESTS ERP_CORE_LRP ERP_CORE + command: $RUN_TESTS ERP_CORE_LRP - codecov/upload - store_test_results: path: ./test-results @@ -874,7 +874,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE MMN - command: $RUN_TESTS ERP_CORE_MMN ERP_CORE + command: $RUN_TESTS ERP_CORE_MMN - codecov/upload - store_test_results: path: ./test-results @@ -904,7 +904,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE N2pc - command: $RUN_TESTS ERP_CORE_N2pc ERP_CORE + command: $RUN_TESTS ERP_CORE_N2pc - codecov/upload - store_test_results: path: ./test-results @@ -934,7 +934,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE N170 - command: $RUN_TESTS ERP_CORE_N170 ERP_CORE + command: $RUN_TESTS ERP_CORE_N170 - codecov/upload - store_test_results: path: ./test-results @@ -964,7 +964,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE P3 - command: $RUN_TESTS ERP_CORE_P3 ERP_CORE + command: $RUN_TESTS ERP_CORE_P3 - codecov/upload - store_test_results: path: ./test-results diff --git a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py index effc99f68..f36ce8e11 100644 --- a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py @@ -151,11 +151,13 @@ def apply_ica( assert len(in_files) == 0, in_files.keys() # Report + kwargs = dict() if ica.exclude: msg = "Adding ICA to report." else: msg = "Skipping ICA addition to report, no components marked as bad." - logger.info(**gen_log_kwargs(message=msg)) + kwargs["emoji"] = "skip" + logger.info(**gen_log_kwargs(message=msg, **kwargs)) if ica.exclude: with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session diff --git a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py index 8fdd863fe..287ace7bc 100644 --- a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py +++ b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py @@ -237,12 +237,12 @@ def main(*, config: SimpleNamespace) -> None: """Run time-by-time decoding.""" if not config.contrasts: msg = "No contrasts specified; not performing decoding." - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return if not config.decode: msg = "No decoding requested by user." - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return with get_parallel_backend(config.exec_params): diff --git a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py index 867b48030..f78e1d0cf 100644 --- a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py @@ -314,12 +314,12 @@ def main(*, config: SimpleNamespace) -> None: """Run time-by-time decoding.""" if not config.contrasts: msg = "No contrasts specified; not performing decoding." - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return if not config.decode: msg = "No decoding requested by user." - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return # Here we go parallel inside the :class:`mne.decoding.SlidingEstimator` diff --git a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py index f5b7c3381..a2e03f2c2 100644 --- a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py @@ -182,7 +182,7 @@ def main(*, config: SimpleNamespace) -> None: """Run Time-frequency decomposition.""" if not config.time_frequency_conditions: msg = "Skipping …" - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return parallel, run_func = parallel_func( diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index f284ffee4..04ac5daa7 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -788,9 +788,9 @@ def average_csp_decoding( ["subject", "freq_range_name", "t_min", "t_max"] ) - for (subject, freq_range_name, t_min, t_max), df in g: + for (subject_, freq_range_name, t_min, t_max), df in g: scores = df["mean_crossval_score"] - sub_idx = subjects.index(subject) + sub_idx = subjects.index(subject_) time_bin_idx = time_bins.loc[ (np.isclose(time_bins["t_min"], t_min)) & (np.isclose(time_bins["t_max"], t_max)), @@ -810,6 +810,7 @@ def average_csp_decoding( cluster_forming_t_threshold = cfg.cluster_forming_t_threshold cluster_permutation_results = {} + # TODO: Do something better when there is 1 subject for freq_range_name, X in data_for_clustering.items(): ( t_vals, @@ -841,6 +842,7 @@ def average_csp_decoding( out_files["cluster"] = out_files["freq"].copy().update(extension=".mat") savemat(file_name=out_files["cluster"], mdict=cluster_permutation_results) + assert subject == "average" with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session ) as report: diff --git a/mne_bids_pipeline/tests/configs/config_ds001971.py b/mne_bids_pipeline/tests/configs/config_ds001971.py index 2f3307c85..34fa812eb 100644 --- a/mne_bids_pipeline/tests/configs/config_ds001971.py +++ b/mne_bids_pipeline/tests/configs/config_ds001971.py @@ -13,6 +13,16 @@ ch_types = ["eeg"] reject = {"eeg": 150e-6} conditions = ["AdvanceTempo", "DelayTempo"] +contrasts = [("AdvanceTempo", "DelayTempo")] subjects = ["001"] runs = ["01"] +epochs_decim = 5 # to 100 Hz + +# This is mostly for testing purposes! +decode = False +decoding_csp = True +decoding_csp_freqs = { + "beta": [13, 20, 30], +} +decoding_csp_times = [-0.2, 0.0, 0.2, 0.4] diff --git a/mne_bids_pipeline/tests/conftest.py b/mne_bids_pipeline/tests/conftest.py index a4f6fe6eb..66d852b58 100644 --- a/mne_bids_pipeline/tests/conftest.py +++ b/mne_bids_pipeline/tests/conftest.py @@ -39,6 +39,9 @@ def pytest_configure(config): # seaborn->pandas ignore:is_categorical_dtype is deprecated.*:FutureWarning ignore:use_inf_as_na option is deprecated.*:FutureWarning + # TODO: Fix decoding clustering for n_subjects=1 + ignore:Degrees of freedom <= 0 for slice:RuntimeWarning + ignore:invalid value encountered in divide:RuntimeWarning """ for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() From c5a6e974ab030a6bccef1d14939e43e9e80be0fe Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 13:39:42 -0400 Subject: [PATCH 11/15] ENH: Allow extra specs --- mne_bids_pipeline/_config_import.py | 60 +++++++++++++------ mne_bids_pipeline/_logging.py | 4 ++ mne_bids_pipeline/_main.py | 4 +- mne_bids_pipeline/_parallel.py | 25 +++++++- mne_bids_pipeline/_run.py | 5 +- .../steps/sensor/_04_time_frequency.py | 10 ++-- .../steps/sensor/_99_group_average.py | 33 +++++----- .../tests/configs/config_ds000248_base.py | 5 +- .../tests/configs/config_ds001971.py | 4 +- mne_bids_pipeline/tests/conftest.py | 3 - mne_bids_pipeline/tests/test_run.py | 41 +++++++++---- 11 files changed, 129 insertions(+), 65 deletions(-) diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index ca507ba02..06c1d2e75 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -36,6 +36,17 @@ def _import_config( log=log, ) + extra_exec_params_keys = () + extra_config = os.getenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", "") + if extra_config: + msg = f"With testing config: {extra_config}" + logger.info(**gen_log_kwargs(message=msg, emoji="🧪")) + _update_config_from_path( + config=config, + config_path=extra_config, + ) + extra_exec_params_keys = ("_n_jobs",) + # Check it if check: _check_config(config) @@ -69,7 +80,7 @@ def _import_config( # Misc "deriv_root", "config_path", - ) + ) + extra_exec_params_keys in_both = {"deriv_root"} exec_params = SimpleNamespace(**{k: getattr(config, k) for k in keys}) for k in keys: @@ -102,6 +113,32 @@ def _get_default_config(): return config +def _update_config_from_path( + *, + config: SimpleNamespace, + config_path: PathLike, +): + user_names = list() + config_path = pathlib.Path(config_path).expanduser().resolve(strict=True) + # Import configuration from an arbitrary path without having to fiddle + # with `sys.path`. + spec = importlib.util.spec_from_file_location( + name="custom_config", location=config_path + ) + custom_cfg = importlib.util.module_from_spec(spec) + spec.loader.exec_module(custom_cfg) + for key in dir(custom_cfg): + if not key.startswith("__"): + # don't validate private vars, but do add to config + # (e.g., so that our hidden _raw_split_size is included) + if not key.startswith("_"): + user_names.append(key) + val = getattr(custom_cfg, key) + logger.debug("Overwriting: %s -> %s" % (key, val)) + setattr(config, key, val) + return user_names + + def _update_with_user_config( *, config: SimpleNamespace, # modified in-place @@ -121,23 +158,12 @@ def _update_with_user_config( # 2. User config user_names = list() if config_path is not None: - config_path = pathlib.Path(config_path).expanduser().resolve(strict=True) - # Import configuration from an arbitrary path without having to fiddle - # with `sys.path`. - spec = importlib.util.spec_from_file_location( - name="custom_config", location=config_path + user_names.extend( + _update_config_from_path( + config=config, + config_path=config_path, + ) ) - custom_cfg = importlib.util.module_from_spec(spec) - spec.loader.exec_module(custom_cfg) - for key in dir(custom_cfg): - if not key.startswith("__"): - # don't validate private vars, but do add to config - # (e.g., so that our hidden _raw_split_size is included) - if not key.startswith("_"): - user_names.append(key) - val = getattr(custom_cfg, key) - logger.debug("Overwriting: %s -> %s" % (key, val)) - setattr(config, key, val) config.config_path = config_path # 3. Overrides via command-line switches diff --git a/mne_bids_pipeline/_logging.py b/mne_bids_pipeline/_logging.py index 56eea901f..6bcb21d73 100644 --- a/mne_bids_pipeline/_logging.py +++ b/mne_bids_pipeline/_logging.py @@ -154,3 +154,7 @@ def gen_log_kwargs( def _linkfile(uri): return f"[link=file://{uri}]{uri}[/link]" + + +def _is_testing() -> bool: + return os.getenv("_MNE_BIDS_STUDY_TESTING", "") == "true" diff --git a/mne_bids_pipeline/_main.py b/mne_bids_pipeline/_main.py index 3c54a177d..7e5febd5a 100755 --- a/mne_bids_pipeline/_main.py +++ b/mne_bids_pipeline/_main.py @@ -196,12 +196,12 @@ def main(): logger.title("Welcome aboard MNE-BIDS-Pipeline! 👋") msg = f"Using configuration: {config}" logger.info(**gen_log_kwargs(message=msg, emoji="📝")) - logger.end() - config_imported = _import_config( config_path=config_path, overrides=overrides, ) + logger.end() + for step_module in step_modules: start = time.time() step = _short_step_path(pathlib.Path(step_module.__file__)) diff --git a/mne_bids_pipeline/_parallel.py b/mne_bids_pipeline/_parallel.py index c2f9430ae..bf62a6df6 100644 --- a/mne_bids_pipeline/_parallel.py +++ b/mne_bids_pipeline/_parallel.py @@ -5,7 +5,7 @@ import joblib -from ._logging import logger +from ._logging import logger, gen_log_kwargs, _is_testing def get_n_jobs(*, exec_params: SimpleNamespace) -> int: @@ -14,6 +14,12 @@ def get_n_jobs(*, exec_params: SimpleNamespace) -> int: n_cores = joblib.cpu_count() n_jobs = min(n_cores + n_jobs + 1, n_cores) + # Shim to allow overriding n_jobs for specific steps + if _is_testing() and hasattr(exec_params, "_n_jobs"): + from ._run import _get_step_path, _short_step_path + + step_path = _short_step_path(_get_step_path()) + n_jobs = exec_params._n_jobs.get(step_path, n_jobs) return n_jobs @@ -82,17 +88,30 @@ def get_parallel_backend_name( exec_params.parallel_backend == "loky" or get_n_jobs(exec_params=exec_params) == 1 ): - return "loky" + backend = "loky" elif exec_params.parallel_backend == "dask": # Disable interactive plotting backend import matplotlib matplotlib.use("Agg") - return "dask" + backend = "dask" else: # TODO: Move to value validation step raise ValueError(f"Unknown parallel backend: {exec_params.parallel_backend}") + # Shim to allow changing the backend on a per-step basis for testing + if _is_testing() and hasattr(exec_params, "_parallel_backend"): + from ._run import _get_step_path, _short_step_path + + step_path = _short_step_path(_get_step_path()) + old_backend = backend + backend = exec_params._parallel_backend.get(step_path, backend) + msg = f"Overriding parallel backend {old_backend}→{backend}" + logger.info(**gen_log_kwargs(message=msg)) + raise RuntimeError(backend) + + return backend + def get_parallel_backend(exec_params: SimpleNamespace) -> joblib.parallel_backend: import joblib diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index 0e02c8235..db0dcdc23 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -4,7 +4,6 @@ import functools import hashlib import inspect -import os import pathlib import pdb import sys @@ -20,7 +19,7 @@ from mne_bids import BIDSPath from ._config_utils import get_task -from ._logging import logger, gen_log_kwargs +from ._logging import logger, gen_log_kwargs, _is_testing def failsafe_run( @@ -85,7 +84,7 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): if on_error == "abort": message += f"\n\nAborting pipeline run. The traceback is:\n\n{tb}" - if os.getenv("_MNE_BIDS_STUDY_TESTING", "") == "true": + if _is_testing(): raise logger.error( **gen_log_kwargs(message=message, **kwargs_copy, emoji="❌") diff --git a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py index a2e03f2c2..0e7b9b720 100644 --- a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py @@ -69,11 +69,12 @@ def run_time_frequency( ) -> dict: import matplotlib.pyplot as plt - msg = f'Input: {in_files["epochs"].basename}' + epochs_path = in_files.pop("epochs") + msg = f"Reading {epochs_path.basename}" logger.info(**gen_log_kwargs(message=msg)) - bids_path = in_files["epochs"].copy().update(processing=None) - - epochs = mne.read_epochs(in_files.pop("epochs")) + epochs = mne.read_epochs(epochs_path) + bids_path = epochs_path.copy().update(processing=None) + del epochs_path _restrict_analyze_channels(epochs, cfg) if cfg.time_frequency_subtract_evoked: @@ -87,6 +88,7 @@ def run_time_frequency( out_files = dict() for condition in cfg.time_frequency_conditions: + logger.info(**gen_log_kwargs(message=f"Computing TFR for {condition}")) this_epochs = epochs[condition] power, itc = mne.time_frequency.tfr_morlet( this_epochs, freqs=freqs, return_itc=True, n_cycles=time_frequency_cycles diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index 04ac5daa7..196e3076d 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -810,22 +810,25 @@ def average_csp_decoding( cluster_forming_t_threshold = cfg.cluster_forming_t_threshold cluster_permutation_results = {} - # TODO: Do something better when there is 1 subject for freq_range_name, X in data_for_clustering.items(): - ( - t_vals, - all_clusters, - cluster_p_vals, - H0, - ) = mne.stats.permutation_cluster_1samp_test( # noqa: E501 - X=X - 0.5, # One-sample test against zero. - threshold=cluster_forming_t_threshold, - n_permutations=cfg.cluster_n_permutations, - adjacency=None, # each time & freq bin connected to its neighbors - out_type="mask", - tail=1, # one-sided: significantly above chance level - seed=cfg.random_state, - ) + if len(X) < 2: + t_vals = np.full(X.shape[1:], np.nan) + H0 = all_clusters = cluster_p_vals = np.array([]) + else: + ( + t_vals, + all_clusters, + cluster_p_vals, + H0, + ) = mne.stats.permutation_cluster_1samp_test( # noqa: E501 + X=X - 0.5, # One-sample test against zero. + threshold=cluster_forming_t_threshold, + n_permutations=cfg.cluster_n_permutations, + adjacency=None, # each time & freq bin connected to its neighbors + out_type="mask", + tail=1, # one-sided: significantly above chance level + seed=cfg.random_state, + ) n_permutations = H0.size - 1 all_clusters = np.array(all_clusters) # preserve "empty" 0th dimension cluster_permutation_results[freq_range_name] = { diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_base.py b/mne_bids_pipeline/tests/configs/config_ds000248_base.py index 8a7155776..22f29b35f 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_base.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_base.py @@ -21,8 +21,6 @@ find_flat_channels_meg = True find_noisy_channels_meg = True use_maxwell_filter = True -_raw_split_size = "60MB" # hits both task-noise and task-audiovisual -_epochs_split_size = "30MB" def noise_cov(bp): @@ -44,8 +42,7 @@ def noise_cov(bp): bem_mri_images = "FLASH" recreate_bem = True -# use n_jobs=1 here to ensure that we get coverage for metadata_query, etc. -n_jobs = 1 +n_jobs = 2 def mri_t1_path_generator(bids_path): diff --git a/mne_bids_pipeline/tests/configs/config_ds001971.py b/mne_bids_pipeline/tests/configs/config_ds001971.py index 34fa812eb..c78d3c858 100644 --- a/mne_bids_pipeline/tests/configs/config_ds001971.py +++ b/mne_bids_pipeline/tests/configs/config_ds001971.py @@ -20,7 +20,9 @@ epochs_decim = 5 # to 100 Hz # This is mostly for testing purposes! -decode = False +decode = True +decoding_time_generalization = True +decoding_time_generalization_decim = 2 decoding_csp = True decoding_csp_freqs = { "beta": [13, 20, 30], diff --git a/mne_bids_pipeline/tests/conftest.py b/mne_bids_pipeline/tests/conftest.py index 66d852b58..a4f6fe6eb 100644 --- a/mne_bids_pipeline/tests/conftest.py +++ b/mne_bids_pipeline/tests/conftest.py @@ -39,9 +39,6 @@ def pytest_configure(config): # seaborn->pandas ignore:is_categorical_dtype is deprecated.*:FutureWarning ignore:use_inf_as_na option is deprecated.*:FutureWarning - # TODO: Fix decoding clustering for n_subjects=1 - ignore:Degrees of freedom <= 0 for slice:RuntimeWarning - ignore:invalid value encountered in divide:RuntimeWarning """ for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() diff --git a/mne_bids_pipeline/tests/test_run.py b/mne_bids_pipeline/tests/test_run.py index 593a5968e..23a92e51e 100644 --- a/mne_bids_pipeline/tests/test_run.py +++ b/mne_bids_pipeline/tests/test_run.py @@ -32,6 +32,7 @@ class _TestOptionsT(TypedDict, total=False): # 'steps': ('preprocessing', 'sensor'), # 'env': {}, # 'task': None, +# "extra_config": "", # } # TEST_SUITE: Dict[str, _TestOptionsT] = { @@ -56,6 +57,12 @@ class _TestOptionsT(TypedDict, total=False): }, "ds000248_base": { "steps": ("preprocessing", "sensor", "source"), + "extra_config": """ +_raw_split_size = "60MB" # hits both task-noise and task-audiovisual +_epochs_split_size = "30MB" +# use n_jobs=1 here to ensure that we get coverage for metadata_query +_n_jobs = {"preprocessing/_05_make_epochs": 1} +""", }, "ds000248_ica": {}, "ds000248_T1_BEM": { @@ -85,6 +92,17 @@ class _TestOptionsT(TypedDict, total=False): "dataset": "ERP_CORE", "config": "config_ERP_CORE.py", "task": "ERN", + "extra_config": """ +# use n_jobs = 1 with loky to ensure that the CSP steps get proper coverage +_n_jobs = { + "sensor/_05_decoding_csp": 1, + "sensor/_99_group_average": 1, +} +_parallel_backend = { + "sensor/_05_decoding_csp": "loky", + "sensor/_99_group_average": "loky", +} +""", }, "ERP_CORE_LRP": { "dataset": "ERP_CORE", @@ -129,33 +147,30 @@ def dataset_test(request): @pytest.mark.dataset_test @pytest.mark.parametrize("dataset", list(TEST_SUITE)) -def test_run(dataset, monkeypatch, dataset_test, capsys): +def test_run(dataset, monkeypatch, dataset_test, capsys, tmp_path): """Test running a dataset.""" test_options = TEST_SUITE[dataset] - - # export the environment variables - monkeypatch.setenv("DATASET", dataset) - for key, value in test_options.get("env", {}).items(): - monkeypatch.setenv(key, value) - config = test_options.get("config", f"config_{dataset}.py") config_path = BIDS_PIPELINE_DIR / "tests" / "configs" / config + extra_config = TEST_SUITE[dataset].get("extra_config", "") + if extra_config: + extra_path = tmp_path / "extra_config.py" + extra_path.write_text(extra_config) + monkeypatch.setenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", str(extra_path)) # XXX Workaround for buggy date in ds000247. Remove this and the # XXX file referenced here once fixed!!! fix_path = Path(__file__).parent if dataset == "ds000247": - shutil.copy( - src=fix_path / "ds000247_scans.tsv", - dst=Path( - "~/mne_data/ds000247/sub-0002/ses-01/" "sub-0002_ses-01_scans.tsv" - ).expanduser(), + dst = ( + DATA_DIR / "ds000247" / "sub-0002" / "ses-01" / "sub-0002_ses-01_scans.tsv" ) + shutil.copy(src=fix_path / "ds000247_scans.tsv", dst=dst) # XXX Workaround for buggy participant_id in ds001971 elif dataset == "ds001971": shutil.copy( src=fix_path / "ds001971_participants.tsv", - dst=Path("~/mne_data/ds001971/participants.tsv").expanduser(), + dst=DATA_DIR / "ds001971" / "participants.tsv", ) # Run the tests. From 7587e6978a101ed4711f6ccb7cd713d078115719 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 14:52:37 -0400 Subject: [PATCH 12/15] ENH: Sensor working I think --- mne_bids_pipeline/_config_import.py | 2 +- mne_bids_pipeline/_main.py | 6 +++ mne_bids_pipeline/_parallel.py | 39 +++++++++---------- mne_bids_pipeline/_run.py | 4 +- .../steps/sensor/_99_group_average.py | 8 ++-- mne_bids_pipeline/tests/test_run.py | 4 -- 6 files changed, 34 insertions(+), 29 deletions(-) diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 06c1d2e75..36568b1f2 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -40,7 +40,7 @@ def _import_config( extra_config = os.getenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", "") if extra_config: msg = f"With testing config: {extra_config}" - logger.info(**gen_log_kwargs(message=msg, emoji="🧪")) + logger.info(**gen_log_kwargs(message=msg, emoji="override")) _update_config_from_path( config=config, config_path=extra_config, diff --git a/mne_bids_pipeline/_main.py b/mne_bids_pipeline/_main.py index 7e5febd5a..cd9ee1037 100755 --- a/mne_bids_pipeline/_main.py +++ b/mne_bids_pipeline/_main.py @@ -11,6 +11,7 @@ from ._config_import import _import_config from ._config_template import create_template_config from ._logging import logger, gen_log_kwargs +from ._parallel import get_parallel_backend from ._run import _short_step_path @@ -195,11 +196,16 @@ def main(): logger.title("Welcome aboard MNE-BIDS-Pipeline! 👋") msg = f"Using configuration: {config}" + __mne_bids_pipeline_step__ = pathlib.Path(__file__) # used for logging logger.info(**gen_log_kwargs(message=msg, emoji="📝")) config_imported = _import_config( config_path=config_path, overrides=overrides, ) + # Initialize dask now + with get_parallel_backend(config_imported.exec_params): + pass + del __mne_bids_pipeline_step__ logger.end() for step_module in step_modules: diff --git a/mne_bids_pipeline/_parallel.py b/mne_bids_pipeline/_parallel.py index bf62a6df6..12f70cf57 100644 --- a/mne_bids_pipeline/_parallel.py +++ b/mne_bids_pipeline/_parallel.py @@ -8,7 +8,7 @@ from ._logging import logger, gen_log_kwargs, _is_testing -def get_n_jobs(*, exec_params: SimpleNamespace) -> int: +def get_n_jobs(*, exec_params: SimpleNamespace, log_override: bool = False) -> int: n_jobs = exec_params.n_jobs if n_jobs < 0: n_cores = joblib.cpu_count() @@ -19,7 +19,11 @@ def get_n_jobs(*, exec_params: SimpleNamespace) -> int: from ._run import _get_step_path, _short_step_path step_path = _short_step_path(_get_step_path()) + orig_n_jobs = n_jobs n_jobs = exec_params._n_jobs.get(step_path, n_jobs) + if log_override and n_jobs != orig_n_jobs: + msg = f"Overriding n_jobs: {orig_n_jobs}→{n_jobs}" + logger.info(**gen_log_kwargs(message=msg, emoji="override")) return n_jobs @@ -36,14 +40,16 @@ def setup_dask_client(*, exec_params: SimpleNamespace) -> None: return n_workers = get_n_jobs(exec_params=exec_params) - logger.info(f"👾 Initializing Dask client with {n_workers} workers …") + msg = f"Dask initializing with {n_workers} workers …" + logger.info(**gen_log_kwargs(message=msg, emoji="👾")) if exec_params.dask_temp_dir is None: this_dask_temp_dir = exec_params.deriv_root / ".dask-worker-space" else: this_dask_temp_dir = exec_params.dask_temp_dir - logger.info(f"📂 Temporary directory is: {this_dask_temp_dir}") + msg = f"Dask temporary directory: {this_dask_temp_dir}" + logger.info(**gen_log_kwargs(message=msg, emoji="📂")) dask.config.set( { "temporary-directory": this_dask_temp_dir, @@ -67,10 +73,8 @@ def setup_dask_client(*, exec_params: SimpleNamespace) -> None: client.auto_restart = False # don't restart killed workers dashboard_url = client.dashboard_link - logger.info( - f"⏱ The Dask client is ready. Open {dashboard_url} " - f"to monitor the workers.\n" - ) + msg = "Dask client dashboard: " f"[link={dashboard_url}]{dashboard_url}[/link]" + logger.info(**gen_log_kwargs(message=msg, emoji="🌎")) if exec_params.dask_open_dashboard: import webbrowser @@ -82,7 +86,8 @@ def setup_dask_client(*, exec_params: SimpleNamespace) -> None: def get_parallel_backend_name( - *, exec_params: SimpleNamespace + *, + exec_params: SimpleNamespace, ) -> Literal["dask", "loky"]: if ( exec_params.parallel_backend == "loky" @@ -99,17 +104,6 @@ def get_parallel_backend_name( # TODO: Move to value validation step raise ValueError(f"Unknown parallel backend: {exec_params.parallel_backend}") - # Shim to allow changing the backend on a per-step basis for testing - if _is_testing() and hasattr(exec_params, "_parallel_backend"): - from ._run import _get_step_path, _short_step_path - - step_path = _short_step_path(_get_step_path()) - old_backend = backend - backend = exec_params._parallel_backend.get(step_path, backend) - msg = f"Overriding parallel backend {old_backend}→{backend}" - logger.info(**gen_log_kwargs(message=msg)) - raise RuntimeError(backend) - return backend @@ -117,7 +111,12 @@ def get_parallel_backend(exec_params: SimpleNamespace) -> joblib.parallel_backen import joblib backend = get_parallel_backend_name(exec_params=exec_params) - kwargs = {"n_jobs": get_n_jobs(exec_params=exec_params)} + kwargs = { + "n_jobs": get_n_jobs( + exec_params=exec_params, + log_override=True, + ) + } if backend == "loky": kwargs["inner_max_num_threads"] = 1 diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index 428b1c1de..5908e4b0c 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -356,8 +356,10 @@ def _get_step_path( if "steps" in fname.parts: return fname else: # pragma: no cover - if frame.function == "__mne_bids_pipeline_failsafe_wrapper__": + try: return frame.frame.f_locals["__mne_bids_pipeline_step__"] + except KeyError: + pass else: # pragma: no cover paths = "\n".join(paths) raise RuntimeError(f"Could not find step path in call stack:\n{paths}") diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index 196e3076d..a71eb4321 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -196,7 +196,7 @@ def _decoding_cluster_permutation_test( out_type="mask", tail=1, # one-sided: significantly above chance level seed=random_seed, - verbose=True, + verbose="error", # ignore No clusters found ) n_permutations = H0.size - 1 @@ -280,7 +280,7 @@ def _get_input_fnames_decoding( ) -> dict: in_files = _get_epochs_in_files(cfg=cfg, subject=subject, session=session) for this_subject in cfg.subjects: - in_files[f"scores-{subject}"] = _decoding_out_fname( + in_files[f"scores-{this_subject}"] = _decoding_out_fname( cfg=cfg, subject=this_subject, session=session, @@ -347,6 +347,7 @@ def average_time_by_time_decoding( mean_scores = np.empty((n_subjects, *time_points_shape)) # Remaining in_files are all decoding data + assert len(in_files) == n_subjects, list(in_files.keys()) for sub_idx, key in enumerate(list(in_files)): decoding_data = loadmat(in_files.pop(key)) mean_scores[sub_idx, :] = decoding_data["scores"].mean(axis=0) @@ -901,7 +902,8 @@ def _average_csp_time_freq( bootstrapped_means = scores_resampled.mean(axis=1) # SD of the bootstrapped distribution == SE of the metric. - se = bootstrapped_means.std(ddof=1) + with np.errstate(over="raise"): + se = bootstrapped_means.std(ddof=1) ci_lower = np.quantile(bootstrapped_means, q=0.025) ci_upper = np.quantile(bootstrapped_means, q=0.975) diff --git a/mne_bids_pipeline/tests/test_run.py b/mne_bids_pipeline/tests/test_run.py index 59a475463..909d6d091 100644 --- a/mne_bids_pipeline/tests/test_run.py +++ b/mne_bids_pipeline/tests/test_run.py @@ -105,10 +105,6 @@ class _TestOptionsT(TypedDict, total=False): "sensor/_05_decoding_csp": 1, "sensor/_99_group_average": 1, } -_parallel_backend = { - "sensor/_05_decoding_csp": "loky", - "sensor/_99_group_average": "loky", -} """, }, "ERP_CORE_LRP": { From ed86ba14c9f271911268433f2ed7f322da197c7d Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 16:00:11 -0400 Subject: [PATCH 13/15] ENH: Maybe source, too? --- .circleci/config.yml | 3 +- .circleci/run_dataset_and_copy_files.sh | 2 +- mne_bids_pipeline/_report.py | 68 +---- .../steps/sensor/_06_make_cov.py | 10 +- .../steps/sensor/_99_group_average.py | 4 +- .../steps/source/_05_make_inverse.py | 30 +- .../steps/source/_99_group_average.py | 263 ++++++++++-------- 7 files changed, 178 insertions(+), 202 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f64bfbd31..03c9de23e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -486,7 +486,8 @@ jobs: - data-cache-ds000248-4 - run: name: test ds000248_base - command: $RUN_TESTS ds000248_base + # Forces rerunning (cov and FLASH BEM) so don't check + command: $RUN_TESTS -r ds000248_base - codecov/upload - store_test_results: path: ./test-results diff --git a/.circleci/run_dataset_and_copy_files.sh b/.circleci/run_dataset_and_copy_files.sh index 2987cb449..34dcfa14f 100755 --- a/.circleci/run_dataset_and_copy_files.sh +++ b/.circleci/run_dataset_and_copy_files.sh @@ -36,7 +36,7 @@ if [[ "$RERUN_TEST" == "false" ]]; then echo "Skipping rerun test" RUN_TIME=0 else - pytest mne_bids_pipeline -k $DS_RUN + pytest mne_bids_pipeline --cov-append -k $DS_RUN RUN_TIME=$SECONDS echo "Runtime: ${RUN_TIME} seconds (should be < 20)" fi diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index 9719cea6e..3d9736c95 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -1,7 +1,6 @@ import contextlib from functools import lru_cache from io import StringIO -from pathlib import Path from typing import Optional, List, Literal from types import SimpleNamespace @@ -17,7 +16,7 @@ from mne_bids import BIDSPath from mne_bids.stats import count_events -from ._config_utils import sanitize_cond_name +from ._config_utils import get_all_contrasts from ._decoding import _handle_csp_args from ._logging import logger, gen_log_kwargs, _linkfile @@ -549,72 +548,11 @@ def _all_conditions(*, cfg): conditions = list(cfg.conditions.keys()) else: conditions = cfg.conditions.copy() - conditions.extend([contrast["name"] for contrast in cfg.all_contrasts]) + all_contrasts = get_all_contrasts(cfg) + conditions.extend([contrast["name"] for contrast in all_contrasts]) return conditions -def run_report_average_source( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], -) -> None: - ####################################################################### - # - # Visualize forward solution, inverse operator, and inverse solutions. - # - evoked_fname = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="ave", - extension=".fif", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - evokeds = mne.read_evokeds(evoked_fname) - method = cfg.inverse_method - inverse_str = method - hemi_str = "hemi" # MNE will auto-append '-lh' and '-rh'. - morph_str = "morph2fsaverage" - conditions = _all_conditions(cfg=cfg) - with _open_report( - cfg=cfg, exec_params=exec_params, subject=subject, session=session - ) as report: - for condition, evoked in zip(conditions, evokeds): - tags = ( - "source-estimate", - _sanitize_cond_tag(condition), - ) - if condition in cfg.conditions: - title = f"Average: {condition}" - else: # It's a contrast of two conditions. - title = f"Average contrast: {condition}" - tags = tags + ("contrast",) - cond_str = sanitize_cond_name(condition) - fname_stc_avg = evoked_fname.copy().update( - suffix=f"{cond_str}+{inverse_str}+{morph_str}+{hemi_str}", - extension=None, - ) - if not Path(f"{fname_stc_avg.fpath}-lh.stc").exists(): - continue - report.add_stc( - stc=fname_stc_avg, - title=title, - subject="fsaverage", - subjects_dir=cfg.fs_subjects_dir, - n_time_points=cfg.report_stc_n_time_points, - tags=tags, - replace=True, - ) - - def _sanitize_cond_tag(cond): return str(cond).lower().replace(" ", "-") diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index c32f16cdf..27323d680 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -16,7 +16,7 @@ _bids_kwargs, ) from ..._config_import import _import_config -from ..._config_utils import _restrict_analyze_channels, get_all_contrasts +from ..._config_utils import _restrict_analyze_channels from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func from ..._report import _open_report, _sanitize_cond_tag, _all_conditions @@ -258,10 +258,8 @@ def run_covariance( for evoked, condition in zip(all_evoked, conditions): _restrict_analyze_channels(evoked, cfg) tags = ("evoked", "covariance", _sanitize_cond_tag(condition)) - if condition in cfg.conditions: - title = f"Whitening: {condition}" - else: # It's a contrast of two conditions. - title = f"Whitening: {condition}" + title = f"Whitening: {condition}" + if condition not in cfg.conditions: tags = tags + ("contrast",) fig = evoked.plot_white(cov, verbose="error") report.add_figure( @@ -287,7 +285,7 @@ def get_config( run_source_estimation=config.run_source_estimation, noise_cov=_sanitize_callable(config.noise_cov), conditions=config.conditions, - all_contrasts=get_all_contrasts(config), + contrasts=config.contrasts, analyze_channels=config.analyze_channels, **_bids_kwargs(config=config), ) diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index a71eb4321..00b17372f 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -22,7 +22,6 @@ get_subjects, get_eeg_reference, get_decoding_contrasts, - get_all_contrasts, _bids_kwargs, _restrict_analyze_channels, ) @@ -934,7 +933,7 @@ def get_config( subjects=get_subjects(config), task_is_rest=config.task_is_rest, conditions=config.conditions, - contrasts=get_all_contrasts(config), + contrasts=config.contrasts, decode=config.decode, decoding_metric=config.decoding_metric, decoding_n_splits=config.decoding_n_splits, @@ -954,7 +953,6 @@ def get_config( eeg_reference=get_eeg_reference(config), sessions=get_sessions(config), exclude_subjects=config.exclude_subjects, - all_contrasts=get_all_contrasts(config), report_evoked_n_time_points=config.report_evoked_n_time_points, cluster_permutation_p_threshold=config.cluster_permutation_p_threshold, # TODO: needed because get_datatype gets called again... diff --git a/mne_bids_pipeline/steps/source/_05_make_inverse.py b/mne_bids_pipeline/steps/source/_05_make_inverse.py index 161fc9a06..6e96e13ef 100644 --- a/mne_bids_pipeline/steps/source/_05_make_inverse.py +++ b/mne_bids_pipeline/steps/source/_05_make_inverse.py @@ -3,7 +3,6 @@ Compute and apply an inverse solution for each evoked data set. """ -import pathlib from types import SimpleNamespace from typing import Optional @@ -26,7 +25,7 @@ ) from ..._logging import logger, gen_log_kwargs from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _open_report, _sanitize_cond_tag +from ..._report import _open_report, _sanitize_cond_tag, _all_conditions from ..._run import failsafe_run, save_logs, _sanitize_callable, _prep_out_files @@ -97,22 +96,18 @@ def run_inverse( # Apply inverse snr = 3.0 lambda2 = 1.0 / snr**2 - - if isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions - + conditions = _all_conditions(cfg=cfg) method = cfg.inverse_method if "evoked" in in_files: fname_ave = in_files.pop("evoked") evokeds = mne.read_evokeds(fname_ave) for condition, evoked in zip(conditions, evokeds): - pick_ori = None - cond_str = sanitize_cond_name(condition) - key = f"{cond_str}+{method}+hemi" - out_files[key] = fname_ave.copy().update(suffix=key, extension=None) + suffix = f"{sanitize_cond_name(condition)}+{method}+hemi" + out_files[condition] = fname_ave.copy().update( + suffix=suffix, + extension=".h5", + ) if "eeg" in cfg.ch_types: evoked.set_eeg_reference("average", projection=True) @@ -122,10 +117,9 @@ def run_inverse( inverse_operator=inverse_operator, lambda2=lambda2, method=method, - pick_ori=pick_ori, + pick_ori=None, ) - stc.save(out_files[key], overwrite=True) - out_files[key] = pathlib.Path(str(out_files[key]) + "-lh.stc") + stc.save(out_files[condition], ftype="h5", overwrite=True) with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session @@ -139,10 +133,11 @@ def run_inverse( continue msg = f"Rendering inverse solution for {condition}" logger.info(**gen_log_kwargs(message=msg)) - fname_stc = out_files[key] tags = ("source-estimate", _sanitize_cond_tag(condition)) + if condition not in cfg.conditions: + tags = tags + ("contrast",) report.add_stc( - stc=fname_stc, + stc=out_files[condition], title=f"Source: {condition}", subject=cfg.fs_subject, subjects_dir=cfg.fs_subjects_dir, @@ -165,6 +160,7 @@ def get_config( inverse_targets=config.inverse_targets, ch_types=config.ch_types, conditions=config.conditions, + contrasts=config.contrasts, loose=config.loose, depth=config.depth, inverse_method=config.inverse_method, diff --git a/mne_bids_pipeline/steps/source/_99_group_average.py b/mne_bids_pipeline/steps/source/_99_group_average.py index 3212e0249..1ea7fd5d1 100644 --- a/mne_bids_pipeline/steps/source/_99_group_average.py +++ b/mne_bids_pipeline/steps/source/_99_group_average.py @@ -4,7 +4,7 @@ """ from types import SimpleNamespace -from typing import Optional, List +from typing import Optional import numpy as np @@ -17,17 +17,28 @@ sanitize_cond_name, get_fs_subject, get_sessions, - get_all_contrasts, _bids_kwargs, ) from ..._logging import logger, gen_log_kwargs from ..._parallel import get_parallel_backend, parallel_func -from ..._report import run_report_average_source -from ..._run import failsafe_run, save_logs +from ..._report import _all_conditions, _open_report +from ..._run import failsafe_run, save_logs, _prep_out_files -def morph_stc(cfg, subject, fs_subject, session=None): - bids_path = BIDSPath( +def _stc_path( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], + condition: str, + morphed: bool, +) -> BIDSPath: + cond_str = sanitize_cond_name(condition) + suffix = [cond_str, cfg.inverse_method, "hemi"] + if morphed: + suffix.insert(2, "morph2fsaverage") + suffix = "+".join(suffix) + return BIDSPath( subject=subject, session=session, task=cfg.task, @@ -37,35 +48,47 @@ def morph_stc(cfg, subject, fs_subject, session=None): space=cfg.space, datatype=cfg.datatype, root=cfg.deriv_root, + suffix=suffix, + extension=".h5", check=False, ) - morphed_stcs = [] - - if cfg.task_is_rest: - conditions = [cfg.task.lower()] - else: - if isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions - for condition in conditions: - method = cfg.inverse_method - cond_str = sanitize_cond_name(condition) - inverse_str = method - hemi_str = "hemi" # MNE will auto-append '-lh' and '-rh'. - morph_str = "morph2fsaverage" - - fname_stc = bids_path.copy().update( - suffix=f"{cond_str}+{inverse_str}+{hemi_str}" - ) - fname_stc_fsaverage = bids_path.copy().update( - suffix=f"{cond_str}+{inverse_str}+{morph_str}+{hemi_str}" +def get_input_fnames_morph_stc( + *, + cfg: SimpleNamespace, + subject: str, + fs_subject: str, + session: Optional[str], +) -> dict: + in_files = dict() + for condition in _all_conditions(cfg=cfg): + in_files[f"original-{condition}"] = _stc_path( + cfg=cfg, + subject=subject, + session=session, + condition=condition, + morphed=False, ) + return in_files - stc = mne.read_source_estimate(fname_stc) +@failsafe_run( + get_input_fnames=get_input_fnames_morph_stc, +) +def morph_stc( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + fs_subject: str, + session: Optional[str], + in_files: dict, +) -> dict: + out_files = dict() + for condition in _all_conditions(cfg=cfg): + fname_stc = in_files.pop(f"original-{condition}") + stc = mne.read_source_estimate(fname_stc) morph = mne.compute_source_morph( stc, subject_from=fs_subject, @@ -73,51 +96,98 @@ def morph_stc(cfg, subject, fs_subject, session=None): subjects_dir=cfg.fs_subjects_dir, ) stc_fsaverage = morph.apply(stc) - stc_fsaverage.save(fname_stc_fsaverage, overwrite=True) - morphed_stcs.append(stc_fsaverage) + key = f"morphed-{condition}" + out_files[key] = _stc_path( + cfg=cfg, + subject=subject, + session=session, + condition=condition, + morphed=True, + ) + stc_fsaverage.save(out_files[key], ftype="h5", overwrite=True) + + assert len(in_files) == 0, in_files + return _prep_out_files(out_files=out_files, exec_params=exec_params) - del fname_stc, fname_stc_fsaverage - return morphed_stcs +def get_input_fnames_run_average( + *, + cfg: SimpleNamespace, + subject: str, + session: Optional[str], +) -> dict: + in_files = dict() + assert subject == "average" + for condition in _all_conditions(cfg=cfg): + for this_subject in cfg.subjects: + in_files[f"{this_subject}-{condition}"] = _stc_path( + cfg=cfg, + subject=this_subject, + session=session, + condition=condition, + morphed=True, + ) + return in_files +@failsafe_run( + get_input_fnames=get_input_fnames_run_average, +) def run_average( *, cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, session: Optional[str], - mean_morphed_stcs: List[mne.SourceEstimate], + in_files: dict, ): - bids_path = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - processing=cfg.proc, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - - if isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions - - for condition, stc in zip(conditions, mean_morphed_stcs): - method = cfg.inverse_method - cond_str = sanitize_cond_name(condition) - inverse_str = method - hemi_str = "hemi" # MNE will auto-append '-lh' and '-rh'. - morph_str = "morph2fsaverage" - - fname_stc_avg = bids_path.copy().update( - suffix=f"{cond_str}+{inverse_str}+{morph_str}+{hemi_str}" + assert subject == "average" + out_files = dict() + conditions = _all_conditions(cfg=cfg) + for condition in conditions: + stc = np.array( + [ + mne.read_source_estimate(in_files.pop(f"{this_subject}-{condition}")) + for this_subject in cfg.subjects + ] + ).mean(axis=0) + out_files[condition] = _stc_path( + cfg=cfg, + subject=subject, + session=session, + condition=condition, + morphed=True, ) - stc.save(fname_stc_avg, overwrite=True) + stc.save(out_files[condition], ftype="h5", overwrite=True) + + ####################################################################### + # + # Visualize forward solution, inverse operator, and inverse solutions. + # + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + for condition in conditions: + msg = f"Rendering inverse solution for {condition}" + logger.info(**gen_log_kwargs(message=msg)) + cond_str = sanitize_cond_name(condition) + tags = ("source-estimate", cond_str) + if condition in cfg.conditions: + title = f"Average: {condition}" + else: # It's a contrast of two conditions. + title = f"Average contrast: {condition}" + tags = tags + ("contrast",) + report.add_stc( + stc=out_files[condition], + title=title, + subject="fsaverage", + subjects_dir=cfg.fs_subjects_dir, + n_time_points=cfg.report_stc_n_time_points, + tags=tags, + replace=True, + ) + assert len(in_files) == 0, in_files + return _prep_out_files(out_files=out_files, exec_params=exec_params) def get_config( @@ -131,11 +201,11 @@ def get_config( fs_subjects_dir=get_fs_subjects_dir(config), subjects_dir=get_fs_subjects_dir(config), ch_types=config.ch_types, - subjects=config.subjects, + subjects=get_subjects(config=config), exclude_subjects=config.exclude_subjects, sessions=get_sessions(config), use_template_mri=config.use_template_mri, - all_contrasts=get_all_contrasts(config), + contrasts=config.contrasts, report_stc_n_time_points=config.report_stc_n_time_points, # TODO: needed because get_datatype gets called again... data_type=config.data_type, @@ -144,64 +214,39 @@ def get_config( return cfg -# pass 'average' subject for logging -@failsafe_run() -def run_group_average_source( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, -) -> None: - """Run group average in source space""" +def main(*, config: SimpleNamespace) -> None: + if not config.run_source_estimation: + msg = "Skipping, run_source_estimation is set to False …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return - mne.datasets.fetch_fsaverage(subjects_dir=get_fs_subjects_dir(cfg)) + mne.datasets.fetch_fsaverage(subjects_dir=get_fs_subjects_dir(config)) + cfg = get_config(config=config) + exec_params = config.exec_params + subjects = get_subjects(config) + sessions = get_sessions(config) + logs = list() with get_parallel_backend(exec_params): parallel, run_func = parallel_func(morph_stc, exec_params=exec_params) - all_morphed_stcs = parallel( + logs += parallel( run_func( cfg=cfg, + exec_params=exec_params, subject=subject, fs_subject=get_fs_subject(config=cfg, subject=subject), session=session, ) - for subject in get_subjects(cfg) - for session in get_sessions(cfg) + for subject in subjects + for session in sessions ) - mean_morphed_stcs = np.array(all_morphed_stcs).mean(axis=0) - - # XXX to fix - sessions = get_sessions(cfg) - if sessions: - session = sessions[0] - else: - session = None - + logs += [ run_average( - cfg=cfg, - session=session, - subject=subject, - mean_morphed_stcs=mean_morphed_stcs, - ) - run_report_average_source( cfg=cfg, exec_params=exec_params, - subject=subject, session=session, + subject="average", ) - - -def main(*, config: SimpleNamespace) -> None: - if not config.run_source_estimation: - msg = "Skipping, run_source_estimation is set to False …" - logger.info(**gen_log_kwargs(message=msg, emoji="skip")) - return - - log = run_group_average_source( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject="average", - ) - save_logs(config=config, logs=[log]) + for session in sessions + ] + save_logs(config=config, logs=logs) From a54d96220eea59d3f328e5b9dc281e7df33af7e1 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 16:23:57 -0400 Subject: [PATCH 14/15] FIX: Ignore --- mne_bids_pipeline/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_bids_pipeline/tests/conftest.py b/mne_bids_pipeline/tests/conftest.py index 1234300f4..b75d009ec 100644 --- a/mne_bids_pipeline/tests/conftest.py +++ b/mne_bids_pipeline/tests/conftest.py @@ -42,6 +42,7 @@ def pytest_configure(config): # seaborn->pandas ignore:is_categorical_dtype is deprecated.*:FutureWarning ignore:use_inf_as_na option is deprecated.*:FutureWarning + ignore:All-NaN axis encountered.*:RuntimeWarning """ for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() From 0ff64933136b4ff25f369959b2920b154607cf1b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 18 Jul 2023 20:51:35 -0400 Subject: [PATCH 15/15] FIX: Titles --- mne_bids_pipeline/_config.py | 2 +- .../steps/sensor/_01_make_evoked.py | 12 +++++-- .../steps/sensor/_99_group_average.py | 35 ++++++++++--------- .../steps/source/_99_group_average.py | 4 +-- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 812248ee4..85b0335e1 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1572,7 +1572,7 @@ time_frequency_subtract_evoked: bool = False """ -Whether to subtract the evoked signal (averaged across all epochs) from the +Whether to subtract the evoked response (averaged across all epochs) from the epochs before passing them to time-frequency analysis. Set this to `True` to highlight induced activity. diff --git a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py index 00c6c64ef..a0f7d1e3e 100644 --- a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py +++ b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py @@ -12,10 +12,11 @@ get_all_contrasts, _bids_kwargs, _restrict_analyze_channels, + _pl, ) from ..._logging import gen_log_kwargs, logger from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _sanitize_cond_tag +from ..._report import _open_report, _sanitize_cond_tag, _all_conditions from ..._run import failsafe_run, save_logs, _sanitize_callable, _prep_out_files @@ -98,10 +99,17 @@ def run_evoked( # Report if evokeds: - msg = f"Adding {len(evokeds)} evoked signals and contrasts to the " f"report." + n_contrasts = len(cfg.contrasts) + n_signals = len(evokeds) - n_contrasts + msg = ( + f"Adding {n_signals} evoked response{_pl(n_signals)} and " + f"{n_contrasts} contrast{_pl(n_contrasts)} to the report." + ) else: msg = "No evoked conditions or contrasts found." logger.info(**gen_log_kwargs(message=msg)) + all_conditions = _all_conditions(cfg=cfg) + assert list(all_conditions) == list(all_evoked) # otherwise we have a bug with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session ) as report: diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index 00b17372f..008e8b7dd 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -5,7 +5,6 @@ import os import os.path as op -from collections import defaultdict from functools import partial from typing import Optional, TypedDict, List, Tuple from types import SimpleNamespace @@ -24,6 +23,7 @@ get_decoding_contrasts, _bids_kwargs, _restrict_analyze_channels, + _pl, ) from ..._decoding import _handle_csp_args from ..._logging import gen_log_kwargs, logger @@ -39,6 +39,7 @@ plot_time_by_time_decoding_t_values, _plot_decoding_time_generalization, _contrasts_to_names, + _all_conditions, ) @@ -80,23 +81,24 @@ def average_evokeds( ) -> dict: logger.info(**gen_log_kwargs(message="Creating grand averages")) # Container for all conditions: - all_evokeds = defaultdict(list) + conditions = _all_conditions(cfg=cfg) + evokeds = [list() for _ in range(len(conditions))] keys = list(in_files) for key in keys: if not key.startswith("evoked-"): continue fname_in = in_files.pop(key) - evokeds = mne.read_evokeds(fname_in) - for idx, evoked in enumerate(evokeds): - all_evokeds[idx].append(evoked) # Insert into the container + these_evokeds = mne.read_evokeds(fname_in) + for idx, evoked in enumerate(these_evokeds): + evokeds[idx].append(evoked) # Insert into the container - for idx, evokeds in all_evokeds.items(): - all_evokeds[idx] = mne.grand_average( - evokeds, interpolate_bads=cfg.interpolate_bads_grand_average + for idx, these_evokeds in enumerate(evokeds): + evokeds[idx] = mne.grand_average( + these_evokeds, interpolate_bads=cfg.interpolate_bads_grand_average ) # Combine subjects # Keep condition in comment - all_evokeds[idx].comment = "Grand average: " + evokeds[0].comment + evokeds[idx].comment = "Grand average: " + these_evokeds[0].comment out_files = dict() fname_out = out_files["evokeds"] = BIDSPath( @@ -120,7 +122,6 @@ def average_evokeds( msg = f"Saving grand-averaged evoked sensor data: {fname_out.basename}" logger.info(**gen_log_kwargs(message=msg)) - evokeds = list(all_evokeds.values()) mne.write_evokeds(fname_out, evokeds, overwrite=True) if exec_params.interactive: for evoked in evokeds: @@ -140,20 +141,22 @@ def average_evokeds( ) # Evoked responses - if all_evokeds: + if evokeds: + n_contrasts = len(cfg.contrasts) + n_signals = len(evokeds) - n_contrasts msg = ( - f"Adding {len(all_evokeds)} evoked signals and contrasts to " - "the report." + f"Adding {n_signals} evoked response{_pl(n_signals)} and " + f"{n_contrasts} contrast{_pl(n_contrasts)} to the report." ) else: msg = "No evoked conditions or contrasts found." logger.info(**gen_log_kwargs(message=msg)) - for condition, evoked in all_evokeds.items(): + for condition, evoked in zip(conditions, evokeds): tags = ("evoked", _sanitize_cond_tag(condition)) if condition in cfg.conditions: - title = f"Condition: {condition}" + title = f"Average (sensor): {condition}" else: # It's a contrast of two conditions. - title = f"Contrast: {condition}" + title = f"Average (sensor) contrast: {condition}" tags = tags + ("contrast",) report.add_evokeds( diff --git a/mne_bids_pipeline/steps/source/_99_group_average.py b/mne_bids_pipeline/steps/source/_99_group_average.py index 1ea7fd5d1..9e855d6df 100644 --- a/mne_bids_pipeline/steps/source/_99_group_average.py +++ b/mne_bids_pipeline/steps/source/_99_group_average.py @@ -173,9 +173,9 @@ def run_average( cond_str = sanitize_cond_name(condition) tags = ("source-estimate", cond_str) if condition in cfg.conditions: - title = f"Average: {condition}" + title = f"Average (source): {condition}" else: # It's a contrast of two conditions. - title = f"Average contrast: {condition}" + title = f"Average (source) contrast: {condition}" tags = tags + ("contrast",) report.add_stc( stc=out_files[condition],