Skip to content

Commit

Permalink
Merge pull request #734 from andabi/master
Browse files Browse the repository at this point in the history
[CUDA-Optimized/FastSpeech]
  • Loading branch information
nv-kkudrynski authored Nov 4, 2020
2 parents b5741a9 + fd32b99 commit 64ea93d
Show file tree
Hide file tree
Showing 39 changed files with 1,154 additions and 318 deletions.
6 changes: 0 additions & 6 deletions CUDA-Optimized/FastSpeech/.gitmodules

This file was deleted.

11 changes: 9 additions & 2 deletions CUDA-Optimized/FastSpeech/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.03-py3
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.10-py3
FROM ${FROM_IMAGE_NAME}

# ARG UNAME
# ARG UID
# ARG GID
# RUN groupadd -g $GID -o $UNAME
# RUN useradd -m -u $UID -g $GID -o -s /bin/bash $UNAME
# USER $UNAME

ADD . /workspace/fastspeech
WORKDIR /workspace/fastspeech

RUN sh ./scripts/install.sh
RUN sh ./scripts/install.sh
58 changes: 31 additions & 27 deletions CUDA-Optimized/FastSpeech/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ and encapsulates some dependencies. Aside from these dependencies, ensure you
have the following components:

* [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
* [PyTorch 20.03-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
* [PyTorch 20.10-py3 NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
or newer
* [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) based GPU
* [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/), [Turing](https://www.nvidia.com/en-us/geforce/turing/)<!--, or [Ampere](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/) based GPU-->

For more information about how to get started with NGC containers, see the
following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning
Expand All @@ -120,11 +120,6 @@ To train your model using mixed precision with Tensor Cores or using FP32, perfo
git clone https://github.com/NVIDIA/DeepLearningExamples.git
cd DeepLearningExamples/CUDA-Optimized/FastSpeech
```
and pull submodules.
```
git submodule init
git submodule update
```
2. Download and preprocess the dataset. Data is downloaded to the ./LJSpeech-1.1 directory (on the host). The ./LJSpeech-1.1 directory is mounted to the /workspace/fastspeech/LJSpeech-1.1 location in the NGC container.
```
Expand All @@ -148,7 +143,7 @@ To train your model using mixed precision with Tensor Cores or using FP32, perfo
The preprocessed mel-spectrograms are stored in the ./mels_ljspeech1.1 directory.
Next, calculate alignments on the LJSpeech dataset using a pre-trained [NVIDIA Tacotron2 checkpoint](https://drive.google.com/file/d/1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA/view). The output directory is specified with `--aligns_path`.
Next, preprocess the alignments on LJSpeech dataset with feed-forwards to the teacher model. Download the Nvidia [pretrained Tacotron2 checkpoint](https://drive.google.com/file/d/1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA/view) to get a pretrained teacher model. And set --tacotron2_path to the Tacotron2 checkpoint file path and the result alignments are stored in --aligns_path.
```
python fastspeech/align_tacotron2.py --dataset_path="./LJSpeech-1.1" --tacotron2_path="tacotron2_statedict.pt" --aligns_path="aligns_ljspeech1.1"
```
Expand All @@ -169,23 +164,23 @@ Next, calculate alignments on the LJSpeech dataset using a pre-trained [NVIDIA T
python fastspeech/train.py --dataset_path="./LJSpeech-1.1" --mels_path="./mels_ljspeech1.1" --aligns_path="./aligns_ljspeech1.1" --log_path="./logs" --checkpoint_path="./checkpoints" --use_amp
```
6. Start generation. To generate waveforms with WaveGlow Vocoder, Get [pretrained WaveGlow model](https://drive.google.com/open?id=1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF) in the home directory, for example, ./waveglow_256channels.pt.
6. Start generation. To generate waveforms with WaveGlow Vocoder, Get [pretrained WaveGlow model](https://ngc.nvidia.com/catalog/models/nvidia:waveglow_ckpt_amp_256/files?version=19.10.0) from NGC into the home directory, for example, ./nvidia_waveglow256pyt_fp16.
After you have trained the FastSpeech model, you can perform generation using the checkpoint stored in ./checkpoints. Then run:
```
python generate.py --waveglow_path="./waveglow_256channels.pt" --checkpoint_path="./checkpoints" --text="./test_sentences.txt"
python generate.py --waveglow_path="./nvidia_waveglow256pyt_fp16" --checkpoint_path="./checkpoints" --text="./test_sentences.txt"
```
The script loads automatically the latest checkpoint (if any exists), or you can pass a checkpoint file through --ckpt_file. And it loads input texts in ./test_sentences.txt and stores the result in ./results directory. You can also set the result directory path with --results_path.
You can also run with a sample text:
```
python generate.py --waveglow_path="./waveglow_256channels.pt" --checkpoint_path="./checkpoints" --text="The more you buy, the more you save."
python generate.py --waveglow_path="./nvidia_waveglow256pyt_fp16" --checkpoint_path="./checkpoints" --text="The more you buy, the more you save."
```
7. Accelerate generation(inferencing of FastSpeech and WaveGlow) with TensorRT. Set parameters config file with --hparam=trt.yaml to enable TensorRT inference mode. To prepare for running WaveGlow on TensorRT, first extract a TensorRT engine file via [DeepLearningExamples/PyTorch/SpeechSynthesis/Tacotron2/trt](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2/trt) and copy this in the home directory, for example, ./waveglow.fp16.trt. Then run with --waveglow_engine_path:
7. Accelerate generation(inferencing of FastSpeech and WaveGlow) with TensorRT. Set parameters config file with --hparam=trt.yaml to enable TensorRT inference mode. To prepare for running WaveGlow on TensorRT, first get an ONNX file via [DeepLearningExamples/PyTorch/SpeechSynthesis/Tacotron2/tensorrt](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2/tensorrt), convert it to an TensorRT engine using scripts/waveglow/convert_onnx2trt.py, and copy this in the home directory, for example, ./waveglow.fp16.trt. Then run with --waveglow_engine_path:
```
python generate.py --hparam=trt.yaml --waveglow_path="./waveglow_256channels.pt" --checkpoint_path="./checkpoints" --text="./test_sentences.txt" --waveglow_engine_path="waveglow.fp16.trt"
python generate.py --hparam=trt.yaml --waveglow_path="./nvidia_waveglow256pyt_fp16" --checkpoint_path="./checkpoints" --text="./test_sentences.txt" --waveglow_engine_path="waveglow.fp16.trt"
```
## Advanced
Expand Down Expand Up @@ -293,33 +288,29 @@ For more details, refer to [accelerating inference with TensorRT](fastspeech/trt

#### Generation

To generate waveforms with WaveGlow Vocoder, 1) Make sure to pull [Nvidia WaveGlow](https://github.com/NVIDIA/waveglow) through git submodule, 2) get [pretrained WaveGlow model](https://drive.google.com/open?id=1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF) in the home directory, for example, ./waveglow_256channels.pt.
```
git submodule init
git submodule update
```
To generate waveforms with WaveGlow Vocoder, get [pretrained WaveGlow model](https://ngc.nvidia.com/catalog/models/nvidia:waveglow_ckpt_amp_256/files?version=19.10.0) from NGC into the home directory, for example, ./nvidia_waveglow256pyt_fp16.

Run generate.py with:
* --text - an input text or the text file path.
* --results_path - result waveforms directory path. (default=./results).
* --ckpt_file - checkpoint file path. (default checkpoint file is the latest file in --checkpoint_path)
```
python generate.py --waveglow_path="./waveglow_256channels.pt" --text="The more you buy, the more you save."
python generate.py --waveglow_path="./nvidia_waveglow256pyt_fp16" --text="The more you buy, the more you save."
```
or
```
python generate.py --waveglow_path="./waveglow_256channels.pt" --text=test_sentences.txt
python generate.py --waveglow_path="./nvidia_waveglow256pyt_fp16" --text=test_sentences.txt
```

Sample result waveforms are [here](https://gitlab-master.nvidia.com/dahn/fastspeech/tree/master/samples).
Sample result waveforms are [here](samples).

To generate waveforms with the whole pipeline of FastSpeech and WaveGlow with TensorRT, extract a WaveGlow TRT engine file through https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2/trt and run generate.py with --hparam=trt.yaml and --waveglow_engine_path.
To generate waveforms with the whole pipeline of FastSpeech and WaveGlow with TensorRT, extract a WaveGlow TRT engine file through https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2/tensorrt and run generate.py with --hparam=trt.yaml and --waveglow_engine_path.

```
python generate.py --hparam=trt.yaml --waveglow_path="./waveglow_256channels.pt" --waveglow_engine_path="waveglow.fp16.trt" --text="The more you buy, the more you save."
python generate.py --hparam=trt.yaml --waveglow_path="./nvidia_waveglow256pyt_fp16" --waveglow_engine_path="waveglow.fp16.trt" --text="The more you buy, the more you save."
```

Sample result waveforms are [FP32](https://gitlab-master.nvidia.com/dahn/fastspeech/-/tree/master/fastspeech/trt/samples) and [FP16](https://gitlab-master.nvidia.com/dahn/fastspeech/-/tree/master/fastspeech/trt/samples_fp16).
Sample result waveforms are [FP32](fastspeech/trt/samples) and [FP16](fastspeech/trt/samples_fp16).


## Performance
Expand Down Expand Up @@ -391,7 +382,17 @@ The following sections provide details on how we achieved our performance and ac

#### Training performance results

Our results were obtained by running the script in [training performance benchmark](#training-performance-benchmark) in the PyTorch-20.03-py3 NGC container on NVIDIA DGX-1 with 8x V100 16G GPUs. Performance numbers (in number of mels per second) were averaged over an entire training epoch.
Our results were obtained by running the script in [training performance benchmark](#training-performance-benchmark) on <!--NVIDIA DGX A100 with 8x A100 40G GPUs and -->NVIDIA DGX-1 with 8x V100 16G GPUs. Performance numbers (in number of mels per second) were averaged over an entire training epoch.

<!-- ##### Training performance: NVIDIA DGX A100 (8x A100 40GB)
| GPUs | Batch size / GPU | Throughput(mels/s) - FP32 | Throughput(mels/s) - mixed precision | Throughput speedup (FP32 - mixed precision) | Multi-GPU Weak scaling - FP32 | Multi-GPU Weak scaling - mixed precision
|---|----|--------|--------|------|-----|------|
| 1 | 32 | | | | | 1 |
| 4 | 32 | | | | | |
| 8 | 32 | | | | | | -->

##### Training performance: NVIDIA DGX-1 (8x V100 16GB)

| GPUs | Batch size / GPU | Throughput(mels/s) - FP32 | Throughput(mels/s) - mixed precision | Throughput speedup (FP32 - mixed precision) | Multi-GPU Weak scaling - FP32 | Multi-GPU Weak scaling - mixed precision
|---|----|--------|--------|------|-----|------|
Expand All @@ -401,7 +402,7 @@ Our results were obtained by running the script in [training performance benchma

#### Inference performance results

Our results were obtained by running the script in [inference performance benchmark](#inference-performance-benchmark) in the PyTorch-20.03-py3 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU and a NVIDIA T4. The following tables show inference statistics for the FastSpeech and WaveGlow text-to-speech system on PyTorch and comparisons by framework with batch size 1 in FP16, gathered from 1000 inference runs. Latency is measured from the start of FastSpeech inference to the end of WaveGlow inference. The tables include average latency, latency standard deviation, and latency confidence intervals. Throughput is measured as the number of generated audio samples per second. RTF is the real-time factor which tells how many seconds of speech are generated in 1 second of compute. The used WaveGlow model is a 256-channel model. The numbers reported below were taken with a moderate length of 128 characters.
Our results were obtained by running the script in [inference performance benchmark](#inference-performance-benchmark) on NVIDIA DGX-1 with 1x V100 16GB GPU and a NVIDIA T4. The following tables show inference statistics for the FastSpeech and WaveGlow text-to-speech system on PyTorch and comparisons by framework with batch size 1 in FP16, gathered from 1000 inference runs. Latency is measured from the start of FastSpeech inference to the end of WaveGlow inference. The tables include average latency, latency standard deviation, and latency confidence intervals. Throughput is measured as the number of generated audio samples per second. RTF is the real-time factor which tells how many seconds of speech are generated in 1 second of compute. The used WaveGlow model is a 256-channel model. The numbers reported below were taken with a moderate length of 128 characters.

##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)

Expand Down Expand Up @@ -442,9 +443,12 @@ Our results were obtained by running the script in [inference performance benchm
## Release notes

### Changelog
Oct 2020
- PyTorch 1.7, TensorRT 7.2 support <!--and Nvidia Ampere architecture support-->

July 2020
- Initial release

### Known issues

There are no known issues in this release.
There are no known issues in this release.
13 changes: 10 additions & 3 deletions CUDA-Optimized/FastSpeech/fastspeech/dataset/ljspeech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@

import csv

import pprint

import librosa
from torch.utils.data import Dataset
import pandas as pd
from fastspeech.text_norm import text_to_sequence
from fastspeech import audio
from fastspeech.utils.logging import tprint

import os
import pathlib

Expand All @@ -38,6 +42,8 @@

from fastspeech import hparam as hp

pp = pprint.PrettyPrinter(indent=4, width=1000)

class LJSpeechDataset(Dataset):

def __init__(self, root_path, meta_file="metadata.csv",
Expand Down Expand Up @@ -130,7 +136,7 @@ def __getitem__(self, idx):
return data


def preprocess_mel(hparam="base.yaml"):
def preprocess_mel(hparam="base.yaml", **kwargs):
"""The script for preprocessing mel-spectrograms from the dataset.
By default, this script assumes to load parameters in the default config file, fastspeech/hparams/base.yaml.
Expand All @@ -147,8 +153,9 @@ def preprocess_mel(hparam="base.yaml"):
hparam (str, optional): Path to default config file. Defaults to "base.yaml".
"""

hp.set_hparam(hparam)

hp.set_hparam(hparam, kwargs)
tprint("Hparams:\n{}".format(pp.pformat(hp)))

pathlib.Path(hp.mels_path).mkdir(parents=True, exist_ok=True)

dataset = LJSpeechDataset(hp.dataset_path, mels_path=None)
Expand Down
2 changes: 1 addition & 1 deletion CUDA-Optimized/FastSpeech/fastspeech/hparams/base.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Path
dataset_path: "/workspace/fastspeech/LJSpeech-1.1"
tacotron2_path: "/workspace/fastspeech/tacotron2_statedict.pt"
waveglow_path: "/workspace/fastspeech/waveglow_256channels.pt"
waveglow_path: "/workspace/fastspeech/nvidia_waveglow256pyt_fp16"
mels_path: "/workspace/fastspeech/mels_ljspeech1.1"
aligns_path: "/workspace/fastspeech/aligns_ljspeech1.1"
log_path: "/workspace/fastspeech/logs"
Expand Down
2 changes: 1 addition & 1 deletion CUDA-Optimized/FastSpeech/fastspeech/hparams/trt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ parent_yaml: 'infer.yaml'
# Inference
batch_size: 1 # Batch size.
use_trt: True # Usage of TensorRT. Must be True to enable TensorRT.
use_fp16: True # Usage of FP16. Set to True to enable half precision for the engine.
use_fp16: True # Usage of FP16. Set to True to enable half precision for the engine.

# TRT
trt_file_path: "/workspace/fastspeech/fastspeech.fp16.b1.trt" # Built TensorRT engine file path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@
from fastspeech.utils.pytorch import to_cpu_numpy, to_device_async
from fastspeech.inferencer.denoiser import Denoiser

from waveglow.model import WaveGlow
import argparse

def unwrap_distributed(state_dict):
"""
Unwraps model from DistributedDataParallel.
DDP wraps model in additional "module.", it needs to be removed for single
GPU inference.
:param state_dict: model's state dict
"""
new_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace('module.', '')
new_state_dict[new_key] = value
return new_state_dict

class WaveGlowInferencer(object):

Expand All @@ -40,11 +55,36 @@ def __init__(self, ckpt_file, device='cuda', use_fp16=False, use_denoiser=False)
self.use_denoiser = use_denoiser

# model
sys.path.append('waveglow')
self.model = torch.load(self.ckpt_file, map_location=self.device)['model']
# sys.path.append('waveglow')

from waveglow.arg_parser import parse_waveglow_args
parser = parser = argparse.ArgumentParser()
model_parser= parse_waveglow_args(parser)
args, _ = model_parser.parse_known_args()
model_config = dict(
n_mel_channels=args.n_mel_channels,
n_flows=args.flows,
n_group=args.groups,
n_early_every=args.early_every,
n_early_size=args.early_size,
WN_config=dict(
n_layers=args.wn_layers,
kernel_size=args.wn_kernel_size,
n_channels=args.wn_channels
)
)
self.model = WaveGlow(**model_config)

state_dict = torch.load(self.ckpt_file, map_location=self.device)['state_dict']
state_dict = unwrap_distributed(state_dict)
self.model.load_state_dict(state_dict)

self.model = to_device_async(self.model, self.device)

self.model = self.model.remove_weightnorm(self.model)

self.model.eval()
self.model = to_device_async(self.model, self.device)

if self.use_fp16:
self.model = self.model.half()
self.model = self.model
Expand Down
Loading

0 comments on commit 64ea93d

Please sign in to comment.