diff --git a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py index 4f3cd0ae2..96f310f44 100644 --- a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py @@ -7,7 +7,7 @@ import mne import numpy as np import pandas as pd -from mne.decoding import CSP +from mne.decoding import CSP, LinearModel from mne_bids import BIDSPath from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.pipeline import make_pipeline @@ -162,6 +162,10 @@ def one_subject_decoding( epochs=epochs, ) + # Create output directory if it doesn't already exist + output_dir = bids_path.fpath.parent / "CSD_output" + output_dir.mkdir(exist_ok=True) + # Classifier csp = CSP( n_components=4, # XXX revisit @@ -170,10 +174,12 @@ def one_subject_decoding( clf = make_pipeline( *preproc_steps, csp, - LogReg( - solver="liblinear", # much faster than the default - random_state=cfg.random_state, - n_jobs=1, + LinearModel( + LogReg( + solver="liblinear", # much faster than the default + random_state=cfg.random_state, + n_jobs=1, + ) ), ) cv = StratifiedKFold( @@ -239,6 +245,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non epochs_filt, y = prepare_epochs_and_y( epochs=epochs, contrast=contrast, fmin=fmin, fmax=fmax, cfg=cfg ) + # Get the data for all time points X = epochs_filt.get_data() @@ -253,6 +260,21 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non freq_decoding_table.loc[idx, "mean_crossval_score"] = cv_scores.mean() freq_decoding_table.at[idx, "scores"] = cv_scores + # COEFS + clf.fit(X, y) + weights_csp = mne.decoding.get_coef(clf, "patterns_", inverse_transform=True) + + # PATTERNS + csp.fit_transform(X, y) + sensor_pattern_csp = csp.patterns_ + + # save weights and patterns + csp_patterns_fname = f"{cond1}_{cond2}_{str(fmin)}_{str(fmax)}_Hz_patterns" + csp_weights_fname = f"{cond1}_{cond2}_{str(fmin)}_{str(fmax)}_Hz_weights" + + np.save(op.join(output_dir, csp_patterns_fname), sensor_pattern_csp) + np.save(op.join(output_dir, csp_weights_fname), weights_csp) + # Loop over times x frequencies # # Note: We don't support varying time ranges for different frequency @@ -306,6 +328,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non # Crop data to the time window of interest if tmax is not None: # avoid warnings about outside the interval tmax = min(tmax, epochs_filt.times[-1]) + X = epochs_filt.crop(tmin, tmax).get_data() del epochs_filt cv_scores = cross_val_score( @@ -323,6 +346,21 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non msg += f": {cfg.decoding_metric}={score:0.3f}" logger.info(**gen_log_kwargs(msg)) + # COEFS + clf.fit(X, y) + weights_csp = mne.decoding.get_coef(clf, "patterns_", inverse_transform=True) + + # PATTERNS + csp.fit_transform(X, y) + sensor_pattern_csp = csp.patterns_ + + # save weights and patterns + csp_patterns_fname = f"{cond1}_{cond2}_{str(fmin)}_{str(fmax)}_Hz_{str(tmin)}_{str(tmax)}_s_patterns" + csp_weights_fname = f"{cond1}_{cond2}_{str(fmin)}_{str(fmax)}_Hz_{str(tmin)}_{str(tmax)}_s_patterns" + + np.save(op.join(output_dir, csp_patterns_fname), sensor_pattern_csp) + np.save(op.join(output_dir, csp_weights_fname), weights_csp) + # Write each DataFrame to a different Excel worksheet. a_vs_b = f"{condition1}+{condition2}".replace(op.sep, "") processing = f"{a_vs_b}+CSP+{cfg.decoding_metric}"