Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InferenceMode causes RuntimeError when storing PESTO model and DataProcessor on LightningModule using DDP strategy #18

Open
ben-hayes opened this issue Nov 28, 2023 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@ben-hayes
Copy link
Contributor

Context

In some use cases (e.g. DDSP audio synthesis) we want to perform F0 estimation on the GPU, so it's helpful to store PESTO as a submodule of our pytorch_lightning.LightningModule.

Bug description

When training with the DistributedDataParallel strategy, the _sync_buffers method causes the following exception to be thrown on the second training iteration, using pesto.predict:

RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.

Note that this persists whether the output is cloned or not — i.e. the problematic InferenceMode tensor is not the output.

Expected behavior

PESTO should be usable as a submodule.

Minimal example

import pesto
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.demos.boring_classes import RandomDataset

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.f0_extractor = pesto.utils.load_model("mir-1k")
        self.prepocessor = pesto.utils.load_dataprocessor(1e-2, device="cuda")
        self.prepocessor.sampling_rate = 44100
        self.net = nn.Linear(201, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, f0_hz, _, _ = pesto.predict(
            x,
            44100,
            data_preprocessor=self.prepocessor,
            model=self.f0_extractor,
            convert_to_freq=True,
        )
        f0_hz = f0_hz.clone() # avoid in-place operation on InferenceMode output

        return self.net(f0_hz)

    def training_step(self, batch, batch_idx):
        x = batch
        y = self(x)
        loss = y.mean()
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


model = MyModel()
dataset = RandomDataset(88200, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

trainer = pl.Trainer(accelerator="gpu", max_epochs=100, strategy="ddp_find_unused_parameters_true")
trainer.fit(model, dataloader)

Diagnostics

As far as I can tell, the issue arises because data_processor.sampling_rate is set inside pesto.predict, which is decorated by torch.inference_mode():

data_preprocessor.sampling_rate = sr

This means that if the sample rate has changed, or is being set for the first time (as it is likely to be on the first call to pesto.predict), the CQT buffers (or parameters) are created as inference-mode tensors.

Workaround

A temporary workaround is to set DataProcessor.sampling_rate before calling pesto.predict.

Possible solution

Use with torch.inference_mode() context manager around only the inference section of pesto.predict.

@aRI0U aRI0U self-assigned this Nov 28, 2023
@aRI0U aRI0U added the bug Something isn't working label Nov 28, 2023
@aRI0U
Copy link
Collaborator

aRI0U commented Nov 28, 2023

Hi,

Yeah in the end maybe always decorating pesto.predict with torch.inference_mode is a bit restrictive, I'll consider adding the possibility to choose between torch.no_grad and torch.inference_mode when running predict, it should prevent such issues.

Also, I'm not sure why it only fails when using DDP. When training on a single GPU does it work as expected?

@ben-hayes
Copy link
Contributor Author

Training without DDP strategy is fine as there are no ops that modify the buffers. The bug occurs when DDP tries to sync buffers. It appears to be the call to torch.distributed._broadcast_coalesced that is triggering an inplace modification:

https://github.com/pytorch/pytorch/blob/b6a30bbfb6c1bcb9c785e7a853c2622c8bc17093/torch/nn/parallel/distributed.py#L1978-L1983

@ben-hayes
Copy link
Contributor Author

Also just to say... this issue is a side effect of the lazy CQT init discussed in #19.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants