-
Notifications
You must be signed in to change notification settings - Fork 7
/
extract_spk_embedding.py
72 lines (54 loc) · 1.88 KB
/
extract_spk_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from pyannote.audio import Model
from pyannote.audio import Inference
import math
import multiprocessing
from random import shuffle
import torch.multiprocessing as mp
import torch
from glob import glob
from tqdm import tqdm
import logging
from data_conf import data_root
logging.getLogger("numba").setLevel(logging.WARNING)
def process_one(file_path, inference, device):
spk_emb_path = file_path.replace(".wav", ".spk.npy").replace(".mp3", ".spk.npy")
try:
np.load(spk_emb_path)
except:
embedding = inference(file_path)
np.save(spk_emb_path, embedding)
np.save(spk_emb_path, embedding)
def process_batch(filenames):
print("Loading models ...")
process_idx = mp.current_process()._identity
rank = process_idx[0] if len(process_idx) > 0 else 0
gpu_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{gpu_id}")
print(device)
model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM")
model = model.to(device)
inference = Inference(model, window="whole")
print("Loaded .")
with torch.no_grad():
for filename in tqdm(filenames):
process_one(filename, inference, device)
in_dir = data_root
if __name__ == "__main__":
filenames = glob(f"{in_dir}/**/*.wav", recursive=True) # [:10]
filenames += glob(f"{in_dir}/**/*.mp3", recursive=True) # [:10]
shuffle(filenames)
multiprocessing.set_start_method("spawn", force=True)
num_processes = 1
chunk_size = int(math.ceil(len(filenames) / num_processes))
chunks = [
filenames[i: i + chunk_size] for i in range(0, len(filenames), chunk_size)
]
print([len(c) for c in chunks])
processes = [
multiprocessing.Process(target=process_batch, args=(chunk,)) for chunk in chunks
]
for p in processes:
p.start()
for p in processes:
p.join()