-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
441 lines (329 loc) · 19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import numpy as np
import random
import torch
from torch.utils.data import Dataset
from moabb.datasets import AlexMI, BNCI2014001, BNCI2014004, BNCI2015001, BNCI2015004, Cho2017, Lee2019_MI, PhysionetMI
from moabb.paradigms import MotorImagery
# ----------------------------------------------------------------------------------------------------------------------
# Load data from MOABB
def get_AlexMI(subject=None, freq_min=8, freq_max=45, resample=250, channels=None, n_classes=None):
# PhD-Theses (french): https://theses.hal.science/tel-01196752
# data: https://zenodo.org/records/806023
# Electrode montage corresponding to the international 10-20 system
if channels is None:
channels = ["Fpz", "F7", "F3", "Fz", "F4", "F8", "T7", "C3", "Cz", "C4", "T8", "P7", "P3", "Pz", "P4", "P8"]
# Labels: right_hand, feet, rest
if n_classes is None:
n_classes = 3
if subject is None:
subject = list(range(1, 9))
dataset = AlexMI()
paradigm = MotorImagery(n_classes=n_classes, fmin=freq_min, fmax=freq_max, channels=channels, resample=resample)
data, labels, meta = paradigm.get_data(dataset=dataset, subjects=subject)
labels[np.where(labels == 'right_hand')] = 1
labels[np.where(labels == 'feet')] = 2
labels[np.where(labels == 'rest')] = 4
labels = labels.astype(int)
return data, labels, meta, channels
def get_BNCI2014001(subject=None, freq_min=8, freq_max=45, resample=250, channels=None, n_classes=None):
# https://lampx.tugraz.at/~bci/database/001-2014/description.pdf
# Electrode montage corresponding to the international 10-20 system
if channels is None:
channels = [
"Fz", "FC3", "FC1", "FCz", "FC2", "FC4", "C5", "C3", "C1", "Cz", "C2",
"C4", "C6", "CP3", "CP1", "CPz", "CP2", "CP4", "P1", "Pz", "P2", "POz"
]
# Labels: left_hand, right_hand, feet, tongue
if n_classes is None:
n_classes = 4
if subject is None:
subject = list(range(1, 10))
# This four class motor imagery data set was originally released as data set 2a of the BCI Competition IV
dataset = BNCI2014001()
paradigm = MotorImagery(n_classes=n_classes, fmin=freq_min, fmax=freq_max, channels=channels, resample=resample)
data, labels, meta = paradigm.get_data(dataset=dataset, subjects=subject)
labels[np.where(labels == 'left_hand')] = 0
labels[np.where(labels == 'right_hand')] = 1
labels[np.where(labels == 'feet')] = 2
labels[np.where(labels == 'tongue')] = 3
labels = labels.astype(int)
return data, labels, meta, channels
def get_BNCI2014004(subject=None, freq_min=8, freq_max=45, resample=250, channels=None, n_classes=None):
# https://lampx.tugraz.at/~bci/database/004-2014/description.pdf
# https://ieeexplore.ieee.org/document/4359220
# 3 bipolar channels (C3, Cz, C4) placed according to the extended 10-20 system
if channels is None:
channels = ["C3", "Cz", "C4"]
# Labels: left_hand, right_hand
if n_classes is None:
n_classes = 2
if subject is None:
subject = list(range(1, 10))
dataset = BNCI2014004()
paradigm = MotorImagery(n_classes=n_classes, fmin=freq_min, fmax=freq_max, channels=channels, resample=resample)
data, labels, meta = paradigm.get_data(dataset=dataset, subjects=subject)
labels[np.where(labels == 'left_hand')] = 0
labels[np.where(labels == 'right_hand')] = 1
labels = labels.astype(int)
return data, labels, meta, channels
def get_BNCI2015001(subject=None, freq_min=8, freq_max=45, resample=250, channels=None, n_classes=None):
# https://lampx.tugraz.at/~bci/database/001-2015/description.pdf
# 13 channels placed according to the 10-20 system
if channels is None:
channels = ["FC3", "FCz", "FC4", "C5", "C3", "C1", "Cz", "C2", "C4", "C6", "CP3", "CPz", "CP4"]
# Labels: right_hand, feet
if n_classes is None:
n_classes = 2
if subject is None:
subject = list(range(1, 13))
dataset = BNCI2015001()
paradigm = MotorImagery(n_classes=n_classes, fmin=freq_min, fmax=freq_max, channels=channels, resample=resample)
data, labels, meta = paradigm.get_data(dataset=dataset, subjects=subject)
labels[np.where(labels == 'right_hand')] = 1
labels[np.where(labels == 'feet')] = 2
labels = labels.astype(int)
return data, labels, meta, channels
def get_BNCI2015004(subject=None, freq_min=8, freq_max=45, resample=250, channels=None, n_classes=None):
# https://lampx.tugraz.at/~bci/database/004-2015/description.pdf
# 13 channels placed according to the 10-20 system
if channels is None:
channels = ["AFz", "F7", "F3", "Fz", "F4", "F8", "FC3", "FCz", "FC4", "T3", "C3", "Cz", "C4", "T4", "CP3",
"CPz", "CP4", "P7", "P5", "P3", "P1", "Pz", "P2", "P4", "P6", "P8", "PO3", "PO4", "O1", "O2"]
# Labels: right_hand, feet
if n_classes is None:
n_classes = 2
if subject is None:
subject = list(range(1, 10))
dataset = BNCI2015004()
paradigm = MotorImagery(n_classes=n_classes, fmin=freq_min, fmax=freq_max, channels=channels, resample=resample)
data, labels, meta = paradigm.get_data(dataset=dataset, subjects=subject)
# drop trials with label 'word_ass' 'subtraction', 'navigation'
idx = np.concatenate((np.where(labels == 'feet')[0], np.where(labels == 'right_hand')[0]))
data = data[idx]
labels = labels[idx]
meta = meta.iloc[idx]
labels[np.where(labels == 'right_hand')] = 1
labels[np.where(labels == 'feet')] = 2
labels = labels.astype(int)
return data, labels, meta, channels
def get_Cho2017(subject=None, freq_min=8, freq_max=45, resample=250, channels=None, n_classes=None):
# https://academic.oup.com/gigascience/article/6/7/gix034/3796323
# 64 channels placed according to the 10-10 system
if channels is None:
channels = ["Fp1", "AF7", "AF3", "F1", "F3", "F5", "F7", "FT7", "FC5", "FC3", "FC1", "C1", "C3", "C5", "T7",
"TP7", "CP5", "CP3", "CP1", "P1", "P3", "P5", "P7", "P9", "PO7", "PO3", "O1", "Iz", "Oz", "POz",
"Pz", "CPz", "Fpz", "Fp2", "AF8", "AF4", "AFz", "Fz", "F2", "F4", "F6", "F8", "FT8", "FC6", "FC4",
"FC2", "FCz", "Cz", "C2", "C4", "C6", "T8", "TP8", "CP6", "CP4", "CP2", "P2", "P4", "P6", "P8",
"P10", "PO8", "PO4", "O2"]
# Labels: left_hand, right_hand
if n_classes is None:
n_classes = 2
if subject is None:
subject = list(range(1, 53))
# ValueError: Invalid subject 32, 46, 49 given
subject.remove(32)
subject.remove(46)
subject.remove(49)
dataset = Cho2017()
paradigm = MotorImagery(n_classes=n_classes, fmin=freq_min, fmax=freq_max, channels=channels, resample=resample)
data, labels, meta = paradigm.get_data(dataset=dataset, subjects=subject)
labels[np.where(labels == 'left_hand')] = 0
labels[np.where(labels == 'right_hand')] = 1
labels = labels.astype(int)
return data, labels, meta, channels
def get_Lee2019_MI(subject=None, freq_min=8, freq_max=45, resample=250, channels=None, n_classes=None):
# https://academic.oup.com/gigascience/article/6/7/gix034/3796323
# 64 channels placed according to the 10-10 system
if channels is None:
channels = ['AF3', 'AF4', 'AF7', 'AF8', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5',
'CP6', 'CPz', 'Cz', 'F10', 'F3', 'F4', 'F7', 'F8', 'F9', 'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6',
'FT10', 'FT9', 'Fp1', 'Fp2', 'Fz', 'O1', 'O2', 'Oz', 'P1', 'P2', 'P3', 'P4', 'P7', 'P8', 'PO10',
'PO3', 'PO4', 'PO9', 'POz', 'Pz', 'T7', 'T8', 'TP10', 'TP7', 'TP9']
# Labels: left_hand, right_hand
if n_classes is None:
n_classes = 2
if subject is None:
subject = list(range(1, 55))
dataset = Lee2019_MI()
paradigm = MotorImagery(n_classes=n_classes, fmin=freq_min, fmax=freq_max, channels=channels, resample=resample)
data, labels, meta = paradigm.get_data(dataset=dataset, subjects=subject)
labels[np.where(labels == 'left_hand')] = 0
labels[np.where(labels == 'right_hand')] = 1
labels = labels.astype(int)
return data, labels, meta, channels
def get_PhysionetMI(subject=None, freq_min=8, freq_max=45, resample=250, channels=None, n_classes=None):
# https://academic.oup.com/gigascience/article/6/7/gix034/3796323
# 64 electrodes as per the international 10-10 system
# (excluding electrodes Nz, F9, F10, FT9, FT10, A1, A2, TP9, TP10, P9, and P10)
if channels is None:
channels = ['Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4',
'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'T9', 'T7', 'C5', 'C3',
'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'T10', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6',
'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8',
'O1', 'Oz', 'O2', 'Iz']
# Labels: left_hand, right_hand
if n_classes is None:
n_classes = 4
if subject is None:
subject = list(range(1, 110))
dataset = PhysionetMI()
paradigm = MotorImagery(n_classes=n_classes, fmin=freq_min, fmax=freq_max, channels=channels, resample=resample)
data, labels, meta = paradigm.get_data(dataset=dataset, subjects=subject)
labels[np.where(labels == 'left_hand')] = 0
labels[np.where(labels == 'right_hand')] = 1
labels[np.where(labels == 'feet')] = 2
labels[np.where(labels == 'rest')] = 4
# booth hands -> soft labels? [.5, .5, 0, 0, 0]
labels[np.where(labels == 'hands')] = 5
labels = labels.astype(int)
# data[:, np.array([19, 21, 23, 28, 29, 30, 31, 32, 33, 34, 39, 41, 43])]
return data, labels, meta, channels
# ----------------------------------------------------------------------------------------------------------------------
# Data loader
class SeqDataset(Dataset):
def __init__(self, dim_token, num_tokens_per_channel, reduce_num_chs_to=False, augmentation=[]):
self.num_tokens_per_channel = num_tokens_per_channel
self.dim_token = dim_token
self.list_data_sets = []
self.list_channel_names = []
self.list_labels = []
# list of tuples with (trial_data, trial_label, index_data_set)
self.list_trials = []
self.int_pos_channels_per_data_set = []
self.dict_channels = {}
# if cls-token should be learnable or not zero it can be overwritten by the model
self.cls = torch.zeros(1, dim_token)
# drop random input tokens
self.reduce_num_chs_to = reduce_num_chs_to
if len(set(augmentation) - {'time_shifts', 'DC_shifts', 'amplitude_scaling', 'noise'}) != 0:
no_aug = str(set(augmentation) - {'time_shifts', 'DC_shifts', 'amplitude_scaling', 'noise'})
raise ValueError(no_aug + ' is not supported as data augmentation')
self.augmentation = augmentation
def append_data_set(self, data_set, channel_names, label):
"""
Note: All data is loaded into RAM, which can be a problem with large amounts of data.
If it fits, it's faster.
data_set: np.array of size Trials x Channels x Time
channel_names: list
label: np.array of size Trials
"""
if data_set.shape[0] == label.shape[0] and data_set.shape[1] == len(channel_names):
self.list_data_sets += [data_set]
self.list_channel_names += [channel_names]
self.list_labels += [label]
else:
raise ValueError('Append data set is not possible, size does not match!')
def prepare_data_set(self, set_pos_channels=None):
# set_pos_channels (dictionary int_pos_channels_per_data_set): to copy int. channel position from existing
# Dataset (e.g. to ensure train and test datasets return the same position)
# list_trial = list of tuples with (trial_data, trial_label, index_data_set), all as tensors
self.list_trials = [(torch.from_numpy(data[idx]).float(), torch.LongTensor([label[idx]]),
torch.LongTensor([idx_ds]))
for idx_ds, (data, label) in enumerate(zip(self.list_data_sets, self.list_labels))
for idx in range(data.shape[0])]
unique_channel_names = list(np.unique(sum(self.list_channel_names, [])))
if set_pos_channels is not None:
# check if there are new channels:
new_channels = list(set(unique_channel_names) - set(set_pos_channels.keys()))
if len(new_channels) == 0:
self.dict_channels = set_pos_channels
else:
print('Following new channels are added: ' + str(new_channels))
raise ValueError('There are some new Channels')
else:
# CLS token has always position 0 -> pos channel start at 1
self.dict_channels = {key: torch.IntTensor([*range(i * self.num_tokens_per_channel + 1,
(i + 1) * self.num_tokens_per_channel + 1)])
for i, key in enumerate(unique_channel_names)}
self.int_pos_channels_per_data_set = [torch.cat(([self.dict_channels[key].unsqueeze(dim=0)
for key in channel_names]), dim=0)
for channel_names in self.list_channel_names]
labels = np.array([int(trial[1]) for trial in self.list_trials])
num_labels = [(i, np.where(labels == i)[0].shape) for i in set(labels)]
print(num_labels)
# free some memory
# self.list_data_sets, self.list_channel_names, self.list_labels = None, None, None
def __len__(self):
return len(self.list_trials)
def __getitem__(self, idx):
"""
dim_time size: #token x dim batch
label size: dim batch
int_pos size: dim batch
"""
dim_time = self.num_tokens_per_channel * self.dim_token
if 'time_shifts' in self.augmentation:
data = torch.cat((self.cls, self.list_trials[idx][0][:,
(st := random.randint(0, self.list_trials[idx][0].shape[1] - dim_time - 1)):
st + dim_time].reshape(-1, self.dim_token)), dim=0)
else:
st = (self.list_trials[idx][0].shape[1] - dim_time - 1) // 2
data = torch.cat((self.cls, self.list_trials[idx][0][:, st: st + dim_time].reshape(-1, self.dim_token)),
dim=0)
if 'DC_shifts' in self.augmentation:
data += (torch.rand(1) * 0.2 - 0.1)
if 'amplitude_scaling' in self.augmentation:
data *= (torch.rand(1) * 0.2 + 0.9)
if 'noise' in self.augmentation:
data += torch.normal(mean=0, std=0.1, size=data.size())
label = self.list_trials[idx][1]
# cls-token has pos. 0
int_pos = torch.cat((torch.IntTensor([0]), self.int_pos_channels_per_data_set[self.list_trials[idx][2]][:,
:self.num_tokens_per_channel].flatten()), dim=0)
if self.reduce_num_chs_to and data.size(0) > self.reduce_num_chs_to:
idx_channels = np.arange(1, data.size(0)//self.num_tokens_per_channel)
np.random.shuffle(idx_channels)
idx_channels = idx_channels[:self.reduce_num_chs_to]
idx = np.array(sum([list(range(i*self.num_tokens_per_channel,
i*self.num_tokens_per_channel+self.num_tokens_per_channel))
for i in idx_channels], []))
idx = np.concatenate((np.array([0]), idx))
return data[idx], label, int_pos[idx]
else:
# data: tensor, label: tensor, int_pos: tensor
return data, label, int_pos
@staticmethod
def my_collate(batch):
# Converts the output of the generator into the appropriate form
# https://discuss.pytorch.org/t/how-to-create-a-dataloader-with-variable-size-input/8278/3
num_token_per_trial = [item[0].size(0) for item in batch]
unique_num_token_within_batch = sorted(set(num_token_per_trial))
data = [torch.empty((0, num_tok, batch[0][0].size(1))) for num_tok in unique_num_token_within_batch]
label = [torch.empty(0) for num_tok in unique_num_token_within_batch]
int_pos = [torch.empty((0, num_tok)) for num_tok in unique_num_token_within_batch]
unique_num_token_within_batch = np.array(list(unique_num_token_within_batch))
mini_batch_idx = [np.where(unique_num_token_within_batch == num_tok)[0][0]
for num_tok in num_token_per_trial]
for i, item in enumerate(batch):
data[mini_batch_idx[i]] = torch.cat((data[mini_batch_idx[i]], item[0].unsqueeze(0)), dim=0)
label[mini_batch_idx[i]] = torch.cat((label[mini_batch_idx[i]], item[1]), dim=0)
int_pos[mini_batch_idx[i]] = torch.cat((int_pos[mini_batch_idx[i]], item[2].unsqueeze(0)), dim=0)
return {'patched_eeg_token': data, 'labels': label, 'pos_as_int': int_pos}
# ----------------------------------------------------------------------------------------------------------------------
# scaling methods
def scale(mne_epochs):
return (mne_epochs - np.mean(mne_epochs, axis=2, keepdims=True)) / (np.max(mne_epochs, axis=2, keepdims=True) -
np.min(mne_epochs, axis=2, keepdims=True))
def zero_mean_unit_var(mne_epochs, meta_data):
for sub in list(set(meta_data['subject'])):
for session in list(set(meta_data['session'])):
data = mne_epochs[np.where((meta_data['subject'] == sub) & (meta_data['session'] == session))]
mne_std = data.transpose(1, 0, 2).reshape(data.shape[1], data.shape[0] * data.shape[2]).std(axis=1)
mne_mean = data.transpose(1, 0, 2).reshape(data.shape[1], data.shape[0] * data.shape[2]).mean(axis=1)
mne_std = np.expand_dims(mne_std, axis=1)
mne_mean = np.expand_dims(mne_mean, axis=1)
data = (data - mne_mean) / mne_std
mne_epochs[np.where((meta_data['subject'] == sub) & (meta_data['session'] == session))] = data
return mne_epochs
# ----------------------------------------------------------------------------------------------------------------------
# train test split
def train_test_split(data, labels, meta, test_size=0.05):
idx = np.arange(data.shape[0])
np.random.shuffle(idx)
train_data = data[idx[:int(idx.shape[0] * (1 - test_size))]]
train_labels = labels[idx[:int(idx.shape[0] * (1 - test_size))]]
train_meta = meta.iloc[idx[:int(idx.shape[0] * (1 - test_size))]]
test_data = data[idx[int(idx.shape[0] * (1 - test_size)):]]
test_labels = labels[idx[int(idx.shape[0] * (1 - test_size)):]]
test_meta = meta.iloc[idx[int(idx.shape[0] * (1 - test_size)):]]
return train_data, train_labels, train_meta, test_data, test_labels, test_meta