Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import math | |
| import time | |
| import traceback | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from torch.nn.modules.conv import Conv1d | |
| from python.xvapitch.glow_tts import RelativePositionTransformer | |
| from python.xvapitch.wavenet import WN | |
| from python.xvapitch.hifigan import HifiganGenerator | |
| from python.xvapitch.sdp import StochasticDurationPredictor#, StochasticPredictor | |
| from python.xvapitch.util import maximum_path, rand_segments, segment, sequence_mask, generate_path | |
| from python.xvapitch.text import get_text_preprocessor, ALL_SYMBOLS, lang_names | |
| class xVAPitch(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.args = args | |
| self.args.init_discriminator = True | |
| self.args.speaker_embedding_channels = 512 | |
| self.args.use_spectral_norm_disriminator = False | |
| self.args.d_vector_dim = 512 | |
| self.args.use_language_embedding = True | |
| self.args.detach_dp_input = True | |
| self.END2END = True | |
| self.embedded_language_dim = 12 | |
| self.latent_size = 256 | |
| num_languages = len(list(lang_names.keys())) | |
| self.emb_l = nn.Embedding(num_languages, self.embedded_language_dim) | |
| self.length_scale = 1.0 | |
| self.noise_scale = 1.0 | |
| self.inference_noise_scale = 0.333 | |
| self.inference_noise_scale_dp = 0.333 | |
| self.noise_scale_dp = 1.0 | |
| self.max_inference_len = None | |
| self.spec_segment_size = 32 | |
| self.text_encoder = TextEncoder( | |
| # 165, | |
| len(ALL_SYMBOLS), | |
| self.latent_size,#192, | |
| self.latent_size,#192, | |
| 768, | |
| 2, | |
| 10, | |
| 3, | |
| 0.1, | |
| # language_emb_dim=4, | |
| language_emb_dim=self.embedded_language_dim, | |
| ) | |
| self.posterior_encoder = PosteriorEncoder( | |
| 513, | |
| self.latent_size,#+self.embedded_language_dim if self.args.flc else self.latent_size,#192, | |
| self.latent_size,#+self.embedded_language_dim if self.args.flc else self.latent_size,#192, | |
| kernel_size=5, | |
| dilation_rate=1, | |
| num_layers=16, | |
| cond_channels=self.args.d_vector_dim, | |
| ) | |
| self.flow = ResidualCouplingBlocks( | |
| self.latent_size,#192, | |
| self.latent_size,#192, | |
| kernel_size=5, | |
| dilation_rate=1, | |
| num_layers=4, | |
| cond_channels=self.args.d_vector_dim, | |
| args=self.args | |
| ) | |
| self.duration_predictor = StochasticDurationPredictor( | |
| self.latent_size,#192, | |
| self.latent_size,#192, | |
| 3, | |
| 0.5, | |
| 4, | |
| cond_channels=self.args.d_vector_dim, | |
| language_emb_dim=self.embedded_language_dim, | |
| ) | |
| self.waveform_decoder = HifiganGenerator( | |
| self.latent_size,#192, | |
| 1, | |
| "1", | |
| [[1,3,5],[1,3,5],[1,3,5]], | |
| [3,7,11], | |
| [16,16,4,4], | |
| 512, | |
| [8,8,2,2], | |
| inference_padding=0, | |
| # cond_channels=self.args.d_vector_dim+self.embedded_language_dim if self.args.flc else self.args.d_vector_dim, | |
| cond_channels=self.args.d_vector_dim, | |
| conv_pre_weight_norm=False, | |
| conv_post_weight_norm=False, | |
| conv_post_bias=False, | |
| ) | |
| self.USE_PITCH_COND = False | |
| # self.USE_PITCH_COND = True | |
| if self.USE_PITCH_COND: | |
| self.pitch_predictor = RelativePositioningPitchEnergyEncoder( | |
| # 165, | |
| # len(ALL_SYMBOLS), | |
| out_channels=1, | |
| hidden_channels=self.latent_size+self.embedded_language_dim,#196, | |
| hidden_channels_ffn=768, | |
| num_heads=2, | |
| # num_layers=10, | |
| num_layers=3, | |
| kernel_size=3, | |
| dropout_p=0.1, | |
| # language_emb_dim=4, | |
| conditioning_emb_dim=self.args.d_vector_dim, | |
| ) | |
| self.pitch_emb = nn.Conv1d( | |
| # 1, 384, | |
| # 1, 196, | |
| 1, | |
| self.args.expanded_flow_dim if args.expanded_flow else self.latent_size, | |
| # pitch_conditioning_formants, symbols_embedding_dim, | |
| kernel_size=3, | |
| padding=int((3 - 1) / 2)) | |
| self.TEMP_timing = [] | |
| def infer_get_lang_emb (self, language_id): | |
| aux_input = { | |
| # "d_vectors": embedding.unsqueeze(dim=0), | |
| "language_ids": language_id | |
| } | |
| sid, g, lid = self._set_cond_input(aux_input) | |
| lang_emb = self.emb_l(lid).unsqueeze(-1) | |
| return lang_emb | |
| def infer_advanced (self, logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace=1.0, editor_data=None, old_sequence=None, pitch_amp=None): | |
| if (editor_data is not None) and ((editor_data[0] is not None and len(editor_data[0])) or (editor_data[1] is not None and len(editor_data[1]))): | |
| pitch_pred, dur_pred, energy_pred, em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred, _ = editor_data | |
| # TODO, use energy_pred | |
| dur_pred = torch.tensor(dur_pred) | |
| dur_pred = dur_pred.view((1, dur_pred.shape[0])).float().to(self.device) | |
| pitch_pred = torch.tensor(pitch_pred) | |
| pitch_pred = pitch_pred.view((1, pitch_pred.shape[0])).float().to(self.device) | |
| energy_pred = torch.tensor(energy_pred) | |
| energy_pred = energy_pred.view((1, energy_pred.shape[0])).float().to(self.device) | |
| em_angry_pred = em_angry_pred.clone().detach() if (type(em_angry_pred) == 'torch.Tensor') else torch.tensor(em_angry_pred) | |
| em_angry_pred = em_angry_pred.view((1, em_angry_pred.shape[0])).float().to(self.device) | |
| em_happy_pred = em_happy_pred.clone().detach() if (type(em_happy_pred) == 'torch.Tensor') else torch.tensor(em_happy_pred) | |
| em_happy_pred = em_happy_pred.view((1, em_happy_pred.shape[0])).float().to(self.device) | |
| em_sad_pred = em_sad_pred.clone().detach() if (type(em_sad_pred) == 'torch.Tensor') else torch.tensor(em_sad_pred) | |
| em_sad_pred = em_sad_pred.view((1, em_sad_pred.shape[0])).float().to(self.device) | |
| em_surprise_pred = em_surprise_pred.clone().detach() if (type(em_surprise_pred) == 'torch.Tensor') else torch.tensor(em_surprise_pred) | |
| em_surprise_pred = em_surprise_pred.view((1, em_surprise_pred.shape[0])).float().to(self.device) | |
| # Pitch speaker embedding deltas | |
| if not self.USE_PITCH_COND and pitch_pred.shape[1]==speaker_embs.shape[2]: | |
| pitch_delta = self.pitch_emb_values.to(pitch_pred.device) * pitch_pred | |
| speaker_embs = speaker_embs + pitch_delta.float() | |
| # Emotion speaker embedding deltas | |
| emotions_strength = 0.00003 # Global scaling | |
| if em_angry_pred.shape[1]==speaker_embs.shape[2]: | |
| em_angry_delta = self.angry_emb_values.to(em_angry_pred.device) * em_angry_pred * emotions_strength | |
| speaker_embs = speaker_embs + em_angry_delta.float() | |
| if em_happy_pred.shape[1]==speaker_embs.shape[2]: | |
| em_happy_delta = self.happy_emb_values.to(em_happy_pred.device) * em_happy_pred * emotions_strength | |
| speaker_embs = speaker_embs + em_happy_delta.float() | |
| if em_sad_pred.shape[1]==speaker_embs.shape[2]: | |
| em_sad_delta = self.sad_emb_values.to(em_sad_pred.device) * em_sad_pred * emotions_strength | |
| speaker_embs = speaker_embs + em_sad_delta.float() | |
| if em_surprise_pred.shape[1]==speaker_embs.shape[2]: | |
| em_surprise_delta = self.surprise_emb_values.to(em_surprise_pred.device) * em_surprise_pred * emotions_strength | |
| speaker_embs = speaker_embs + em_surprise_delta.float() | |
| try: | |
| logger.info("editor data infer_using_vals") | |
| wav, dur_pred, pitch_pred_out, energy_pred, em_pred_out, start_index, end_index, wav_mult = self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, \ | |
| speaker_embs, pace, dur_pred_existing=dur_pred, pitch_pred_existing=pitch_pred, energy_pred_existing=energy_pred, em_pred_existing=[em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred], old_sequence=old_sequence, new_sequence=text, pitch_amp=pitch_amp) | |
| [em_angry_pred_out, em_happy_pred_out, em_sad_pred_out, em_surprise_pred_out] = em_pred_out | |
| pitch_pred_out = pitch_pred | |
| em_angry_pred_out = em_angry_pred | |
| em_happy_pred_out = em_happy_pred | |
| em_sad_pred_out = em_sad_pred | |
| em_surprise_pred_out = em_surprise_pred | |
| return wav, dur_pred, pitch_pred_out, energy_pred, [em_angry_pred_out, em_happy_pred_out, em_sad_pred_out, em_surprise_pred_out], start_index, end_index, wav_mult | |
| except: | |
| print(traceback.format_exc()) | |
| logger.info(traceback.format_exc()) | |
| # return traceback.format_exc() | |
| logger.info("editor data corrupt; fallback to infer_using_vals") | |
| return self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace, None, None, None, None, None, None, pitch_amp=pitch_amp) | |
| else: | |
| logger.info("no editor infer_using_vals") | |
| return self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace, None, None, None, None, None, None, pitch_amp=pitch_amp) | |
| def infer_using_vals (self, logger, plugin_manager, cleaned_text, sequence, lang_embs, speaker_embs, pace, dur_pred_existing, pitch_pred_existing, energy_pred_existing, em_pred_existing, old_sequence, new_sequence, pitch_amp=None): | |
| start_index = None | |
| end_index = None | |
| [em_angry_pred_existing, em_happy_pred_existing, em_sad_pred_existing, em_surprise_pred_existing] = em_pred_existing if em_pred_existing is not None else [None, None, None, None] | |
| # Calculate text splicing bounds, if needed | |
| if old_sequence is not None: | |
| old_sequence_np = old_sequence.cpu().detach().numpy() | |
| old_sequence_np = list(old_sequence_np[0]) | |
| new_sequence_np = new_sequence.cpu().detach().numpy() | |
| new_sequence_np = list(new_sequence_np[0]) | |
| # Get the index of the first changed value | |
| if old_sequence_np[0]==new_sequence_np[0]: # If the start of both sequences is the same, then the change is not at the start | |
| for i in range(len(old_sequence_np)): | |
| if i<len(new_sequence_np): | |
| if old_sequence_np[i]!=new_sequence_np[i]: | |
| start_index = i-1 | |
| break | |
| else: | |
| start_index = i-1 | |
| break | |
| if start_index is None: | |
| start_index = len(old_sequence_np)-1 | |
| # Get the index of the last changed value | |
| old_sequence_np.reverse() | |
| new_sequence_np.reverse() | |
| if old_sequence_np[0]==new_sequence_np[0]: # If the end of both reversed sequences is the same, then the change is not at the end | |
| for i in range(len(old_sequence_np)): | |
| if i<len(new_sequence_np): | |
| if old_sequence_np[i]!=new_sequence_np[i]: | |
| end_index = len(old_sequence_np)-1-i+1 | |
| break | |
| else: | |
| end_index = len(old_sequence_np)-1-i+1 | |
| break | |
| old_sequence_np.reverse() | |
| new_sequence_np.reverse() | |
| # cleaned_text is the actual text phonemes | |
| input_symbols = sequence | |
| x_lengths = torch.where(input_symbols > 0, torch.ones_like(input_symbols), torch.zeros_like(input_symbols)).sum(dim=1) | |
| lang_emb_full = None # TODO | |
| self.text_encoder.logger = logger | |
| # TODO, store a bank of trained 31 language embeds, to use for interpolating | |
| lang_emb = self.emb_l(lang_embs).unsqueeze(-1) | |
| if len(lang_embs.shape)>1: # Batch mode | |
| lang_emb_full = lang_emb.squeeze(1).squeeze(-1) | |
| else: # Individual line from the UI | |
| lang_emb_full = lang_emb.transpose(2, 1).squeeze(1).unsqueeze(0) | |
| x, x_emb, x_mask = self.text_encoder(input_symbols, x_lengths, lang_emb=None, stats=False, lang_emb_full=lang_emb_full) | |
| m_p, logs_p = self.text_encoder(x, x_lengths, lang_emb=None, lang_emb_full=lang_emb_full, stats=True, x_mask=x_mask) | |
| lang_emb_full = lang_emb_full.reshape(lang_emb_full.shape[0],lang_emb_full.shape[2],lang_emb_full.shape[1]) | |
| self.inference_noise_scale_dp = 0 # TEMP DEBUGGING. REMOVE - or should I? It seems to make it worse, the higher it is | |
| # Calculate its own pitch, and duration vals if these were not already provided | |
| if (dur_pred_existing is None or dur_pred_existing.shape[1]==0) or old_sequence is not None: | |
| # Predict durations | |
| self.duration_predictor.logger = logger | |
| logw = self.duration_predictor(x, x_mask, g=speaker_embs, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb_full) | |
| w = torch.exp(logw) * x_mask * self.length_scale | |
| # w = w * 1.3 # The model seems to generate quite fast speech, so I'm gonna just globally adjust that | |
| w = w * (pace.unsqueeze(2) if torch.is_tensor(pace) else pace) | |
| w_ceil = w | |
| w_ceil = torch.ceil(w) | |
| dur_pred = w_ceil | |
| else: | |
| dur_pred = dur_pred_existing.unsqueeze(dim=0) | |
| dur_pred = dur_pred * pace | |
| y_lengths = torch.clamp_min(torch.sum(torch.round(dur_pred), [1, 2]), 1).long() | |
| y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) | |
| if dur_pred.shape[0]>1: | |
| attn_all = [] | |
| m_p_all = [] | |
| logs_p_all = [] | |
| for b in range(dur_pred.shape[0]): | |
| attn_mask = torch.unsqueeze(x_mask[b,:].unsqueeze(0), 2) * torch.unsqueeze(y_mask[b,:].unsqueeze(0), -1) | |
| attn_all.append(generate_path(dur_pred.squeeze(1)[b,:].unsqueeze(0), attn_mask.squeeze(0).transpose(1, 2))) | |
| m_p_all.append(torch.matmul(attn_all[-1].transpose(1, 2), m_p[b,:].unsqueeze(0).transpose(1, 2)).transpose(1, 2)) | |
| logs_p_all.append(torch.matmul(attn_all[-1].transpose(1, 2), logs_p[b,:].unsqueeze(0).transpose(1, 2)).transpose(1, 2)) | |
| del attn_all | |
| m_p = torch.stack(m_p_all, dim=1).squeeze(dim=0) | |
| logs_p = torch.stack(logs_p_all, dim=1).squeeze(dim=0) | |
| pitch_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
| else: | |
| attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) | |
| attn = generate_path(dur_pred.squeeze(1), attn_mask.squeeze(0).transpose(1, 2)) | |
| m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) | |
| logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) | |
| pitch_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
| emAngry_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
| emHappy_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
| emSad_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
| emSurprise_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
| # Splice/replace pitch/duration values from the old input if simulating only a partial re-generation | |
| if start_index is not None or end_index is not None: | |
| dur_pred_np = list(dur_pred.cpu().detach().numpy())[0][0] | |
| pitch_pred_np = list(pitch_pred.cpu().detach().numpy())[0][0] | |
| emAngry_pred_np = list(emAngry_pred.cpu().detach().numpy())[0][0] | |
| emHappy_pred_np = list(emHappy_pred.cpu().detach().numpy())[0][0] | |
| emSad_pred_np = list(emSad_pred.cpu().detach().numpy())[0][0] | |
| emSurprise_pred_np = list(emSurprise_pred.cpu().detach().numpy())[0][0] | |
| dur_pred_existing_np = list(dur_pred_existing.cpu().detach().numpy())[0] | |
| pitch_pred_existing_np = list(pitch_pred_existing.cpu().detach().numpy())[0] | |
| emAngry_pred_existing_np = list(em_angry_pred_existing.cpu().detach().numpy())[0] | |
| emHappy_pred_existing_np = list(em_happy_pred_existing.cpu().detach().numpy())[0] | |
| emSad_pred_existing_np = list(em_sad_pred_existing.cpu().detach().numpy())[0] | |
| emSurprise_pred_existing_np = list(em_surprise_pred_existing.cpu().detach().numpy())[0] | |
| if start_index is not None: # Replace starting values | |
| for i in range(start_index+1): | |
| dur_pred_np[i] = dur_pred_existing_np[i] | |
| pitch_pred_np[i] = pitch_pred_existing_np[i] | |
| emAngry_pred_np[i] = emAngry_pred_existing_np[i] | |
| emHappy_pred_np[i] = emHappy_pred_existing_np[i] | |
| emSad_pred_np[i] = emSad_pred_existing_np[i] | |
| emSurprise_pred_np[i] = emSurprise_pred_existing_np[i] | |
| if end_index is not None: # Replace end values | |
| for i in range(len(old_sequence_np)-end_index): | |
| dur_pred_np[-i-1] = dur_pred_existing_np[-i-1] | |
| pitch_pred_np[-i-1] = pitch_pred_existing_np[-i-1] | |
| emAngry_pred_np[-i-1] = emAngry_pred_existing_np[-i-1] | |
| emHappy_pred_np[-i-1] = emHappy_pred_existing_np[-i-1] | |
| emSad_pred_np[-i-1] = emSad_pred_existing_np[-i-1] | |
| emSurprise_pred_np[-i-1] = emSurprise_pred_existing_np[-i-1] | |
| dur_pred = torch.tensor(dur_pred_np).to(self.device).unsqueeze(0) | |
| pitch_pred = torch.tensor(pitch_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
| emAngry_pred = torch.tensor(emAngry_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
| emHappy_pred = torch.tensor(emHappy_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
| emSad_pred = torch.tensor(emSad_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
| emSurprise_pred = torch.tensor(emSurprise_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
| if pitch_amp is not None: | |
| pitch_pred = pitch_pred * pitch_amp.unsqueeze(dim=-1) | |
| if plugin_manager is not None and len(plugin_manager.plugins["synth-line"]["mid"]): | |
| pitch_pred_numpy = pitch_pred.cpu().detach().numpy() | |
| plugin_data = { | |
| "pace": pace, | |
| "duration": dur_pred.cpu().detach().numpy(), | |
| "pitch": pitch_pred_numpy.reshape((pitch_pred_numpy.shape[0],pitch_pred_numpy.shape[2])), | |
| "emAngry": emAngry_pred.reshape((emAngry_pred.shape[0],emAngry_pred.shape[2])), | |
| "emHappy": emHappy_pred.reshape((emHappy_pred.shape[0],emHappy_pred.shape[2])), | |
| "emSad": emSad_pred.reshape((emSad_pred.shape[0],emSad_pred.shape[2])), | |
| "emSurprise": emSurprise_pred.reshape((emSurprise_pred.shape[0],emSurprise_pred.shape[2])), | |
| "sequence": sequence, | |
| "is_fresh_synth": pitch_pred_existing is None and dur_pred_existing is None, | |
| "pluginsContext": plugin_manager.context, | |
| "hasDataChanged": False | |
| } | |
| plugin_manager.run_plugins(plist=plugin_manager.plugins["synth-line"]["mid"], event="mid synth-line", data=plugin_data) | |
| if ( | |
| pace != plugin_data["pace"] | |
| or plugin_data["hasDataChanged"] | |
| ): | |
| logger.info("Inference data has been changed by plugins, rerunning infer_advanced") | |
| pace = plugin_data["pace"] | |
| editor_data = [ | |
| plugin_data["pitch"][0], | |
| plugin_data["duration"][0][0], | |
| [1.0 for _ in range(pitch_pred_numpy.shape[-1])], | |
| plugin_data["emAngry"][0], | |
| plugin_data["emHappy"][0], | |
| plugin_data["emSad"][0], | |
| plugin_data["emSurprise"][0], | |
| None | |
| ] | |
| # rerun infer_advanced so that emValues take effect | |
| # second argument ensures no loop | |
| return self.infer_advanced (logger, None, cleaned_text, sequence, lang_embs, speaker_embs, pace=pace, editor_data=editor_data, old_sequence=sequence, pitch_amp=None) | |
| else: | |
| # skip rerunning infer_advanced | |
| logger.info("Inference data unchanged by plugins") | |
| # TODO, incorporate some sort of control for this | |
| # self.inference_noise_scale = 0 | |
| # for flow in self.flow.flows: | |
| # flow.logger = logger | |
| # flow.enc.logger = logger | |
| z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale | |
| z = self.flow(z_p, y_mask, g=speaker_embs, reverse=True) | |
| self.waveform_decoder.logger = logger | |
| wav = self.waveform_decoder((z * y_mask.unsqueeze(1))[:, :, : self.max_inference_len], g=speaker_embs) | |
| # In batch mode, trim the shorter audio waves in the batch. The masking doesn't seem to work, so have to do it manually | |
| if dur_pred.shape[0]>1: | |
| wav_all = [] | |
| for b in range(dur_pred.shape[0]): | |
| percent_to_mask = torch.sum(y_mask[b])/y_mask.shape[1] | |
| wav_all.append(wav[b,0,0:int((wav.shape[2]*percent_to_mask).item())]) | |
| wav = wav_all | |
| start_index = -1 if start_index is None else start_index | |
| end_index = -1 if end_index is None else end_index | |
| # Apply volume adjustments | |
| stretched_energy_mult = None | |
| if energy_pred_existing is not None and pitch_pred_existing is not None: | |
| energy_mult = self.expand_vals_by_durations(energy_pred_existing.unsqueeze(0), dur_pred, logger=logger) | |
| stretched_energy_mult = torch.nn.functional.interpolate(energy_mult.unsqueeze(0).unsqueeze(0), (1,1,wav.shape[2])).squeeze() | |
| stretched_energy_mult = stretched_energy_mult.cpu().detach().numpy() | |
| energy_pred = energy_pred_existing.squeeze() | |
| else: | |
| energy_pred = [1.0 for _ in range(pitch_pred.shape[-1])] | |
| energy_pred = torch.tensor(energy_pred) | |
| # energy_pred = energy_pred.squeeze() | |
| em_pred_out = [emAngry_pred, emHappy_pred, emSad_pred, emSurprise_pred] | |
| return wav, dur_pred, pitch_pred, energy_pred, em_pred_out, start_index, end_index, stretched_energy_mult | |
| def voice_conversion(self, y, y_lengths=None, spk1_emb=None, spk2_emb=None): | |
| if y_lengths is None: | |
| y_lengths = self.y_lengths_default | |
| z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=spk1_emb) | |
| # z_hat = z | |
| y_mask = y_mask.squeeze(0) | |
| z_p = self.flow(z, y_mask, g=spk1_emb) | |
| z_hat = self.flow(z_p, y_mask, g=spk2_emb, reverse=True) | |
| o_hat = self.waveform_decoder(z_hat * y_mask, g=spk2_emb) | |
| return o_hat | |
| def _set_cond_input (self, aux_input): | |
| """Set the speaker conditioning input based on the multi-speaker mode.""" | |
| sid, g, lid = None, None, None | |
| # if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: | |
| # sid = aux_input["speaker_ids"] | |
| # if sid.ndim == 0: | |
| # sid = sid.unsqueeze_(0) | |
| if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: | |
| g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) | |
| if g.ndim == 2: | |
| g = g.unsqueeze_(0) | |
| if "language_ids" in aux_input and aux_input["language_ids"] is not None: | |
| lid = aux_input["language_ids"] | |
| if lid.ndim == 0: | |
| lid = lid.unsqueeze_(0) | |
| return sid, g, lid | |
| # Opposite of average_pitch; Repeat per-symbol values by durations, to get sequence-wide values | |
| def expand_vals_by_durations (self, vals, durations, logger=None): | |
| vals = vals.view((vals.shape[0], vals.shape[2])) | |
| if len(durations.shape)>2: | |
| durations = durations.view((durations.shape[0], durations.shape[2])) | |
| max_dur = int(torch.round(durations).sum().item()) | |
| max_dur = int(torch.max(torch.sum(torch.round(durations), dim=1)).item()) | |
| expanded = torch.zeros((vals.shape[0], 1, max_dur)).to(vals) | |
| for b in range(vals.shape[0]): | |
| b_vals = vals[b] | |
| b_durs = durations[b] | |
| expanded_vals = [] | |
| for vi in range(b_vals.shape[0]): | |
| for dur_i in range(round(b_durs[vi].item())): | |
| if len(durations.shape)>2: | |
| expanded_vals.append(b_vals[vi]) | |
| else: | |
| expanded_vals.append(b_vals[vi].unsqueeze(dim=0)) | |
| expanded_vals = torch.tensor(expanded_vals).to(expanded) | |
| expanded[b,:,0:expanded_vals.shape[0]] += expanded_vals | |
| return expanded | |
| class TextEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| n_vocab: int, # len(ALL_SYMBOLS) | |
| out_channels: int, # 192 | |
| hidden_channels: int, # 192 | |
| hidden_channels_ffn: int, # 768 | |
| num_heads: int, # 2 | |
| num_layers: int, # 10 | |
| kernel_size: int, # 3 | |
| dropout_p: float, # 0.1 | |
| language_emb_dim: int = None, | |
| ): | |
| """Text Encoder for VITS model. | |
| Args: | |
| n_vocab (int): Number of characters for the embedding layer. | |
| out_channels (int): Number of channels for the output. | |
| hidden_channels (int): Number of channels for the hidden layers. | |
| hidden_channels_ffn (int): Number of channels for the convolutional layers. | |
| num_heads (int): Number of attention heads for the Transformer layers. | |
| num_layers (int): Number of Transformer layers. | |
| kernel_size (int): Kernel size for the FFN layers in Transformer network. | |
| dropout_p (float): Dropout rate for the Transformer layers. | |
| """ | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.hidden_channels = hidden_channels | |
| self.emb = nn.Embedding(n_vocab, hidden_channels) | |
| nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) | |
| if language_emb_dim: | |
| hidden_channels += language_emb_dim | |
| self.encoder = RelativePositionTransformer( | |
| in_channels=hidden_channels, | |
| out_channels=hidden_channels, | |
| hidden_channels=hidden_channels, | |
| hidden_channels_ffn=hidden_channels_ffn, | |
| num_heads=num_heads, | |
| num_layers=num_layers, | |
| kernel_size=kernel_size, | |
| dropout_p=dropout_p, | |
| layer_norm_type="2", | |
| rel_attn_window_size=4, | |
| ) | |
| self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
| def forward(self, x, x_lengths, lang_emb=None, stats=False, x_mask=None, lang_emb_full=None): | |
| """ | |
| Shapes: | |
| - x: :math:`[B, T]` | |
| - x_length: :math:`[B]` | |
| """ | |
| if stats: | |
| stats = self.proj(x) * x_mask | |
| m, logs = torch.split(stats, self.out_channels, dim=1) | |
| return m, logs | |
| else: | |
| x_emb = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] | |
| # concat the lang emb in embedding chars | |
| if lang_emb is not None or lang_emb_full is not None: | |
| # x = torch.cat((x_emb, lang_emb.transpose(2, 1).expand(x_emb.size(0), x_emb.size(1), -1)), dim=-1) | |
| if lang_emb_full is None: | |
| lang_emb_full = lang_emb.transpose(2, 1).expand(x_emb.size(0), x_emb.size(1), -1) | |
| x = torch.cat((x_emb, lang_emb_full), dim=-1) | |
| x = torch.transpose(x, 1, -1) # [b, h, t] | |
| x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) | |
| x = self.encoder(x * x_mask, x_mask) | |
| # stats = self.proj(x) * x_mask | |
| # m, logs = torch.split(stats, self.out_channels, dim=1) | |
| return x, x_emb, x_mask | |
| class RelativePositioningPitchEnergyEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| # n_vocab: int, # len(ALL_SYMBOLS) | |
| out_channels: int, # 192 | |
| hidden_channels: int, # 192 | |
| hidden_channels_ffn: int, # 768 | |
| num_heads: int, # 2 | |
| num_layers: int, # 10 | |
| kernel_size: int, # 3 | |
| dropout_p: float, # 0.1 | |
| conditioning_emb_dim: int = None, | |
| ): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.hidden_channels = hidden_channels | |
| # self.emb = nn.Embedding(n_vocab, hidden_channels) | |
| # nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) | |
| if conditioning_emb_dim: | |
| hidden_channels += conditioning_emb_dim | |
| self.encoder = RelativePositionTransformer( | |
| in_channels=hidden_channels, | |
| # out_channels=hidden_channels, | |
| out_channels=1, | |
| # out_channels=196, | |
| hidden_channels=hidden_channels, | |
| hidden_channels_ffn=hidden_channels_ffn, | |
| num_heads=num_heads, | |
| num_layers=num_layers, | |
| kernel_size=kernel_size, | |
| dropout_p=dropout_p, | |
| layer_norm_type="2", | |
| rel_attn_window_size=4, | |
| ) | |
| # self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
| # self.proj = nn.Conv1d(196, out_channels * 2, 1) | |
| def forward(self, x, x_lengths=None, speaker_emb=None, stats=False, x_mask=None): | |
| """ | |
| Shapes: | |
| - x: :math:`[B, T]` | |
| - x_length: :math:`[B]` | |
| """ | |
| # concat the lang emb in embedding chars | |
| if speaker_emb is not None: | |
| x = torch.cat((x, speaker_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) | |
| x = torch.transpose(x, 1, -1) # [b, h, t] | |
| x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) | |
| x = self.encoder(x * x_mask, x_mask) | |
| return x#, x_mask | |
| class ResidualCouplingBlocks(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| hidden_channels: int, | |
| kernel_size: int, | |
| dilation_rate: int, | |
| num_layers: int, | |
| num_flows=4, | |
| cond_channels=0, | |
| args=None | |
| ): | |
| """Redisual Coupling blocks for VITS flow layers. | |
| Args: | |
| channels (int): Number of input and output tensor channels. | |
| hidden_channels (int): Number of hidden network channels. | |
| kernel_size (int): Kernel size of the WaveNet layers. | |
| dilation_rate (int): Dilation rate of the WaveNet layers. | |
| num_layers (int): Number of the WaveNet layers. | |
| num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4. | |
| cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0. | |
| """ | |
| super().__init__() | |
| self.args = args | |
| self.channels = channels | |
| self.hidden_channels = hidden_channels | |
| self.kernel_size = kernel_size | |
| self.dilation_rate = dilation_rate | |
| self.num_layers = num_layers | |
| self.num_flows = num_flows | |
| self.cond_channels = cond_channels | |
| self.flows = nn.ModuleList() | |
| for flow_i in range(num_flows): | |
| self.flows.append( | |
| ResidualCouplingBlock( | |
| (192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else channels, | |
| (192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else hidden_channels, | |
| kernel_size, | |
| dilation_rate, | |
| num_layers, | |
| cond_channels=cond_channels, | |
| out_channels_override=(192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else None, | |
| mean_only=True, | |
| ) | |
| ) | |
| def forward(self, x, x_mask, g=None, reverse=False): | |
| """ | |
| Shapes: | |
| - x: :math:`[B, C, T]` | |
| - x_mask: :math:`[B, 1, T]` | |
| - g: :math:`[B, C, 1]` | |
| """ | |
| if not reverse: | |
| for fi, flow in enumerate(self.flows): | |
| x, _ = flow(x, x_mask, g=g, reverse=reverse) | |
| x = torch.flip(x, [1]) | |
| else: | |
| for flow in reversed(self.flows): | |
| x = torch.flip(x, [1]) | |
| x = flow(x, x_mask, g=g, reverse=reverse) | |
| return x | |
| class PosteriorEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| hidden_channels: int, | |
| kernel_size: int, | |
| dilation_rate: int, | |
| num_layers: int, | |
| cond_channels=0, | |
| ): | |
| """Posterior Encoder of VITS model. | |
| :: | |
| x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z | |
| Args: | |
| in_channels (int): Number of input tensor channels. | |
| out_channels (int): Number of output tensor channels. | |
| hidden_channels (int): Number of hidden channels. | |
| kernel_size (int): Kernel size of the WaveNet convolution layers. | |
| dilation_rate (int): Dilation rate of the WaveNet layers. | |
| num_layers (int): Number of the WaveNet layers. | |
| cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. | |
| """ | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.hidden_channels = hidden_channels | |
| self.kernel_size = kernel_size | |
| self.dilation_rate = dilation_rate | |
| self.num_layers = num_layers | |
| self.cond_channels = cond_channels | |
| self.pre = nn.Conv1d(in_channels, hidden_channels, 1) | |
| self.enc = WN( | |
| hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels | |
| ) | |
| self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
| def forward(self, x, x_lengths, g=None): | |
| """ | |
| Shapes: | |
| - x: :math:`[B, C, T]` | |
| - x_lengths: :math:`[B, 1]` | |
| - g: :math:`[B, C, 1]` | |
| """ | |
| x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) | |
| x = self.pre(x) * x_mask | |
| x = self.enc(x, x_mask, g=g) | |
| stats = self.proj(x) * x_mask | |
| mean, log_scale = torch.split(stats, self.out_channels, dim=1) | |
| z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask | |
| return z, mean, log_scale, x_mask | |
| class ResidualCouplingBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| hidden_channels, | |
| kernel_size, | |
| dilation_rate, | |
| num_layers, | |
| dropout_p=0, | |
| cond_channels=0, | |
| out_channels_override=None, | |
| mean_only=False, | |
| ): | |
| assert channels % 2 == 0, "channels should be divisible by 2" | |
| super().__init__() | |
| self.half_channels = channels // 2 | |
| self.mean_only = mean_only | |
| # input layer | |
| self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) | |
| # coupling layers | |
| self.enc = WN( | |
| hidden_channels, | |
| hidden_channels, | |
| kernel_size, | |
| dilation_rate, | |
| num_layers, | |
| dropout_p=dropout_p, | |
| c_in_channels=cond_channels, | |
| ) | |
| # output layer | |
| # Initializing last layer to 0 makes the affine coupling layers | |
| # do nothing at first. This helps with training stability | |
| self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) | |
| self.conv1d_projector = None | |
| if out_channels_override: | |
| self.conv1d_projector = nn.Conv1d(192, out_channels_override, 1) | |
| def forward(self, x, x_mask, g=None, reverse=False): | |
| """ | |
| Shapes: | |
| - x: :math:`[B, C, T]` | |
| - x_mask: :math:`[B, 1, T]` | |
| - g: :math:`[B, C, 1]` | |
| """ | |
| if self.conv1d_projector is not None and not reverse: | |
| x = self.conv1d_projector(x) | |
| x0, x1 = torch.split(x, [self.half_channels] * 2, 1) | |
| h = self.pre(x0) * x_mask.unsqueeze(1) | |
| h = self.enc(h, x_mask.unsqueeze(1), g=g) | |
| stats = self.post(h) * x_mask.unsqueeze(1) | |
| if not self.mean_only: | |
| m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) | |
| else: | |
| m = stats | |
| log_scale = torch.zeros_like(m) | |
| if not reverse: | |
| x1 = m + x1 * torch.exp(log_scale) * x_mask.unsqueeze(1) | |
| x = torch.cat([x0, x1], 1) | |
| logdet = torch.sum(log_scale, [1, 2]) | |
| return x, logdet | |
| else: | |
| x1 = (x1 - m) * torch.exp(-log_scale) * x_mask.unsqueeze(1) | |
| x = torch.cat([x0, x1], 1) | |
| return x | |
| def mask_from_lens(lens, max_len= None): | |
| if max_len is None: | |
| max_len = lens.max() | |
| ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype) | |
| mask = torch.lt(ids, lens.unsqueeze(1)) | |
| return mask | |
| class TemporalPredictor(nn.Module): | |
| """Predicts a single float per each temporal location""" | |
| def __init__(self, input_size, filter_size, kernel_size, dropout, | |
| n_layers=2, n_predictions=1): | |
| super(TemporalPredictor, self).__init__() | |
| self.layers = nn.Sequential(*[ | |
| ConvReLUNorm(input_size if i == 0 else filter_size, filter_size, | |
| kernel_size=kernel_size, dropout=dropout) | |
| for i in range(n_layers)] | |
| ) | |
| self.n_predictions = n_predictions | |
| self.fc = nn.Linear(filter_size, self.n_predictions, bias=True) | |
| def forward(self, enc_out, enc_out_mask): | |
| out = enc_out * enc_out_mask | |
| out = self.layers(out.transpose(1, 2)).transpose(1, 2) | |
| out = self.fc(out) * enc_out_mask | |
| return out | |
| class ConvReLUNorm(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0): | |
| super(ConvReLUNorm, self).__init__() | |
| self.conv = torch.nn.Conv1d(in_channels, out_channels, | |
| kernel_size=kernel_size, | |
| padding=(kernel_size // 2)) | |
| self.norm = torch.nn.LayerNorm(out_channels) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| def forward(self, signal): | |
| out = F.relu(self.conv(signal)) | |
| out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype) | |
| return self.dropout(out) | |