Skip to content

Commit e97f926

Browse files
Remove redundant learning curve implementation paths (#963)
* Remove redundant learning curve implementation paths (#962) * Initial plan * Remove redundant learning curve implementation - Remove data_size and n_perms parameters from WithinSessionEvaluation - Remove get_data_size_subsets(), score_explicit(), and _evaluate_learning_curve() methods - Simplify evaluate() to always use _evaluate() path - Update docstring to recommend cv_class=LearningCurveSplitter - Update all examples to use new LearningCurveSplitter API - Update tests to use new API Co-authored-by: bruAristimunha <42702466+bruAristimunha@users.noreply.github.com> * Fix pre-commit linting issues - Apply black formatting (line length 90) - Fix import sorting with isort - Remove unused imports (Optional, StratifiedShuffleSplit) - Apply ruff fixes Co-authored-by: bruAristimunha <42702466+bruAristimunha@users.noreply.github.com> * Apply final black formatting fixes - Remove extra blank lines per black style guide Co-authored-by: bruAristimunha <42702466+bruAristimunha@users.noreply.github.com> * Fix isort import ordering - Apply isort to properly order imports in all changed files Co-authored-by: bruAristimunha <42702466+bruAristimunha@users.noreply.github.com> * Resolve black/isort formatting conflict - Apply black formatting after isort - Remove extra blank line after imports * Add single-class safeguard for LearningCurveSplitter - Skip splits where training set collapses to single class - Log warning when splits are skipped due to single-class training sets - Fix ArrowStringArray shuffle warnings by converting to numpy arrays - Update tests to call process() since validation happens at evaluation time - Fix isort import ordering in learning curve examples --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: bruAristimunha <42702466+bruAristimunha@users.noreply.github.com> Co-authored-by: Bru <b.aristimunha@gmail.com> * Add whats_new entries for learning curve unification Document the learning curve and splitter improvements: - cv_class and cv_kwargs parameters for all evaluation classes - LearningCurveSplitter for sklearn-compatible learning curves - Removal of data_size and n_perms from WithinSessionEvaluation - Automatic metadata columns for learning curve results - Centralized CV resolution with _resolve_cv() method - Removal of redundant learning curve methods * Add parametrized test for LearningCurveSplitter as cv_class Test that LearningCurveSplitter can be used as cv_class parameter for all main splitters: WithinSessionSplitter, WithinSubjectSplitter, CrossSessionSplitter, and CrossSubjectSplitter. * updating the python files * solving problem with new pandas * updating the splits to make sure about the logic * iteration * Simplify the logic * solving the group problem * iteration 2 * simplify and simplify * Update whats_new.rst with _load_data, _get_nchan, and splitter hoisting Document the extraction of _load_data() and _get_nchan() helpers into BaseEvaluation, the move of _pipeline_requires_epochs() to utils.py, and the WithinSessionSplitter creation hoisted outside the session loop. --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
1 parent dd3811a commit e97f926

15 files changed

+656
-516
lines changed

docs/source/whats_new.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ Enhancements
3232
- Ability to parameterize the scoring rule of paradigms (:gh:`948` by `Ethan Davis`_)
3333
- Extend scoring configuration to accept lists of metric callables, scorer objects, and tuple kwargs (e.g., `needs_proba`/`needs_threshold`) for multi-metric evaluations (:gh:`948` by `Ethan Davis`_ and `Bruno Aristimunha`_)
3434
- Implement :class:`moabb.evaluations.WithinSubjectSplitter` for k-fold cross-validation within each subject across all sessions (by `Bruno Aristimunha`_)
35+
- Add ``cv_class`` and ``cv_kwargs`` parameters to all evaluation classes (WithinSessionEvaluation, CrossSessionEvaluation, CrossSubjectEvaluation) for custom cross-validation strategies (:gh:`963` by `Bruno Aristimunha`_)
36+
- Implement :class:`moabb.evaluations.splitters.LearningCurveSplitter` as a dedicated sklearn-compatible cross-validator for learning curves, enabling learning curve analysis with any evaluation type (:gh:`963` by `Bruno Aristimunha`_)
3537

3638
API changes
3739
~~~~~~~~~~~
@@ -42,6 +44,8 @@ API changes
4244
- Enable choice of online or offline CodeCarbon through the parameterization of `codecarbon_config` when instantiating a :class:`moabb.evaluations.base.BaseEvaluation` child class (:gh:`956` by `Ethan Davis`_)
4345
- Renamed stimulus channel from ``stim`` to ``STI`` in BNCI motor imagery and error-related potential datasets for clarity and BIDS compliance (by `Bruno Aristimunha`_).
4446
- Added four new BNCI P300/ERP dataset classes: :class:`moabb.datasets.BNCI2015_009` (AMUSE), :class:`moabb.datasets.BNCI2015_010` (RSVP), :class:`moabb.datasets.BNCI2015_012` (PASS2D), and :class:`moabb.datasets.BNCI2015_013` (ErrP) (by `Bruno Aristimunha`_).
47+
- Removed ``data_size`` and ``n_perms`` parameters from :class:`moabb.evaluations.WithinSessionEvaluation`. Use ``cv_class=LearningCurveSplitter`` with ``cv_kwargs=dict(data_size=..., n_perms=...)`` instead (:gh:`963` by `Bruno Aristimunha`_)
48+
- Learning curve results now automatically include "data_size" and "permutation" columns when using ``LearningCurveSplitter`` (:gh:`963` by `Bruno Aristimunha`_)
4549

4650
Requirements
4751
~~~~~~~~~~~~
@@ -61,6 +65,8 @@ Bugs
6165
- Prevent Python mutable default argument when defining CodeCarbon configurations (:gh:`956` by `Ethan Davis`_)
6266
- Fix copytree FileExistsError in BrainInvaders2013a download by adding dirs_exist_ok=True (by `Bruno Aristimunha`_)
6367
- Ensure optional additional scoring columns in evaluation results (:gh:`957` by `Ethan Davis`_)
68+
- Fix pandas ``ArrowStringArray`` shuffle warning by converting ``.unique()`` results to numpy arrays in splitters, avoiding issues with newer pandas versions (:gh:`963` by `Bruno Aristimunha`_)
69+
- ``LearningCurveSplitter`` now skips training splits that collapse to a single class (e.g., with very small ``data_size``) and emits a ``RuntimeWarning`` instead of producing NaN results (:gh:`963` by `Bruno Aristimunha`_)
6470

6571
Code health
6672
~~~~~~~~~~~
@@ -69,6 +75,16 @@ Code health
6975

7076
- Persist docs/test CI MNE dataset cache across runs to reduce cold-cache downloads (:gh:`946` by `Bruno Aristimunha`_)
7177
- Refactor evaluation scoring into shared utility functions for future improvements (:gh:`948` by `Bruno Aristimunha`_)
78+
- Centralize CV resolution in BaseEvaluation with new ``_resolve_cv()`` method for consistent cross-validation handling across all evaluation types. Add ``_build_result()`` and ``_build_scored_result()`` helpers to centralize result dict construction across WithinSession, CrossSession, and CrossSubject evaluations, replacing manual dict assembly in each (:gh:`963` by `Bruno Aristimunha`_)
79+
- Remove redundant learning curve methods (``get_data_size_subsets()``, ``score_explicit()``, ``_evaluate_learning_curve()``) from WithinSessionEvaluation in favor of unified splitter-based approach (:gh:`963` by `Bruno Aristimunha`_)
80+
- Generic metadata column registration: ``LearningCurveSplitter`` declares a ``metadata_columns`` class attribute, and ``BaseEvaluation`` auto-detects it via ``hasattr(cv_class, "metadata_columns")`` instead of hardcoding class checks, making it extensible to future custom splitters (:gh:`963` by `Bruno Aristimunha`_)
81+
- Fix ``get_n_splits()`` delegation in ``WithinSessionSplitter`` and ``WithinSubjectSplitter`` to properly forward to the inner ``cv_class.get_n_splits()`` instead of hardcoding ``n_folds``, giving correct split counts when using custom CV classes like ``LearningCurveSplitter`` (:gh:`963` by `Bruno Aristimunha`_)
82+
- Remove duplicate ``get_inner_splitter_metadata()`` from ``WithinSessionSplitter``, ``WithinSubjectSplitter``, and ``CrossSubjectSplitter``. All splitters now store a ``_current_splitter`` reference, and ``BaseEvaluation._build_scored_result()`` reads metadata generically from it (:gh:`963` by `Bruno Aristimunha`_)
83+
- Extract ``_fit_cv()``, ``_maybe_save_model_cv()``, and ``_attach_emissions()`` into ``BaseEvaluation``, removing duplicated model-fitting, model-saving, and carbon-tracking boilerplate from ``WithinSessionEvaluation``, ``CrossSessionEvaluation``, and ``CrossSubjectEvaluation`` (:gh:`963` by `Bruno Aristimunha`_)
84+
- Extract ``_load_data()`` helper into ``BaseEvaluation`` to centralize data loading logic (epoch requirement checking and ``paradigm.get_data()`` call) that was duplicated across all three evaluation classes (:gh:`963` by `Bruno Aristimunha`_)
85+
- Extract ``_get_nchan()`` helper into ``BaseEvaluation`` to replace repeated channel count extraction (``X.info["nchan"] if isinstance(X, BaseEpochs) else X.shape[1]``) in all evaluation classes (:gh:`963` by `Bruno Aristimunha`_)
86+
- Move ``_pipeline_requires_epochs()`` from ``evaluations.py`` to ``utils.py`` for shared access by ``BaseEvaluation._load_data()`` (:gh:`963` by `Bruno Aristimunha`_)
87+
- Move ``WithinSessionSplitter`` creation outside the per-session loop in ``WithinSessionEvaluation``, since splitter parameters do not change per session (:gh:`963` by `Bruno Aristimunha`_)
7288

7389
Version 1.4.3 (Stable - PyPi)
7490
-------------------------------

examples/advanced_examples/plot_hinss2021_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def transform(self, X):
145145
# in approximately 8 times to the Cov+ElSel+TS+LDA pipeline.
146146

147147
print("Averaging the session performance:")
148-
print(results.groupby("pipeline").mean("score")[["score", "time"]])
148+
print(results.groupby("pipeline")[["score", "time"]].mean())
149149

150150
###############################################################################
151151
# Plot Results

examples/external/learning_curve_p300_external.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import moabb
3737
from moabb.datasets import BNCI2014_009
3838
from moabb.evaluations import WithinSessionEvaluation
39+
from moabb.evaluations.splitters import LearningCurveSplitter
3940
from moabb.paradigms import P300
4041

4142

@@ -114,8 +115,8 @@
114115
evaluation = WithinSessionEvaluation(
115116
paradigm=paradigm,
116117
datasets=datasets,
117-
data_size=data_size,
118-
n_perms=n_perms,
118+
cv_class=LearningCurveSplitter,
119+
cv_kwargs=dict(data_size=data_size, n_perms=n_perms),
119120
suffix="examples_lr",
120121
overwrite=overwrite,
121122
return_epochs=True,

examples/external/noplot_learning_curve_p300_external.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import moabb
3737
from moabb.datasets import BNCI2014_009
3838
from moabb.evaluations import WithinSessionEvaluation
39+
from moabb.evaluations.splitters import LearningCurveSplitter
3940
from moabb.paradigms import P300
4041

4142

@@ -115,8 +116,8 @@
115116
evaluation = WithinSessionEvaluation(
116117
paradigm=paradigm,
117118
datasets=datasets,
118-
data_size=data_size,
119-
n_perms=n_perms,
119+
cv_class=LearningCurveSplitter,
120+
cv_kwargs=dict(data_size=data_size, n_perms=n_perms),
120121
suffix="examples_lr",
121122
overwrite=overwrite,
122123
)

examples/learning_curve/noplot_learning_curve_p300_external.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import moabb
3737
from moabb.datasets import BNCI2014_009
3838
from moabb.evaluations import WithinSessionEvaluation
39+
from moabb.evaluations.splitters import LearningCurveSplitter
3940
from moabb.paradigms import P300
4041

4142

@@ -115,8 +116,8 @@
115116
evaluation = WithinSessionEvaluation(
116117
paradigm=paradigm,
117118
datasets=datasets,
118-
data_size=data_size,
119-
n_perms=n_perms,
119+
cv_class=LearningCurveSplitter,
120+
cv_kwargs=dict(data_size=data_size, n_perms=n_perms),
120121
suffix="examples_lr",
121122
overwrite=overwrite,
122123
)

examples/learning_curve/plot_learning_curve_motor_imagery.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import moabb
3434
from moabb.datasets import BNCI2014_001
3535
from moabb.evaluations import WithinSessionEvaluation
36+
from moabb.evaluations.splitters import LearningCurveSplitter
3637
from moabb.paradigms import LeftRightImagery
3738

3839

@@ -86,8 +87,8 @@
8687
datasets=datasets,
8788
suffix="examples",
8889
overwrite=overwrite,
89-
data_size=data_size,
90-
n_perms=n_perms,
90+
cv_class=LearningCurveSplitter,
91+
cv_kwargs=dict(data_size=data_size, n_perms=n_perms),
9192
)
9293

9394
results = evaluation.process(pipelines)

examples/learning_curve/plot_learning_curve_p300.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import moabb
3434
from moabb.datasets import BNCI2014_009
3535
from moabb.evaluations import WithinSessionEvaluation
36+
from moabb.evaluations.splitters import LearningCurveSplitter
3637
from moabb.paradigms import P300
3738

3839

@@ -89,8 +90,8 @@
8990
evaluation = WithinSessionEvaluation(
9091
paradigm=paradigm,
9192
datasets=datasets,
92-
data_size=data_size,
93-
n_perms=n_perms,
93+
cv_class=LearningCurveSplitter,
94+
cv_kwargs=dict(data_size=data_size, n_perms=n_perms),
9495
suffix="examples_lr",
9596
overwrite=overwrite,
9697
)

examples/tutorials/tutorial_4_adding_a_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _get_single_subject_data(self, subject):
123123
fs = data["fs"]
124124
ch_names = ["ch" + str(i) for i in range(8)] + ["stim"]
125125
ch_types = ["eeg" for i in range(8)] + ["stim"]
126-
info = mne.create_info(ch_names, fs, ch_types)
126+
info = mne.create_info(ch_names, float(np.squeeze(fs)), ch_types)
127127
raw = mne.io.RawArray(x, info)
128128

129129
sessions = {}

moabb/analysis/plotting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True):
476476
if simplify:
477477
effect_df.columns = effect_df.columns.map(_simplify_names)
478478
sig_df.columns = sig_df.columns.map(_simplify_names)
479-
annot_df = effect_df.copy()
479+
annot_df = effect_df.copy().astype(object)
480480
for row in annot_df.index:
481481
for col in annot_df.columns:
482482
if effect_df.loc[row, col] > 0:
@@ -575,10 +575,10 @@ def _marker(pval):
575575
_min = 0
576576
_max = 0
577577
for ind, d in enumerate(dsets):
578-
nsub = float(df_fw.loc[df_fw.dataset == d, "nsub"])
578+
nsub = df_fw.loc[df_fw.dataset == d, "nsub"].item()
579579
t_dof = nsub - 1
580580
ci.append(t.ppf(0.95, t_dof) / np.sqrt(nsub))
581-
v = float(df_fw.loc[df_fw.dataset == d, "smd"])
581+
v = df_fw.loc[df_fw.dataset == d, "smd"].item()
582582
if v > 0:
583583
p = df_fw.loc[df_fw.dataset == d, "p"].item()
584584
if p < 0.05:

moabb/evaluations/base.py

Lines changed: 184 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import logging
2+
import math
23
from abc import ABC, abstractmethod
4+
from time import perf_counter
5+
from uuid import uuid4
36
from warnings import warn
47

58
import pandas as pd
@@ -11,8 +14,14 @@
1114
from moabb.evaluations.utils import (
1215
Emissions,
1316
_convert_sklearn_params_to_optuna,
17+
_create_save_path,
1418
_create_scorer,
1519
_DictScorer,
20+
_ensure_fitted,
21+
_get_nchan,
22+
_pipeline_requires_epochs,
23+
_save_model_cv,
24+
_score_and_update,
1625
check_search_available,
1726
)
1827
from moabb.paradigms.base import BaseParadigm
@@ -144,6 +153,11 @@ def __init__(
144153
if additional_columns is None:
145154
self.additional_columns = []
146155

156+
if self.cv_class is not None and hasattr(self.cv_class, "metadata_columns"):
157+
for col in self.cv_class.metadata_columns:
158+
if col not in self.additional_columns:
159+
self.additional_columns.append(col)
160+
147161
if self.optuna and not optuna_available:
148162
raise ImportError("Optuna is not available. Please install it first.")
149163
if (self.time_out != 60 * 15) and not self.optuna:
@@ -222,9 +236,178 @@ def _resolve_cv(self, default_class, default_kwargs=None):
222236
cv_kwargs = {} if default_kwargs is None else dict(default_kwargs)
223237
else:
224238
cv_class = self.cv_class
225-
cv_kwargs = {} if self.cv_kwargs is None else dict(self.cv_kwargs)
239+
cv_kwargs = dict(self.cv_kwargs)
226240
return cv_class, cv_kwargs
227241

242+
def _load_data(
243+
self,
244+
dataset,
245+
run_pipes,
246+
process_pipeline,
247+
postprocess_pipeline,
248+
subjects=None,
249+
):
250+
"""Load data for an evaluation, handling epoch requirements.
251+
252+
Parameters
253+
----------
254+
dataset : BaseDataset
255+
The dataset to load.
256+
run_pipes : dict
257+
Pipelines to run (used to check epoch requirements).
258+
process_pipeline : Pipeline
259+
The processing pipeline.
260+
postprocess_pipeline : Pipeline | None
261+
Optional post-processing pipeline.
262+
subjects : list | None
263+
List of subjects to load. If None, loads all subjects.
264+
265+
Returns
266+
-------
267+
X : array-like or Epochs
268+
The loaded data.
269+
y : array-like
270+
The labels.
271+
metadata : DataFrame
272+
The metadata.
273+
"""
274+
requires_epochs = any(
275+
_pipeline_requires_epochs(clf) for clf in run_pipes.values()
276+
)
277+
return_epochs = True if requires_epochs else self.return_epochs
278+
kwargs = dict(
279+
dataset=dataset,
280+
return_epochs=return_epochs,
281+
return_raws=self.return_raws,
282+
cache_config=self.cache_config,
283+
postprocess_pipeline=postprocess_pipeline,
284+
process_pipelines=None if requires_epochs else [process_pipeline],
285+
)
286+
if subjects is not None:
287+
kwargs["subjects"] = subjects
288+
return self.paradigm.get_data(**kwargs)
289+
290+
@staticmethod
291+
def _get_nchan(X):
292+
"""Extract number of channels from data (Epochs or ndarray)."""
293+
return _get_nchan(X)
294+
295+
def _build_scored_result(
296+
self,
297+
dataset,
298+
subject,
299+
session,
300+
pipeline,
301+
n_samples,
302+
n_channels,
303+
duration,
304+
scorer,
305+
model,
306+
X_test,
307+
y_test,
308+
split_metadata=None,
309+
**extra,
310+
):
311+
"""Build a result dict and score it in one place."""
312+
metadata = {}
313+
if split_metadata is None:
314+
splitter = getattr(getattr(self, "cv", None), "_current_splitter", None)
315+
if splitter is not None and hasattr(splitter, "get_metadata"):
316+
split_metadata = splitter.get_metadata()
317+
if split_metadata:
318+
metadata.update(split_metadata)
319+
metadata.update(extra)
320+
res = self._build_result(
321+
dataset,
322+
subject,
323+
session,
324+
pipeline,
325+
n_samples,
326+
n_channels,
327+
duration,
328+
**metadata,
329+
)
330+
try:
331+
return _score_and_update(res, scorer, model, X_test, y_test)
332+
except ValueError as err:
333+
if self.error_score == "raise":
334+
raise err
335+
res["score"] = self.error_score
336+
return res
337+
338+
def _fit_cv(self, model, X_train, y_train, tracker=None):
339+
"""Fit a model for a CV fold with optional CodeCarbon tracking."""
340+
task_name = None
341+
emissions = math.nan
342+
if tracker is not None:
343+
task_name = str(uuid4())
344+
tracker.start_task(task_name)
345+
t_start = perf_counter()
346+
model.fit(X_train, y_train)
347+
duration = perf_counter() - t_start
348+
if tracker is not None:
349+
emissions_data = tracker.stop_task()
350+
emissions = emissions_data.emissions if emissions_data else math.nan
351+
_ensure_fitted(model)
352+
return duration, emissions, task_name
353+
354+
def _maybe_save_model_cv(
355+
self, model, dataset, subject, session, name, cv_ind, eval_type
356+
):
357+
"""Save model for a CV fold when saving is enabled."""
358+
if self.hdf5_path is None or not self.save_model:
359+
return
360+
model_save_path = _create_save_path(
361+
hdf5_path=self.hdf5_path,
362+
code=dataset.code,
363+
subject=subject,
364+
session=session,
365+
name=name,
366+
grid=self.search,
367+
eval_type=eval_type,
368+
)
369+
_save_model_cv(model=model, save_path=model_save_path, cv_index=str(cv_ind))
370+
371+
@staticmethod
372+
def _attach_emissions(res, emissions, task_name):
373+
res["carbon_emission"] = (1000 * emissions,)
374+
res["codecarbon_task_name"] = task_name
375+
376+
def _build_result(
377+
self,
378+
dataset,
379+
subject,
380+
session,
381+
pipeline,
382+
n_samples,
383+
n_channels,
384+
duration,
385+
**extra,
386+
):
387+
"""Build a result dictionary with all required columns.
388+
389+
This is the single place where the evaluation result schema is defined.
390+
All evaluation subclasses should use this instead of constructing the
391+
dict manually, so the schema stays consistent when columns are added
392+
or evaluations are merged.
393+
394+
Any ``additional_columns`` not provided via *extra* are defaulted to
395+
NaN so that ``Results.add()`` never fails on a missing key.
396+
"""
397+
res = {
398+
"time": duration,
399+
"dataset": dataset,
400+
"subject": subject,
401+
"session": session,
402+
"n_samples": n_samples,
403+
"n_channels": n_channels,
404+
"pipeline": pipeline,
405+
}
406+
for col in self.additional_columns:
407+
if col not in res:
408+
res[col] = extra.get(col, math.nan)
409+
return res
410+
228411
def process(self, pipelines, param_grid=None, postprocess_pipeline=None):
229412
"""Runs all pipelines on all datasets.
230413

0 commit comments

Comments
 (0)