mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-04-08 16:04:16 +00:00
support latents into the diffusion decoder
This commit is contained in:
parent
e2ee843098
commit
3214ca0dfe
5 changed files with 55 additions and 315 deletions
|
|
@ -176,7 +176,13 @@ class DiffusionTts(nn.Module):
|
|||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
)
|
||||
self.code_norm = normalization(model_channels)
|
||||
self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
|
||||
self.latent_conditioner = nn.Sequential(
|
||||
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||
)
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
|
|
@ -190,6 +196,7 @@ class DiffusionTts(nn.Module):
|
|||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
)
|
||||
|
||||
self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
|
||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
|
|
@ -206,7 +213,7 @@ class DiffusionTts(nn.Module):
|
|||
groups = {
|
||||
'minicoder': list(self.contextual_embedder.parameters()),
|
||||
'layers': list(self.layers.parameters()),
|
||||
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
|
||||
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()),
|
||||
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
|
||||
'time_embed': list(self.time_embed.parameters()),
|
||||
}
|
||||
|
|
@ -227,7 +234,7 @@ class DiffusionTts(nn.Module):
|
|||
cond_emb = conds.mean(dim=-1)
|
||||
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.autoregressive_latent_converter(aligned_conditioning)
|
||||
code_emb = self.latent_conditioner(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
|
|
@ -269,7 +276,7 @@ class DiffusionTts(nn.Module):
|
|||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(list(self.latent_converter.parameters()))
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
else:
|
||||
if precomputed_aligned_embeddings is not None:
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
|
|
@ -278,7 +285,7 @@ class DiffusionTts(nn.Module):
|
|||
if is_latent(aligned_conditioning):
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
else:
|
||||
unused_params.extend(list(self.latent_converter.parameters()))
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue