-
Notifications
You must be signed in to change notification settings - Fork 1
/
toruch_audio.py
60 lines (46 loc) · 1.8 KB
/
toruch_audio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# %matplotlib inline
import torch
import torchaudio
import re
import matplotlib
import matplotlib.pyplot as plt
import IPython
import requests
print(torch.__version__)
print(torchaudio.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
#print("Sample Rate:", bundle.sample_rate)
#print("Labels:", bundle.get_labels())
model = bundle.get_model().to(device)
#print(model.__class__)
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, ignore):
super().__init__()
self.labels = labels
self.ignore = ignore
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i not in self.ignore]
return ''.join([self.labels[i] for i in indices])
def audiototext(SPEECH_FILE):
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
with torch.inference_mode():
features, _ = model.extract_features(waveform)
with torch.inference_mode():
emission, _ = model(waveform)
decoder = GreedyCTCDecoder(labels=bundle.get_labels(), ignore=(0, 1, 2, 3),)
transcript = decoder(emission[0])
transcript = re.sub('[^a-zA-Z0-9 \n\.]', ' ', transcript).lower()
return transcript