Skip to content

Commit

Permalink
updated demo and style transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
csteinmetz1 committed Nov 4, 2023
1 parent a143dc8 commit 8215377
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 84 deletions.
213 changes: 133 additions & 80 deletions examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,87 +8,140 @@
noise_shaped_reverberation,
)

os.makedirs("outputs/demo", exist_ok=True)
if __name__ == "__main__":
use_gpu = torch.cuda.is_available()

if not os.path.exists("examples/demo/di_guitar.wav"):
os.system(
"wget csteinmetz1.github.io/sounds/assets/short_riff.wav -O outputs/demo/short_riff.wav"
if use_gpu:
torch.set_default_device("cuda")

os.makedirs("outputs/demo", exist_ok=True)

if not os.path.exists("outputs/demo/idmt-rock-input-varying-gain.wav"):
os.system(
"wget csteinmetz1.github.io/sounds/assets/amps/idmt-rock-input-varying-gain.wav -O outputs/demo/idmt-rock-input-varying-gain.wav"
)

# load DI guitar sample
x, sr = torchaudio.load(
"outputs/demo/idmt-rock-input-varying-gain.wav", backend="soundfile"
)

# load DI guitar sample
x, sr = torchaudio.load("outputs/demo/short_riff.wav", backend="soundfile")

# add batch dim
x = x.unsqueeze(0)

# process with chain of effects
y = parametric_eq(
x,
sr,
low_shelf_gain_db=-12.0,
low_shelf_cutoff_freq=1000,
low_shelf_q_factor=0.5,
band0_gain_db=0.0,
band0_cutoff_freq=1000,
band0_q_factor=0.5,
band1_gain_db=0.0,
band1_cutoff_freq=1000,
band1_q_factor=0.5,
band2_gain_db=0.0,
band2_cutoff_freq=1000,
band2_q_factor=0.5,
band3_gain_db=0.0,
band3_cutoff_freq=1000,
band3_q_factor=0.5,
high_shelf_gain_db=12.0,
high_shelf_cutoff_freq=4000,
high_shelf_q_factor=0.5,
)
print(y.shape)
y = compressor(
y,
sr,
threshold_db=-12.0,
ratio=4.0,
attack_ms=10.0,
release_ms=100.0,
knee_db=12.0,
makeup_gain_db=0.0,
)
print(y.shape)
y = distortion(y, sr, drive_db=16.0)
print(y.shape)
y = noise_shaped_reverberation(
y,
sr,
band0_gain=0.8,
band1_gain=0.4,
band2_gain=0.2,
band3_gain=0.5,
band4_gain=0.5,
band5_gain=0.5,
band6_gain=0.5,
band7_gain=0.6,
band8_gain=0.7,
band9_gain=0.8,
band10_gain=0.9,
band11_gain=1.0,
band0_decay=0.5,
band1_decay=0.5,
band2_decay=0.5,
band3_decay=0.5,
band4_decay=0.5,
band5_decay=0.5,
band6_decay=0.5,
band7_decay=0.5,
band8_decay=0.5,
band9_decay=0.5,
band10_decay=0.5,
band11_decay=0.5,
mix=0.2,
)
print(y.shape)
y = y.squeeze(0)
print(y.shape)
start_idx = int(49.0 * sr)
end_idx = start_idx + 441000
x = x[0:1, start_idx:end_idx]

# add batch dim
x = x.unsqueeze(0)

if use_gpu:
x = x.cuda()

# process with chain of effects
y = parametric_eq(
x,
sr,
low_shelf_gain_db=torch.tensor([-12.0]),
low_shelf_cutoff_freq=torch.tensor([1000]),
low_shelf_q_factor=torch.tensor([0.5]),
band0_gain_db=torch.tensor([0.0]),
band0_cutoff_freq=torch.tensor([1000]),
band0_q_factor=torch.tensor([0.5]),
band1_gain_db=torch.tensor([0.0]),
band1_cutoff_freq=torch.tensor([1000]),
band1_q_factor=torch.tensor([0.5]),
band2_gain_db=torch.tensor([0.0]),
band2_cutoff_freq=torch.tensor([1000]),
band2_q_factor=torch.tensor([0.5]),
band3_gain_db=torch.tensor([0.0]),
band3_cutoff_freq=torch.tensor([1000]),
band3_q_factor=torch.tensor([0.5]),
high_shelf_gain_db=torch.tensor([12.0]),
high_shelf_cutoff_freq=torch.tensor([4000]),
high_shelf_q_factor=torch.tensor([0.5]),
)
y = compressor(
y,
sr,
threshold_db=torch.tensor([-12.0]),
ratio=torch.tensor([4.0]),
attack_ms=torch.tensor([10.0]),
release_ms=torch.tensor([100.0]),
knee_db=torch.tensor([12.0]),
makeup_gain_db=torch.tensor([0.0]),
)

torchaudio.save("outputs/demo/short_riff_output.wav", y, sr, backend="soundfile")
y = distortion(y, sr, drive_db=torch.tensor([42.0]))

y = parametric_eq(
y,
sr,
low_shelf_gain_db=torch.tensor([0.0]),
low_shelf_cutoff_freq=torch.tensor([1000]),
low_shelf_q_factor=torch.tensor([0.5]),
band0_gain_db=torch.tensor([0.0]),
band0_cutoff_freq=torch.tensor([1000]),
band0_q_factor=torch.tensor([0.5]),
band1_gain_db=torch.tensor([0.0]),
band1_cutoff_freq=torch.tensor([1000]),
band1_q_factor=torch.tensor([0.5]),
band2_gain_db=torch.tensor([0.0]),
band2_cutoff_freq=torch.tensor([1000]),
band2_q_factor=torch.tensor([0.5]),
band3_gain_db=torch.tensor([0.0]),
band3_cutoff_freq=torch.tensor([1000]),
band3_q_factor=torch.tensor([0.5]),
high_shelf_gain_db=torch.tensor([-18.0]),
high_shelf_cutoff_freq=torch.tensor([4000]),
high_shelf_q_factor=torch.tensor([0.5]),
)

y = noise_shaped_reverberation(
y,
sr,
band0_gain=torch.tensor([0.8]),
band1_gain=torch.tensor([0.4]),
band2_gain=torch.tensor([0.2]),
band3_gain=torch.tensor([0.5]),
band4_gain=torch.tensor([0.5]),
band5_gain=torch.tensor([0.5]),
band6_gain=torch.tensor([0.5]),
band7_gain=torch.tensor([0.6]),
band8_gain=torch.tensor([0.7]),
band9_gain=torch.tensor([0.8]),
band10_gain=torch.tensor([0.9]),
band11_gain=torch.tensor([1.0]),
band0_decay=torch.tensor([0.5]),
band1_decay=torch.tensor([0.5]),
band2_decay=torch.tensor([0.5]),
band3_decay=torch.tensor([0.5]),
band4_decay=torch.tensor([0.5]),
band5_decay=torch.tensor([0.5]),
band6_decay=torch.tensor([0.5]),
band7_decay=torch.tensor([0.5]),
band8_decay=torch.tensor([0.5]),
band9_decay=torch.tensor([0.5]),
band10_decay=torch.tensor([0.5]),
band11_decay=torch.tensor([0.5]),
mix=torch.tensor([0.15]),
)

y = y.squeeze(0)
x = x.squeeze(0)
print(y.shape)

y = y / torch.max(torch.abs(y))
x = x / torch.max(torch.abs(x))

torchaudio.save(
"outputs/demo/idmt-rock-input-varying-input.wav",
x.cpu(),
sr,
backend="soundfile",
)

torchaudio.save(
"outputs/demo/idmt-rock-input-varying-output.wav",
y.cpu(),
sr,
backend="soundfile",
)
8 changes: 4 additions & 4 deletions examples/style_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(

self.examples.append((filepath, frame_offset))

self.examples = self.examples * 100
self.examples = self.examples

def __len__(self):
return len(self.examples)
Expand Down Expand Up @@ -270,7 +270,7 @@ def validate(

def step(input: torch.Tensor, model: torch.nn.Module):
# generate reference by randomly processing input
torch.manual_seed(1)
# torch.manual_seed(1)
rand_equalizer_params = torch.rand(
input.shape[0],
model.equalizer.num_params,
Expand Down Expand Up @@ -394,8 +394,8 @@ def train(
train_filepaths = filepaths[: int(len(filepaths) * 0.8)]
val_filepaths = filepaths[int(len(filepaths) * 0.8) :]

train_filepaths = train_filepaths[:1]
val_filepaths = train_filepaths[:1]
# train_filepaths = train_filepaths[:1]
# val_filepaths = train_filepaths[:1]

train_dataset = AudioFileDataset(train_filepaths, length=262144)
train_dataloader = torch.utils.data.DataLoader(
Expand Down

0 comments on commit 8215377

Please sign in to comment.