Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import abc
import logging
from operator import methodcaller
from typing import List, Optional, Tuple
from typing import List, Literal, Optional, Tuple

import mne
import numpy as np
import pandas as pd
from mne_bids.path import _find_matching_sidecar
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import FunctionTransformer

from moabb.datasets.base import BaseDataset
from moabb.datasets.base import BaseBIDSDataset, BaseDataset
from moabb.datasets.bids_interface import StepType
from moabb.datasets.preprocessing import (
EpochsToEvents,
Expand Down Expand Up @@ -232,6 +233,7 @@ def get_data( # noqa: C901
return_raws=False,
cache_config=None,
postprocess_pipeline=None,
additional_metadata: Literal["default", "all"] | list[str] = "default",
):
"""
Return the data for a list of subject.
Expand Down Expand Up @@ -265,6 +267,13 @@ def get_data( # noqa: C901
This pipeline must return an ``np.ndarray``.
This pipeline must be "fixed" because it will not be trained,
i.e. no call to ``fit`` will be made.
additional_metadata: Literal["default", "all"] | list[str]
Additional metadata to be loaded if return_epochs=True.
If "default", the default metadata will be loaded containing containing
`subject`, `session` and `run`. If "all", all columns of the `events.tsv`
file will be loaded. A list of column names can be passed to just
select these columns in addition to the three default values mentioned
before.

Returns
-------
Expand Down Expand Up @@ -306,6 +315,7 @@ def get_data( # noqa: C901
for session, runs in sessions.items():
for run in runs.keys():
proc = [data_i[subject][session][run] for data_i in data]

if any(obj is None for obj in proc):
# this mean the run did not contain any selected event
# go to next
Expand All @@ -321,6 +331,38 @@ def get_data( # noqa: C901
if len(self.filters) == 1
else mne.concatenate_epochs(proc)
)

# prepare additional metadata
if additional_metadata != "default":
if not isinstance(dataset, BaseBIDSDataset):
raise TypeError(
"Additional_metadata can only be used with BIDS datasets."
)

dm = load_bids_event_metadata(
dataset, subject=subject, session=session, run=run
)

# stack for multiple bandpass filtered versions
dm = pd.concat(
[
dm.copy().assign(filter=i)
for i in range(len(self.filters))
],
ignore_index=True,
)

if additional_metadata == "all":
pass
elif isinstance(additional_metadata, list):
dm = dm[
["session", "subject", "run"] + additional_metadata
]
else:
raise ValueError(
"Additional_metadata must be 'default', all' or a list of column names"
)

elif return_raws:
assert all(len(proc[0]) == len(p) for p in proc[1:])
n = 1
Expand Down Expand Up @@ -350,16 +392,22 @@ def get_data( # noqa: C901
met["subject"] = subject
met["session"] = session
met["run"] = run

metadata.append(met)

if return_epochs:
x.metadata = (
met.copy()
if len(self.filters) == 1
else pd.concat(
[met.copy()] * len(self.filters), ignore_index=True
if additional_metadata == "default":
x.metadata = (
met.copy()
if len(self.filters) == 1
else pd.concat(
[met.copy()] * len(self.filters), ignore_index=True
)
)
)
else:
x.metadata = dm
# also overwrite in the metadata list
metadata[-1] = dm
X.append(x)
labels.append(lbs)

Expand Down Expand Up @@ -556,3 +604,30 @@ def scoring(self):
def _get_events_pipeline(self, dataset):
event_id = self.used_events(dataset)
return RawToEvents(event_id=event_id, interval=dataset.interval)


def load_bids_event_metadata(
data_set: BaseBIDSDataset, subject: str, session: str, run: str
) -> pd.DataFrame:
bids_paths = data_set.bids_paths(subject)

# select only with matching session and run
bids_path_selected = [
pth
for pth in bids_paths
if f"ses-{session}" in pth.basename and f"run-{run}" in pth.basename
]

if len(bids_path_selected) > 1:
raise ValueError("More than one matching BIDS path found.")
bids_path = bids_path_selected[0]

events_fname = _find_matching_sidecar(
bids_path, suffix="events", extension=".tsv", on_error="warn"
)

dm = pd.read_csv(events_fname, sep="\t").assign(
subject=subject, session=session, run=run
)

return dm
75 changes: 75 additions & 0 deletions moabb/tests/test_paradigms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from mne.io import BaseRaw

from moabb.datasets import BNCI2014_001
from moabb.datasets.base import (
LocalBIDSDataset,
)
from moabb.datasets.fake import FakeDataset
from moabb.paradigms import (
CVEP,
Expand Down Expand Up @@ -1237,3 +1240,75 @@ def test_epochs(self, epochs_labels_metadata, dataset):
np.testing.assert_array_almost_equal(
epo.get_data()[0, :, 0] * dataset.unit_factor, X
)


class TestMetadata:

@pytest.fixture(scope="class")
def cached_dataset_root(self, tmpdir_factory):
root = tmpdir_factory.mktemp("fake_bids")
dataset = FakeDataset(
event_list=["fake1", "fake2"], n_sessions=2, n_subjects=2, n_runs=1
)
dataset.get_data(cache_config=dict(save_raw=True, overwrite_raw=False, path=root))
return root / "MNE-BIDS-fake-dataset-imagery-2-2--60--120--fake1-fake2--c3-cz-c4"

def test_additional_metadata_extracts(self, cached_dataset_root):

# --- The tsv files have metadata which would contain the following
#
# onset duration trial_type value sample
# 0.0078125 3.0 fake1 1 1
# 1.984375 3.0 fake2 2 254
# 3.96875 3.0 fake1 1 508
# 5.953125 3.0 fake2 2 762
#
# --- While onset, duration and trial_type, are implicitly available
# --- by the epoch design, we could want `value` and or `sample` as well

dataset = LocalBIDSDataset(
cached_dataset_root,
events={"fake1": 1, "fake2": 2},
interval=[0, 3],
paradigm="imagery",
)
paradigm = MotorImagery()

epo1, labels1, metadata1 = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=True,
)

epo2, labels2, metadata2 = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=True,
additional_metadata="all",
)

epo3, labels3, metadata3 = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=True,
additional_metadata=["value"],
)

epo4, labels4, metadata4 = paradigm.get_data(
dataset=dataset,
subjects=["1"],
return_epochs=True,
additional_metadata=["value", "duration"],
)

assert epo1 == epo2 == epo3
assert (labels1 == labels2).all()
assert (labels2 == labels3).all()

assert "value" in metadata2.columns
assert "sample" in metadata2.columns
assert "value" in metadata3.columns
assert "value" in metadata4.columns
assert "duration" in metadata4.columns
assert "sample" not in metadata3.columns
assert "sample" not in metadata4.columns
Loading