Skip to content

Commit

Permalink
FIX: Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Sep 12, 2024
1 parent 7b2039a commit dc037e9
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
### Before merging …

- [ ] Changelog has been updated (`docs/source/changes.md`)
- [ ] Changelog has been updated (`docs/source/vX.Y.md.inc`)
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ build/
.hypothesis/
.coverage*
junit-results.xml
.cache/
9 changes: 5 additions & 4 deletions docs/source/examples/gen_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import shutil
import sys
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Generator, Iterable
from pathlib import Path
from typing import Any

from tqdm import tqdm

Expand All @@ -23,15 +24,15 @@
logger = logging.getLogger()


def _bool_to_icon(x: bool | Iterable) -> str:
def _bool_to_icon(x: bool | Iterable[Any]) -> str:
if x:
return "✅"
else:
return "❌"


@contextlib.contextmanager
def _task_context(task):
def _task_context(task: str | None) -> Generator[None, None, None]:
old_argv = sys.argv
if task:
sys.argv = [sys.argv[0], f"--task={task}"]
Expand All @@ -41,7 +42,7 @@ def _task_context(task):
sys.argv = old_argv


def _gen_demonstrated_funcs(example_config_path: Path) -> dict:
def _gen_demonstrated_funcs(example_config_path: Path) -> dict[str, bool]:
"""Generate dict of demonstrated functionality based on config."""
# Here we use a defaultdict, and for keys that might vary across configs
# we should use an `funcs[key] = funcs[key] or ...` so that we effectively
Expand Down
1 change: 0 additions & 1 deletion docs/source/features/gen_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@
if not isinstance(a_b, str):
all_steps_list.extend(a_b)
all_steps = set(all_steps_list)
assert len(all_steps) == len(all_steps_list)
assert mapped == all_steps, all_steps.symmetric_difference(mapped)
overview_lines.append("```\n\n</details>\n")

Expand Down
1 change: 1 addition & 0 deletions docs/source/v1.9.md.inc
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@
- Use GitHub's `dependabot` service to automatically keep GitHub Actions up-to-date. (#893 by @hoechenberger)
- Clean up some strings that our autoformatter failed to correctly merge. (#965 by @drammock)
- Type hints are now checked using `mypy`. (#995 by @larsoner)
22 changes: 16 additions & 6 deletions mne_bids_pipeline/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
from collections import defaultdict
from pathlib import Path
from types import FunctionType
from typing import Any

from tqdm import tqdm

Expand Down Expand Up @@ -104,11 +106,11 @@


class _ParseConfigSteps:
def __init__(self, force_empty=None):
def __init__(self, force_empty: tuple[str, ...] | None = None) -> None:
self._force_empty = _FORCE_EMPTY if force_empty is None else force_empty
self.steps = defaultdict(list)
self.steps: dict[str, Any] = defaultdict(list)
# Add a few helper functions
for func in (
for func_extra in (
_config_utils.get_eeg_reference,
_config_utils.get_all_contrasts,
_config_utils.get_decoding_contrasts,
Expand All @@ -117,15 +119,16 @@ def __init__(self, force_empty=None):
_config_utils.get_mf_cal_fname,
_config_utils.get_mf_ctc_fname,
):
this_list = []
for attr in ast.walk(ast.parse(inspect.getsource(func))):
this_list: list[str] = []
assert isinstance(func_extra, FunctionType)
for attr in ast.walk(ast.parse(inspect.getsource(func_extra))):
if not isinstance(attr, ast.Attribute):
continue
if not (isinstance(attr.value, ast.Name) and attr.value.id == "config"):
continue
if attr.attr not in this_list:
this_list.append(attr.attr)
_MANUAL_KWS[func.__name__] = tuple(this_list)
_MANUAL_KWS[func_extra.__name__] = tuple(this_list)

for module in tqdm(
sum(_config_utils._get_step_modules().values(), tuple()),
Expand All @@ -147,6 +150,7 @@ def __init__(self, force_empty=None):
for keyword in call.keywords:
if not isinstance(keyword.value, ast.Attribute):
continue
assert isinstance(keyword.value.value, ast.Name)
if keyword.value.value.id != "config":
continue
if keyword.value.attr in ("exec_params",):
Expand All @@ -165,6 +169,7 @@ def __init__(self, force_empty=None):
for attr in ast.walk(cond.test):
if not isinstance(attr, ast.Attribute):
continue
assert isinstance(attr.value, ast.Name)
if attr.value.id != "config":
continue
self._add_step_option(step, attr.attr)
Expand All @@ -175,6 +180,7 @@ def __init__(self, force_empty=None):
for call in ast.walk(func):
if not isinstance(call, ast.Call):
continue
assert isinstance(call.func, ast.Name)
if call.func.id != "SimpleNamespace":
continue
break
Expand All @@ -183,6 +189,7 @@ def __init__(self, force_empty=None):
assert call.args == []
for keyword in call.keywords:
if isinstance(keyword.value, ast.Call):
assert isinstance(keyword.value.func, ast.Name)
key = keyword.value.func.id
if key in _MANUAL_KWS:
for option in _MANUAL_KWS[key]:
Expand All @@ -191,6 +198,7 @@ def __init__(self, force_empty=None):
if keyword.value.func.id == "_sanitize_callable":
assert len(keyword.value.args) == 1
assert isinstance(keyword.value.args[0], ast.Attribute)
assert isinstance(keyword.value.args[0].value, ast.Name)
assert keyword.value.args[0].value.id == "config"
self._add_step_option(step, keyword.value.args[0].attr)
continue
Expand All @@ -213,6 +221,7 @@ def __init__(self, force_empty=None):
for func_name in _EXTRA_FUNCS.get(key, ()):
funcs.append(getattr(_config_utils, func_name))
for fi, func in enumerate(funcs):
assert isinstance(func, FunctionType), func
source = inspect.getsource(func)
assert "config: SimpleNamespace" in source, key
if fi == 0:
Expand Down Expand Up @@ -240,6 +249,7 @@ def __init__(self, force_empty=None):
option = keyword.value.attr
if option in _IGNORE_OPTIONS:
continue
assert isinstance(keyword.value.value, ast.Name)
assert keyword.value.value.id == "config", f"{where} {keyword.value.value.id}" # noqa: E501 # fmt: skip
self._add_step_option(step, option)
if step in _NO_CONFIG:
Expand Down
45 changes: 22 additions & 23 deletions mne_bids_pipeline/steps/sensor/_05_decoding_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def one_subject_decoding(
ignore_index=True,
)
del freq_decoding_table_rows
assert isinstance(freq_decoding_table, pd.DataFrame)

def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=None):
msg = (
Expand All @@ -235,16 +234,12 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
return msg

for idx, row in freq_decoding_table.iterrows():
assert isinstance(row["f_min"], pd.Series)
fmin = row["f_min"].item()
assert isinstance(row["f_max"], pd.Series)
fmax = row["f_max"].item()
assert isinstance(row["cond_1"], pd.Series)
cond1 = row["cond_1"].item()
assert isinstance(row["cond_2"], pd.Series)
cond2 = row["cond_2"].item()
assert isinstance(row["freq_range_name"], pd.Series)
freq_range_name = row["freq_range_name"].item()
assert isinstance(row, pd.Series)
fmin = row["f_min"]
fmax = row["f_max"]
cond1 = row["cond_1"]
cond2 = row["cond_2"]
freq_range_name = row["freq_range_name"]

msg = _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name)
logger.info(**gen_log_kwargs(msg))
Expand All @@ -268,6 +263,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
)
freq_decoding_table.loc[idx, "mean_crossval_score"] = cv_scores.mean()
freq_decoding_table.at[idx, "scores"] = cv_scores
del fmin, fmax, cond1, cond2, freq_range_name

# Loop over times x frequencies
#
Expand Down Expand Up @@ -308,6 +304,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
for idx, row in tf_decoding_table.iterrows():
if len(row) == 0:
break # no data
assert isinstance(row, pd.Series)
tmin = row["t_min"]
tmax = row["t_max"]
fmin = row["f_min"]
Expand Down Expand Up @@ -339,6 +336,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
msg = _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin, tmax)
msg += f": {cfg.decoding_metric}={score:0.3f}"
logger.info(**gen_log_kwargs(msg))
del tmin, tmax, fmin, fmax, cond1, cond2, freq_range_name

# Write each DataFrame to a different Excel worksheet.
a_vs_b = f"{condition1}+{condition2}".replace(op.sep, "")
Expand Down Expand Up @@ -441,14 +439,15 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
)
results = all_csp_tf_results[contrast]
mean_crossval_scores: list[float] = list()
tmin, tmax, fmin, fmax = list(), list(), list(), list()
tmin_list, tmax_list = list(), list()
fmin_list, fmax_list = list(), list()
mean_crossval_scores.extend(
results["mean_crossval_score"].to_numpy().ravel().tolist()
)
tmin.extend(results["t_min"].to_numpy().ravel())
tmax.extend(results["t_max"].to_numpy().ravel())
fmin.extend(results["f_min"].to_numpy().ravel())
fmax.extend(results["f_max"].to_numpy().ravel())
tmin_list.extend(results["t_min"].to_numpy().ravel())
tmax_list.extend(results["t_max"].to_numpy().ravel())
fmin_list.extend(results["f_min"].to_numpy().ravel())
fmax_list.extend(results["f_max"].to_numpy().ravel())
mean_crossval_scores_array = np.array(mean_crossval_scores, float)
del mean_crossval_scores
fig, ax = plt.subplots(constrained_layout=True)
Expand All @@ -466,10 +465,10 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
img = _imshow_tf(
mean_crossval_scores_array,
ax,
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
tmin=tmin_list,
tmax=tmax_list,
fmin=fmin_list,
fmax=fmax_list,
vmin=vmin,
vmax=vmax,
)
Expand All @@ -478,16 +477,16 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non
)
for freq_range_name, bins in freq_name_to_bins_map.items():
ax.text(
tmin[0],
tmin_list[0],
0.5 * bins[0][0] + 0.5 * bins[-1][1],
freq_range_name,
transform=offset,
ha="left",
va="center",
rotation=90,
)
ax.set_xlim([np.min(tmin), np.max(tmax)])
ax.set_ylim([np.min(fmin), np.max(fmax)])
ax.set_xlim([np.min(tmin_list), np.max(tmax_list)])
ax.set_ylim([np.min(fmin_list), np.max(fmax_list)])
ax.set_xlabel("Time (s)")
ax.set_ylabel("Frequency (Hz)")
cbar = fig.colorbar(
Expand Down
11 changes: 7 additions & 4 deletions mne_bids_pipeline/tests/test_documented.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import sys
from pathlib import Path
from types import CodeType

import yaml

Expand All @@ -22,11 +23,13 @@ def test_options_documented():
with open(root_path / "_config.py") as fid:
contents = fid.read()
contents = ast.parse(contents)
in_config = [
assert isinstance(contents, CodeType)
in_config_list = [
item.target.id for item in contents.body if isinstance(item, ast.AnnAssign)
]
assert len(set(in_config)) == len(in_config)
in_config = set(in_config)
assert len(set(in_config_list)) == len(in_config_list)
in_config = set(in_config_list)
del in_config_list
# ensure we clean our namespace correctly
config = _get_default_config()
config_names = set(d for d in dir(config) if not d.startswith("_"))
Expand Down Expand Up @@ -156,7 +159,7 @@ def test_datasets_in_doc():
assert n_found == count, f"{cp} ({n_found} != {count})"

# 3. Read examples from docs (being careful about tags we can't read)
class SafeLoaderIgnoreUnknown(yaml.SafeLoader):
class SafeLoaderIgnoreUnknown(yaml.SafeLoader): # type: ignore[misc]
def ignore_unknown(self, node):
return None

Expand Down
23 changes: 19 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,27 @@ convention = "numpy"
ignore_errors = false
scripts_are_modules = true
ignore_missing_imports = true
strict = false
modules = ["mne_bids_pipeline"]
disable_error_code = [
"import-untyped",
]

# TODO: Add these
[[tool.mypy.overrides]]
module = [
"mne_bids_pipeline.steps.*",
"mne_bids_pipeline.tests.*",
"docs.source.examples.*",
]
disable_error_code = [
"assignment",
"no-untyped-call",
"no-untyped-def",
"union-attr",
"type-arg",
"no-any-return",
]

[[tool.mypy.overrides]]
module = ["mne_bids_pipeline.steps.freesurfer.contrib.*"]
ignore_errors = true
Expand All @@ -147,11 +162,11 @@ ignore_errors = true
module = ['mne_bids_pipeline.tests.configs.*']
# Ignare: Need type annotation for "plot_psd_for_runs" [var-annotated]
disable_error_code = [
"var-annotated",
"var-annotated",
]

[[tool.mypy.overrides]]
module = ["mne_bids_pipeline.tests.configs.config_ERP_CORE"]
disable_error_code = [
"assignment",
]
"assignment",
]

0 comments on commit dc037e9

Please sign in to comment.