Skip to content

Commit

Permalink
added sdpa for vits
Browse files Browse the repository at this point in the history
  • Loading branch information
huseinzol05 committed Dec 20, 2024
1 parent e004fb7 commit 1ba78fe
Show file tree
Hide file tree
Showing 18 changed files with 15,943 additions and 9,409 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ conformer-large.py
malay_vits/config/*
malay_vits/logs/*
malay_vits/*.json
!malay_vits/config/*.json
malay_vits/*.ipynb
conformer-large-mixed-32
female.ipynb
Expand Down Expand Up @@ -124,4 +125,5 @@ malay-asr-test.json
malaya-malay-test-set.json
pretrained-model/stt/conformer-ctc/*/checkpoint*
pretrained-model/stt/conformer-ctc/*/runs*
pretrained-model/stt/conformer-ctc/out
pretrained-model/stt/conformer-ctc/out
malay_vits/monotonic_align/monotonic_align
2 changes: 1 addition & 1 deletion gradio/f5-tts/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ services:
reservations:
devices:
- driver: nvidia
device_ids: ["2"]
device_ids: ["1"]
capabilities: [gpu]
container_name: gradio-f5-tts
ports:
Expand Down
1 change: 1 addition & 0 deletions malay_vits/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ python3 setup.py build_ext --inplace

```bash
python3 train.py -c config.json -m speaker
CUDA_VISIBLE_DEVICES=1 python3.10 train_ms.py -c config/multispeaker-clean.json -m multispeaker
```

### VITS 2
Expand Down
71 changes: 42 additions & 29 deletions malay_vits/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, use_sdpa=False, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
Expand All @@ -20,6 +20,7 @@ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_s
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.use_sdpa = use_sdpa

self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
Expand All @@ -28,14 +29,18 @@ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_s
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
n_heads, p_dropout=p_dropout, window_size=window_size))
n_heads, p_dropout=p_dropout, window_size=window_size, use_sdpa=use_sdpa))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
filter_channels, kernel_size, p_dropout=p_dropout))
self.norm_layers_2.append(LayerNorm(hidden_channels))

def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
if self.use_sdpa:
dtype = x.dtype
min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
attn_mask = torch.where(attn_mask == 1, torch.tensor(0, dtype=dtype), min_value)
x = x * x_mask
for i in range(self.n_layers):
y = self.attn_layers[i](x, x, attn_mask)
Expand All @@ -50,7 +55,7 @@ def forward(self, x, x_mask):


class Decoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, use_sdpa = False, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
Expand All @@ -70,10 +75,10 @@ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_s
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads,
p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init,use_sdpa=use_sdpa))
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,use_sdpa=use_sdpa))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels,
kernel_size, p_dropout=p_dropout, causal=True))
Expand Down Expand Up @@ -104,7 +109,7 @@ def forward(self, x, x_mask, h, h_mask):


class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False, use_sdpa=False):
super().__init__()
assert channels % n_heads == 0

Expand All @@ -118,6 +123,7 @@ def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=No
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.use_sdpa = use_sdpa

self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
Expand Down Expand Up @@ -157,29 +163,36 @@ def attention(self, query, key, value, mask=None):
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)

scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert t_s == t_t, "Local attention is only available for self-attention."
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
if self.use_sdpa:
output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=mask,
)
p_attn = None
else:
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert t_s == t_t, "Local attention is only available for self-attention."
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)

output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn

Expand Down
54 changes: 54 additions & 0 deletions malay_vits/config/multispeaker-clean-small.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"train": {
"log_interval": 2,
"eval_interval": 1000,
"seed": 1234,
"epochs": 20000,
"learning_rate": 2e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 48,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 8192,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0
},
"data": {
"training_files":"/home/husein/ssd3/tts/multispeaker-clean-vits.json",
"validation_files":"/home/husein/ssd3/tts/multispeaker-clean-vits.json",
"text_cleaners":[""],
"max_wav_value": 32768.0,
"sampling_rate": 22050,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 8,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 3,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"upsample_rates": [8,8,2,2],
"upsample_initial_channel": 384,
"upsample_kernel_sizes": [16,16,4,4],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 256,
"use_sdpa": true
}
}
54 changes: 54 additions & 0 deletions malay_vits/config/multispeaker-clean.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"train": {
"log_interval": 2,
"eval_interval": 1000,
"seed": 1234,
"epochs": 20000,
"learning_rate": 2e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 48,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 8192,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0
},
"data": {
"training_files":"/home/husein/ssd3/tts/multispeaker-clean-vits.json",
"validation_files":"/home/husein/ssd3/tts/multispeaker-clean-vits.json",
"text_cleaners":[""],
"max_wav_value": 32768.0,
"sampling_rate": 22050,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 8,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"upsample_rates": [8,8,2,2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16,4,4],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 256,
"use_sdpa": true
}
}
53 changes: 53 additions & 0 deletions malay_vits/config/osman.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"train": {
"log_interval": 2,
"eval_interval": 1000,
"seed": 1234,
"epochs": 20000,
"learning_rate": 2e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 64,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 8192,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0
},
"data": {
"training_files":"/home/husein/speech-bahasa/osman-vits-train-set.txt",
"validation_files":"/home/husein/speech-bahasa/osman-vits-test-set.txt",
"text_cleaners":[""],
"max_wav_value": 32768.0,
"sampling_rate": 22050,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 0,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0,
"resblock": "1",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"upsample_rates": [8,8,2,2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16,4,4],
"n_layers_q": 3,
"use_spectral_norm": false,
"use_sdpa": true
}
}
52 changes: 52 additions & 0 deletions malay_vits/config/yasmin.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"train": {
"log_interval": 50,
"eval_interval": 1000,
"seed": 1234,
"epochs": 20000,
"learning_rate": 2e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 64,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 8192,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0
},
"data": {
"training_files":"/home/husein/speech-bahasa/yasmin-vits-train-set.txt",
"validation_files":"/home/husein/speech-bahasa/yasmin-vits-test-set.txt",
"text_cleaners":[""],
"max_wav_value": 32768.0,
"sampling_rate": 22050,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_mel_channels": 80,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 0,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"upsample_rates": [8,8,2,2],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16,4,4],
"n_layers_q": 3,
"use_spectral_norm": false
}
}
Loading

0 comments on commit 1ba78fe

Please sign in to comment.