-
Notifications
You must be signed in to change notification settings - Fork 183
/
plot_cross_session_motor_imagery.py
127 lines (102 loc) · 3.96 KB
/
plot_cross_session_motor_imagery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
===========================
Cross-Session Motor Imagery
===========================
This example show how to perform a cross session motor imagery analysis on the
very popular dataset 2a from the BCI competition IV.
We will compare two pipelines :
- CSP+LDA
- Riemannian Geometry+Logistic Regression
We will use the LeftRightImagery paradigm. This will restrict the analysis
to two classes (left hand versus right hand) and use AUC as metric.
The cross session evaluation context will evaluate performance using a leave
one session out cross-validation. For each session in the dataset, a model
is trained on every other session and performance are evaluated on the current
session.
"""
# Authors: Alexandre Barachant <[email protected]>
# Sylvain Chevallier <[email protected]>
#
# License: BSD (3-clause)
import matplotlib.pyplot as plt
import seaborn as sns
from mne.decoding import CSP
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
import moabb
from moabb.datasets import BNCI2014_001
from moabb.evaluations import CrossSessionEvaluation
from moabb.paradigms import LeftRightImagery
moabb.set_log_level("info")
##############################################################################
# Create Pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.
#
# The CSP implementation is based on the MNE implementation. We selected 8 CSP
# components, as usually done in the literature.
#
# The Riemannian geometry pipeline consists in covariance estimation, tangent
# space mapping and finally a logistic regression for the classification.
pipelines = {}
pipelines["CSP+LDA"] = make_pipeline(CSP(n_components=8), LDA())
pipelines["RG+LR"] = make_pipeline(
Covariances(), TangentSpace(), LogisticRegression(solver="lbfgs")
)
##############################################################################
# Evaluation
# ----------
#
# We define the paradigm (LeftRightImagery) and the dataset (BNCI2014_001).
# The evaluation will return a DataFrame containing a single AUC score for
# each subject / session of the dataset, and for each pipeline.
#
# Results are saved into the database, so that if you add a new pipeline, it
# will not run again the evaluation unless a parameter has changed. Results can
# be overwritten if necessary.
paradigm = LeftRightImagery()
# Because this is being auto-generated we only use 2 subjects
dataset = BNCI2014_001()
dataset.subject_list = dataset.subject_list[:2]
datasets = [dataset]
overwrite = False # set to True if we want to overwrite cached results
evaluation = CrossSessionEvaluation(
paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite
)
results = evaluation.process(pipelines)
print(results.head())
##############################################################################
# Plot Results
# ----------------
#
# Here we plot the results. We first make a pointplot with the average
# performance of each pipeline across session and subjects.
# The second plot is a paired scatter plot. Each point representing the score
# of a single session. An algorithm will outperform another is most of the
# points are in its quadrant.
fig, axes = plt.subplots(1, 2, figsize=[8, 4], sharey=True)
sns.stripplot(
data=results,
y="score",
x="pipeline",
ax=axes[0],
jitter=True,
alpha=0.5,
zorder=1,
palette="Set1",
)
sns.pointplot(data=results, y="score", x="pipeline", ax=axes[0], palette="Set1")
axes[0].set_ylabel("ROC AUC")
axes[0].set_ylim(0.5, 1)
paired = results.pivot_table(
values="score", columns="pipeline", index=["subject", "session"]
)
paired = paired.reset_index()
sns.regplot(data=paired, y="RG+LR", x="CSP+LDA", ax=axes[1], fit_reg=False)
axes[1].plot([0, 1], [0, 1], ls="--", c="k")
axes[1].set_xlim(0.5, 1)
plt.show()