diff --git a/CookieTTS/_2_ttm/VDVAETTS/model.py b/CookieTTS/_2_ttm/VDVAETTS/model.py index dd60d35..f0681de 100644 --- a/CookieTTS/_2_ttm/VDVAETTS/model.py +++ b/CookieTTS/_2_ttm/VDVAETTS/model.py @@ -1443,10 +1443,10 @@ def update_device(self, **inputs): outputs[key] = input return outputs - def inference(self, text_seq, text_lengths, speaker_id, torchmoji_hdn, + def inference(self, text_seq, text_lengths, speaker_id, torchmoji_hdn, multispeaker_mode, char_sigma=1.0, frame_sigma=1.0, bn_logdur=None, char_dur=None, gt_mel=None, alignment=None, - mel_lengths=None,):# [B, enc_T], [B], [B], [B], [B, tm_dim] + mel_lengths=None):# [B, enc_T], [B], [B], [B], [B, tm_dim] outputs = {} memory = [] @@ -1458,7 +1458,18 @@ def inference(self, text_seq, text_lengths, speaker_id, torchmoji_hdn, # (Speaker) speaker_id -> speaker_embed if hasattr(self, "speaker_embedding"): speaker_embed = self.speaker_embedding(speaker_id)# [B, embed] - outputs["speaker_embed"] = speaker_embed# [B, embed] + if multispeaker_mode == "hybrid_voices" and speaker_embed.shape[0] > 1: + splits = int(speaker_embed.shape[0] / 2) + mix_1, mix_2 = torch.split(speaker_embed, splits) + speaker_embed = torch.add(mix_1, mix_2) + speaker_embed = torch.div(speaker_embed, 2) + speaker_embed = speaker_embed.repeat(2, 1) + #outputs["speaker_embed"] = speaker_embed# [B, embed] + #speaker_embed_mix = self.speaker_embedding(speaker_mix)# [B, embed] + #outputs["speaker_embed_mix"] = speaker_embed_mix# [B, embed] + #print(speaker_embed_mix) + #speaker_embed = torch.div(torch.add(speaker_embed, speaker_embed_mix), 2) + outputs["speaker_embed"] = speaker_embed # (TorchMoji) if hasattr(self, 'tm_bn'): diff --git a/CookieTTS/_2_ttm/VDVAETTS/train.py b/CookieTTS/_2_ttm/VDVAETTS/train.py index 42aece9..49981bd 100644 --- a/CookieTTS/_2_ttm/VDVAETTS/train.py +++ b/CookieTTS/_2_ttm/VDVAETTS/train.py @@ -1005,4 +1005,3 @@ def train(args, rank, group_name, hparams): pass train(args, args.rank, args.group_name, hparams) - diff --git a/CookieTTS/_5_infer/VDVAETTS_server/templates/main.html b/CookieTTS/_5_infer/VDVAETTS_server/templates/main.html index a767b71..f7f6c32 100644 --- a/CookieTTS/_5_infer/VDVAETTS_server/templates/main.html +++ b/CookieTTS/_5_infer/VDVAETTS_server/templates/main.html @@ -127,6 +127,7 @@

Text To Speech

+
diff --git a/CookieTTS/_5_infer/VDVAETTS_server/text2speech.py b/CookieTTS/_5_infer/VDVAETTS_server/text2speech.py index b8ed22f..e28cfd7 100644 --- a/CookieTTS/_5_infer/VDVAETTS_server/text2speech.py +++ b/CookieTTS/_5_infer/VDVAETTS_server/text2speech.py @@ -436,20 +436,25 @@ def shuffle_and_return(): speaker_names.append(speaker_names.pop(0)) return first_speaker batch_speaker_names = [shuffle_and_return() for i in range(simultaneous_texts)] + elif multispeaker_mode == "hybrid_voices": + batch_speaker_names = speaker_names * -(-simultaneous_texts//len(speaker_names)) else: raise NotImplementedError if 0:# (optional) use different speaker list for text inside quotes speaker_ids = [random.choice(speakers).split("|")[2] if ('"' in text) else random.choice(narrators).split("|")[2] for text in text_batch] # pick speaker if quotemark in text, else narrator - text_batch = [text.replace('"',"") for text in text_batch] # remove quotes from text + text_batch = [text.replace('"',"") for text in text_batch] # remove quotes from text if len(batch_speaker_names) > len(text_batch): batch_speaker_names = batch_speaker_names[:len(text_batch)] - simultaneous_texts = len(text_batch) + simultaneous_texts = len(text_batch) # get speaker_ids (VDVAETTS) VDVAETTS_speaker_ids = [self.ttm_sp_name_lookup[speaker] for speaker in batch_speaker_names] VDVAETTS_speaker_ids = torch.LongTensor(VDVAETTS_speaker_ids).cuda().repeat_interleave(batch_size_per_text) + #VDVAETTS_speaker_mix = [44] + #print(VDVAETTS_speaker_mix) + #VDVAETTS_speaker_mix = torch.LongTensor(VDVAETTS_speaker_mix).cuda().repeat_interleave(batch_size_per_text) # get style input try: @@ -503,7 +508,7 @@ def shuffle_and_return(): while np.amin(best_score) < target_score: # run VDVAETTS if status_updates: print("..", end='') - outputs = self.VDVAETTS.inference(sequence, text_lengths.repeat_interleave(batch_size_per_text, dim=0), VDVAETTS_speaker_ids, style_input, char_sigma=char_sigma, frame_sigma=frame_sigma) + outputs = self.VDVAETTS.inference(sequence, text_lengths.repeat_interleave(batch_size_per_text, dim=0), VDVAETTS_speaker_ids, style_input, multispeaker_mode, char_sigma=char_sigma, frame_sigma=frame_sigma) batch_pred_mel = outputs['hifigan_inputs'] if self.MTW_conf.uses_latent_input else outputs['pred_mel'] # metric for html side