-
Notifications
You must be signed in to change notification settings - Fork 2
/
data_utils.py
116 lines (99 loc) · 3.6 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import random
import os
import sys
sys.path.append('../')
from scipy.fftpack import fft
from scipy.signal import resample, correlate
def computeFFT(signals, n):
"""
Args:
signals: EEG signals, (number of channels, number of data points)
n: length of positive frequency terms of fourier transform
Returns:
FT: log amplitude of FFT of signals, (number of channels, number of data points)
P: phase spectrum of FFT of signals, (number of channels, number of data points)
"""
# fourier transform
fourier_signal = fft(signals, n=n, axis=-1) # FFT on the last dimension
# only take the positive freq part
idx_pos = int(np.floor(n/2))
fourier_signal = fourier_signal[:, :idx_pos]
amp = np.abs(fourier_signal)
amp[amp == 0.] = 1e-8 # avoid log of 0
FT = np.log(amp)
#FT = amp
P = np.angle(fourier_signal)
return FT, P
def getOrderedChannels(file_name, verbose, labels_object, channel_names):
labels = list(labels_object)
for i in range(len(labels)):
labels[i] = labels[i].split('-')[0]
ordered_channels = []
for ch in channel_names:
try:
ordered_channels.append(labels.index(ch))
except:
if (verbose):
print(file_name + " failed to get channel " + ch)
raise Exception("channel not match")
return ordered_channels
def resampleData(signals, to_freq=200, window_size=4):
"""
Resample signals from its original sampling freq to another freq
Args:
signals: EEG signal slice, (num_channels, num_data_points)
to_freq: Re-sampled frequency in Hz
window_size: time window in seconds
Returns:
resampled: (num_channels, resampled_data_points)
"""
num = int(to_freq * window_size)
resampled = resample(signals, num=num, axis=1)
return resampled
######## Graph related data utils ########
def keep_topk(adj_mat, top_k=3, directed=True):
""""
Helper function to sparsen the adjacency matrix by keeping top-k neighbors
for each node.
Args:
adj_mat: adjacency matrix, shape (num_nodes, num_nodes)
top_k: int
directed: whether or not a directed graph
Returns:
adj_mat: sparse adjacency matrix, directed graph
"""
# Set values that are not of top-k neighbors to 0:
adj_mat_noSelfEdge = adj_mat.copy()
for i in range(adj_mat_noSelfEdge.shape[0]):
adj_mat_noSelfEdge[i, i] = 0
top_k_idx = (-adj_mat_noSelfEdge).argsort(axis=-1)[:, :top_k]
mask = np.eye(adj_mat.shape[0], dtype=bool)
for i in range(0, top_k_idx.shape[0]):
for j in range(0, top_k_idx.shape[1]):
mask[i, top_k_idx[i, j]] = 1
if not directed:
mask[top_k_idx[i, j], i] = 1 # symmetric
adj_mat = mask * adj_mat
return adj_mat
def comp_xcorr(x, y, mode='valid', normalize=True):
"""
Compute cross-correlation between 2 1D signals x, y
Args:
x: 1D array
y: 1D array
mode: 'valid', 'full' or 'same',
refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlate.html
normalize: If True, will normalize cross-correlation
Returns:
xcorr: cross-correlation of x and y
"""
xcorr = correlate(x, y, mode=mode)
# the below normalization code refers to matlab xcorr function
cxx0 = np.sum(np.absolute(x)**2)
cyy0 = np.sum(np.absolute(y)**2)
if normalize and (cxx0 != 0) and (cyy0 != 0):
scale = (cxx0 * cyy0) ** 0.5
xcorr /= scale
return xcorr
######## Graph related data utils ########