-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
397 additions
and
374 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import datetime | ||
import importlib | ||
import os | ||
import re | ||
|
||
import torch | ||
from TTS.speaker_encoder.model import SpeakerEncoder | ||
from TTS.utils.generic_utils import check_argument | ||
|
||
|
||
def to_camel(text): | ||
text = text.capitalize() | ||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) | ||
|
||
|
||
def setup_model(c): | ||
model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'], | ||
c.model['lstm_dim'], c.model['num_lstm_layers']) | ||
return model | ||
|
||
|
||
def save_checkpoint(model, optimizer, model_loss, out_path, | ||
current_step, epoch): | ||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) | ||
checkpoint_path = os.path.join(out_path, checkpoint_path) | ||
print(" | | > Checkpoint saving : {}".format(checkpoint_path)) | ||
|
||
new_state_dict = model.state_dict() | ||
state = { | ||
'model': new_state_dict, | ||
'optimizer': optimizer.state_dict() if optimizer is not None else None, | ||
'step': current_step, | ||
'epoch': epoch, | ||
'loss': model_loss, | ||
'date': datetime.date.today().strftime("%B %d, %Y"), | ||
} | ||
torch.save(state, checkpoint_path) | ||
|
||
|
||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, | ||
current_step): | ||
if model_loss < best_loss: | ||
new_state_dict = model.state_dict() | ||
state = { | ||
'model': new_state_dict, | ||
'optimizer': optimizer.state_dict(), | ||
'step': current_step, | ||
'loss': model_loss, | ||
'date': datetime.date.today().strftime("%B %d, %Y"), | ||
} | ||
best_loss = model_loss | ||
bestmodel_path = 'best_model.pth.tar' | ||
bestmodel_path = os.path.join(out_path, bestmodel_path) | ||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format( | ||
model_loss, bestmodel_path)) | ||
torch.save(state, bestmodel_path) | ||
return best_loss | ||
|
||
|
||
def check_config_speaker_encoder(c): | ||
"""Check the config.json file of the speaker encoder""" | ||
check_argument('run_name', c, restricted=True, val_type=str) | ||
check_argument('run_description', c, val_type=str) | ||
|
||
# audio processing parameters | ||
check_argument('audio', c, restricted=True, val_type=dict) | ||
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) | ||
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) | ||
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) | ||
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') | ||
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') | ||
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) | ||
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) | ||
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) | ||
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) | ||
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) | ||
|
||
# training parameters | ||
check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str) | ||
check_argument('grad_clip', c, restricted=True, val_type=float) | ||
check_argument('epochs', c, restricted=True, val_type=int, min_val=1) | ||
check_argument('lr', c, restricted=True, val_type=float, min_val=0) | ||
check_argument('lr_decay', c, restricted=True, val_type=bool) | ||
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) | ||
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) | ||
check_argument('num_speakers_in_batch', c, restricted=True, val_type=int) | ||
check_argument('num_loader_workers', c, restricted=True, val_type=int) | ||
check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) | ||
|
||
# checkpoint and output parameters | ||
check_argument('steps_plot_stats', c, restricted=True, val_type=int) | ||
check_argument('checkpoint', c, restricted=True, val_type=bool) | ||
check_argument('save_step', c, restricted=True, val_type=int) | ||
check_argument('print_step', c, restricted=True, val_type=int) | ||
check_argument('output_path', c, restricted=True, val_type=str) | ||
|
||
# model parameters | ||
check_argument('model', c, restricted=True, val_type=dict) | ||
check_argument('input_dim', c['model'], restricted=True, val_type=int) | ||
check_argument('proj_dim', c['model'], restricted=True, val_type=int) | ||
check_argument('lstm_dim', c['model'], restricted=True, val_type=int) | ||
check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int) | ||
check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool) | ||
|
||
# in-memory storage parameters | ||
check_argument('storage', c, restricted=True, val_type=dict) | ||
check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) | ||
check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100) | ||
check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) | ||
|
||
# datasets - checking only the first entry | ||
check_argument('datasets', c, restricted=True, val_type=list) | ||
for dataset_entry in c['datasets']: | ||
check_argument('name', dataset_entry, restricted=True, val_type=str) | ||
check_argument('path', dataset_entry, restricted=True, val_type=str) | ||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) | ||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
# coding=utf-8 | ||
# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo | ||
# All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# Only support eager mode and TF>=2.0.0 | ||
# pylint: disable=no-member, invalid-name, relative-beyond-top-level | ||
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes | ||
''' voxceleb 1 & 2 ''' | ||
|
||
import os | ||
import sys | ||
import zipfile | ||
import subprocess | ||
import hashlib | ||
import pandas | ||
from absl import logging | ||
import tensorflow as tf | ||
import soundfile as sf | ||
|
||
gfile = tf.compat.v1.gfile | ||
|
||
SUBSETS = { | ||
"vox1_dev_wav": | ||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"], | ||
"vox1_test_wav": | ||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], | ||
"vox2_dev_aac": | ||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", | ||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"], | ||
"vox2_test_aac": | ||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"] | ||
} | ||
|
||
MD5SUM = { | ||
"vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", | ||
"vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", | ||
"vox1_test_wav": "185fdc63c3c739954633d50379a3d102", | ||
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312" | ||
} | ||
|
||
USER = { | ||
"user": "", | ||
"password": "" | ||
} | ||
|
||
speaker_id_dict = {} | ||
|
||
def download_and_extract(directory, subset, urls): | ||
"""Download and extract the given split of dataset. | ||
Args: | ||
directory: the directory where to put the downloaded data. | ||
subset: subset name of the corpus. | ||
urls: the list of urls to download the data file. | ||
""" | ||
if not gfile.Exists(directory): | ||
gfile.MakeDirs(directory) | ||
|
||
try: | ||
for url in urls: | ||
zip_filepath = os.path.join(directory, url.split("/")[-1]) | ||
if os.path.exists(zip_filepath): | ||
continue | ||
logging.info("Downloading %s to %s" % (url, zip_filepath)) | ||
subprocess.call('wget %s --user %s --password %s -O %s' % | ||
(url, USER["user"], USER["password"], zip_filepath), shell=True) | ||
|
||
statinfo = os.stat(zip_filepath) | ||
logging.info( | ||
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size) | ||
) | ||
|
||
# concatenate all parts into zip files | ||
if ".zip" not in zip_filepath: | ||
zip_filepath = "_".join(zip_filepath.split("_")[:-1]) | ||
subprocess.call('cat %s* > %s.zip' % | ||
(zip_filepath, zip_filepath), shell=True) | ||
zip_filepath += ".zip" | ||
extract_path = zip_filepath.strip(".zip") | ||
|
||
# check zip file md5sum | ||
md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest() | ||
if md5 != MD5SUM[subset]: | ||
raise ValueError("md5sum of %s mismatch" % zip_filepath) | ||
|
||
with zipfile.ZipFile(zip_filepath, "r") as zfile: | ||
zfile.extractall(directory) | ||
extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) | ||
subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True) | ||
finally: | ||
# gfile.Remove(zip_filepath) | ||
pass | ||
|
||
|
||
def exec_cmd(cmd): | ||
"""Run a command in a subprocess. | ||
Args: | ||
cmd: command line to be executed. | ||
Return: | ||
int, the return code. | ||
""" | ||
try: | ||
retcode = subprocess.call(cmd, shell=True) | ||
if retcode < 0: | ||
logging.info(f"Child was terminated by signal {retcode}") | ||
except OSError as e: | ||
logging.info(f"Execution failed: {e}") | ||
retcode = -999 | ||
return retcode | ||
|
||
|
||
def decode_aac_with_ffmpeg(aac_file, wav_file): | ||
"""Decode a given AAC file into WAV using ffmpeg. | ||
Args: | ||
aac_file: file path to input AAC file. | ||
wav_file: file path to output WAV file. | ||
Return: | ||
bool, True if success. | ||
""" | ||
cmd = f"ffmpeg -i {aac_file} {wav_file}" | ||
logging.info(f"Decoding aac file using command line: {cmd}") | ||
ret = exec_cmd(cmd) | ||
if ret != 0: | ||
logging.error(f"Failed to decode aac file with retcode {ret}") | ||
logging.error("Please check your ffmpeg installation.") | ||
return False | ||
return True | ||
|
||
|
||
def convert_audio_and_make_label(input_dir, subset, | ||
output_dir, output_file): | ||
"""Optionally convert AAC to WAV and make speaker labels. | ||
Args: | ||
input_dir: the directory which holds the input dataset. | ||
subset: the name of the specified subset. e.g. vox1_dev_wav | ||
output_dir: the directory to place the newly generated csv files. | ||
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv | ||
""" | ||
|
||
logging.info("Preprocessing audio and label for subset %s" % subset) | ||
source_dir = os.path.join(input_dir, subset) | ||
|
||
files = [] | ||
# Convert all AAC file into WAV format. At the same time, generate the csv | ||
for root, _, filenames in gfile.Walk(source_dir): | ||
for filename in filenames: | ||
name, ext = os.path.splitext(filename) | ||
if ext.lower() == ".wav": | ||
_, ext2 = (os.path.splitext(name)) | ||
if ext2: | ||
continue | ||
wav_file = os.path.join(root, filename) | ||
elif ext.lower() == ".m4a": | ||
# Convert AAC to WAV. | ||
aac_file = os.path.join(root, filename) | ||
wav_file = aac_file + ".wav" | ||
if not gfile.Exists(wav_file): | ||
if not decode_aac_with_ffmpeg(aac_file, wav_file): | ||
raise RuntimeError("Audio decoding failed.") | ||
else: | ||
continue | ||
speaker_name = root.split(os.path.sep)[-2] | ||
if speaker_name not in speaker_id_dict: | ||
num = len(speaker_id_dict) | ||
speaker_id_dict[speaker_name] = num | ||
# wav_filesize = os.path.getsize(wav_file) | ||
wav_length = len(sf.read(wav_file)[0]) | ||
files.append( | ||
(os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name) | ||
) | ||
|
||
# Write to CSV file which contains four columns: | ||
# "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". | ||
csv_file_path = os.path.join(output_dir, output_file) | ||
df = pandas.DataFrame( | ||
data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) | ||
df.to_csv(csv_file_path, index=False, sep="\t") | ||
logging.info("Successfully generated csv file {}".format(csv_file_path)) | ||
|
||
|
||
def processor(directory, subset, force_process): | ||
""" download and process """ | ||
urls = SUBSETS | ||
if subset not in urls: | ||
raise ValueError(subset, "is not in voxceleb") | ||
|
||
subset_csv = os.path.join(directory, subset + '.csv') | ||
if not force_process and os.path.exists(subset_csv): | ||
return subset_csv | ||
|
||
logging.info("Downloading and process the voxceleb in %s", directory) | ||
logging.info("Preparing subset %s", subset) | ||
download_and_extract(directory, subset, urls[subset]) | ||
convert_audio_and_make_label( | ||
directory, | ||
subset, | ||
directory, | ||
subset + ".csv" | ||
) | ||
logging.info("Finished downloading and processing") | ||
return subset_csv | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.set_verbosity(logging.INFO) | ||
if len(sys.argv) != 4: | ||
print("Usage: python prepare_data.py save_directory user password") | ||
sys.exit() | ||
|
||
DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3] | ||
for SUBSET in SUBSETS: | ||
processor(DIR, SUBSET, False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import umap | ||
import numpy as np | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
|
||
matplotlib.use("Agg") | ||
|
||
|
||
colormap = ( | ||
np.array( | ||
[ | ||
[76, 255, 0], | ||
[0, 127, 70], | ||
[255, 0, 0], | ||
[255, 217, 38], | ||
[0, 135, 255], | ||
[165, 0, 165], | ||
[255, 167, 255], | ||
[0, 255, 255], | ||
[255, 96, 38], | ||
[142, 76, 0], | ||
[33, 0, 127], | ||
[0, 0, 0], | ||
[183, 183, 183], | ||
], | ||
dtype=np.float, | ||
) | ||
/ 255 | ||
) | ||
|
||
|
||
def plot_embeddings(embeddings, num_utter_per_speaker): | ||
embeddings = embeddings[: 10 * num_utter_per_speaker] | ||
model = umap.UMAP() | ||
projection = model.fit_transform(embeddings) | ||
num_speakers = embeddings.shape[0] // num_utter_per_speaker | ||
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker) | ||
colors = [colormap[i] for i in ground_truth] | ||
|
||
fig, ax = plt.subplots(figsize=(16, 10)) | ||
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) | ||
plt.gca().set_aspect("equal", "datalim") | ||
plt.title("UMAP projection") | ||
plt.tight_layout() | ||
plt.savefig("umap") | ||
return fig |
Oops, something went wrong.