-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
72 lines (61 loc) · 2.43 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import itertools
import numpy as np
import librosa
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# helper functions
def preproces(fn_wav):
y, sr = librosa.load(fn_wav, mono=True, duration=5)
chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)
rmse = librosa.feature.rms(y=y)
spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
spectral_bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)
rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
zcr = librosa.feature.zero_crossing_rate(y)
mfcc = librosa.feature.mfcc(y=y, sr=sr)
feature_row = {
'chroma_stft': np.mean(chroma_stft),
'rmse': np.mean(rmse),
'spectral_centroid': np.mean(spectral_centroid),
'spectral_bandwidth': np.mean(spectral_bandwidth),
'rolloff': np.mean(rolloff),
'zero_crossing_rate': np.mean(zcr),
}
for i, c in enumerate(mfcc):
feature_row[f'mfcc{i+1}'] = np.mean(c)
return feature_row
class CoughNet(torch.nn.Module):
def __init__(self, input_size):
super(CoughNet, self).__init__()
self.l1 = torch.nn.Linear(input_size, 512)
self.l2 = torch.nn.Linear(512, 256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10)
self.l6 = torch.nn.Linear(10, 2)
def forward(self, x):
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
x = torch.relu(self.l3(x))
x = torch.relu(self.l4(x))
x = torch.relu(self.l5(x))
x = self.l6(x)
return x
# https://deeplizard.com/learn/video/0LhiS6yu2qQ
def plot_confusion_matrix(targets, predictions, classes):
# calculate normalized confusion matrix
cm = confusion_matrix(targets, predictions)
cm = cm.astype(np.float) / cm.sum(axis=1)[:, np.newaxis]
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], '.2f'), horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')