Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk gpu #11

Merged
merged 4 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,20 @@ re-initialize the same model for each tensor.

To avoid this time-consuming step, one can manually instantiate the model and data processor, then pass them directly
as args to the `predict` function. To do so, one has to use the underlying methods from `pesto.utils`:

```python
import torch

from pesto import predict
from pesto.utils import load_model, load_dataprocessor


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model("mir-1k", device=device)
data_processor = load_dataprocessor(step_size=0.01, device=device)

for x, sr in ...:
data_processor.sampling_rate = sr # The data_processor handles waveform->CQT conversion so it must know the sampling rate
predictions = predict(x, sr, model=model, data_processor=data_processor)
predictions = predict(x, sr, model=model)
...
```
Note that when passing a list of files to `pesto.predict_from_files(...)` or the CLI directly, the model is loaded only
Expand Down Expand Up @@ -175,6 +175,20 @@ Note that the *y*-axis is in log-scale: with a step size of 10ms (the default),
PESTO would perform pitch estimation of the file in 13 seconds (~12 times faster than real-time) while CREPE would take 12 minutes!
It is therefore more suited to applications that need very fast pitch estimation without relying on GPU resources.

### Inference on GPU

The underlying PESTO pitch estimator is a standard PyTorch module and can therefore use the GPU,
if available, by setting option `--gpu` to the id of the device you want to use for pitch estimation.

Under the hood, the input is passed to the model as a single batch of CQT frames,
so pitch is estimated for the whole track in parallel, making inference extremely fast.

However, when dealing with very large audio files, processing the whole track at once can lead to OOM errors.
To circumvent this, one can split the batch of CQT frames into multiple chunks by setting option `-c`/`--num_chunks`.
Chunks will be processed sequentially, thus reducing memory usage.

As an example, a 48kHz audio file of 1 hour can be processed in 20 seconds only on a single GTX 1080 Ti when split into 10 chunks.

## Contributing

- Currently, only a single model trained on [MIR-1K](https://zenodo.org/record/3532216#.ZG0kWhlBxhE) is provided.
Expand Down
27 changes: 17 additions & 10 deletions pesto/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def predict(
data_preprocessor=None,
step_size: Optional[float] = None,
reduction: str = "argmax",
num_chunks: int = 1,
convert_to_freq: bool = False
):
r"""Main prediction function.
Expand All @@ -31,7 +32,10 @@ def predict(
data_preprocessor: Module handling the data processing pipeline (waveform to CQT, cropping, etc.)
step_size (float, optional): step size between each CQT frame in milliseconds.
If the data_preprocessor is passed, its value will be used instead.
reduction (str):
reduction (str): reduction method for converting activation probabilities to log-frequencies.
num_chunks (int): number of chunks to split the input audios in.
Default is 1 (all CQT frames in parallel) but it can be increased to reduce memory usage
and prevent out-of-memory errors.
convert_to_freq (bool): whether predictions should be converted to frequencies or not.
"""
# convert to mono
Expand All @@ -53,7 +57,15 @@ def predict(

# apply model
cqt = data_preprocessor(x)
activations = model(cqt)
try:
activations = torch.cat([
model(chunk) for chunk in cqt.chunk(chunks=num_chunks)
])
except torch.cuda.OutOfMemoryError:
raise torch.cuda.OutOfMemoryError("Got an out-of-memory error while performing pitch estimation. "
"Please increase the number of chunks with option `-c`/`--chunks` "
"to reduce GPU memory usage.")

if batch_size:
total_batch_size, num_predictions = activations.size()
activations = activations.view(batch_size, total_batch_size // batch_size, num_predictions)
Expand Down Expand Up @@ -84,6 +96,7 @@ def predict_from_files(
reduction: str = "alwa",
export_format: Sequence[str] = ("csv",),
no_convert_to_freq: bool = False,
num_chunks: int = 1,
gpu: int = -1
):
r"""
Expand Down Expand Up @@ -130,14 +143,8 @@ def predict_from_files(
x = x.to(device)

# compute the predictions
predictions = predict(
x,
sr,
model=model,
data_preprocessor=data_preprocessor,
reduction=reduction,
convert_to_freq=not no_convert_to_freq
)
predictions = predict(x, sr, model=model, data_preprocessor=data_preprocessor, reduction=reduction,
convert_to_freq=not no_convert_to_freq, num_chunks=num_chunks)

output_file = file.rsplit('.', 1)[0] + "." + ("semitones" if no_convert_to_freq else "f0")
if output is not None:
Expand Down
2 changes: 1 addition & 1 deletion pesto/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
if len(n_ch) < 5:
n_ch.append(1)

# Layer normalization over frequency and channels (harmonics of HCQT)
# Layer normalization over frequency
self.layernorm = nn.LayerNorm(normalized_shape=[1, n_bins_in])

# Prefiltering
Expand Down
3 changes: 3 additions & 0 deletions pesto/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def parse_args():
parser.add_argument('-F', '--no_convert_to_freq', action='store_true',
help='if true, does not convert the predicted pitch to frequency domain and '
'returns predictions as semitones')
parser.add_argument('-c', '--num_chunks', type=int, default=1,
help='number of chunks to split the input data into (default: 1). '
'Can be increased to prevent out-of-memory errors.')
parser.add_argument('--gpu', type=int, default=-1,
help='the index of the GPU to use, -1 for CPU')
return parser.parse_args()
6 changes: 3 additions & 3 deletions pesto/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch

from pesto.config import model_args, cqt_args, bins_per_semitone
from pesto.data import DataProcessor
from pesto.model import PESTOEncoder
from .config import model_args, cqt_args, bins_per_semitone
from .data import DataProcessor
from .model import PESTOEncoder


def load_dataprocessor(step_size, device: Optional[torch.device] = None):
Expand Down