mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-04-08 16:04:16 +00:00
integrate new autoregressive model and fix new diffusion bug
This commit is contained in:
parent
9043dde3f9
commit
33e4bc7907
5 changed files with 549 additions and 10 deletions
|
|
@ -212,7 +212,7 @@ class DiffusionTts(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
|
||||
def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||
|
|
@ -227,7 +227,7 @@ class DiffusionTts(nn.Module):
|
|||
cond_emb = conds.mean(dim=-1)
|
||||
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
code_emb = self.autoregressive_latent_converter(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
|
|
@ -240,7 +240,7 @@ class DiffusionTts(nn.Module):
|
|||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
|
|
@ -250,7 +250,6 @@ class DiffusionTts(nn.Module):
|
|||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
||||
return expanded_code_emb, mel_pred
|
||||
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
|
@ -275,11 +274,12 @@ class DiffusionTts(nn.Module):
|
|||
if precomputed_aligned_embeddings is not None:
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
else:
|
||||
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
|
||||
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
|
||||
if is_latent(aligned_conditioning):
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
else:
|
||||
unused_params.extend(list(self.latent_converter.parameters()))
|
||||
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue