mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-04-08 16:04:16 +00:00
Add in ASR filtration
This commit is contained in:
parent
9ad0f0e6e8
commit
c66954b6a6
5 changed files with 70 additions and 66 deletions
|
|
@ -486,66 +486,40 @@ class DiffusionTts(nn.Module):
|
|||
aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
|
||||
return x, aligned_conditioning
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
||||
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
|
||||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert conditioning_input is not None
|
||||
if self.super_sampling_enabled:
|
||||
assert lr_input is not None
|
||||
if self.training and self.super_sampling_max_noising_factor > 0:
|
||||
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
|
||||
lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
|
||||
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
||||
x = torch.cat([x, lr_input], dim=1)
|
||||
|
||||
def timestep_independent(self, aligned_conditioning, conditioning_input):
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||
|
||||
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
||||
orig_x_shape = x.shape[-1]
|
||||
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
|
||||
with autocast(aligned_conditioning.device.type, enabled=self.enable_fp16):
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if len(cond_emb.shape) == 3: # Just take the first element.
|
||||
cond_emb = cond_emb[:, :, 0]
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_converter(aligned_conditioning)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
|
||||
return code_emb
|
||||
|
||||
def forward(self, x, timesteps, precomputed_aligned_embeddings, conditioning_free=False):
|
||||
assert x.shape[-1] % self.alignment_size == 0
|
||||
|
||||
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||
hs = []
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
else:
|
||||
cond_emb = self.contextual_embedder(conditioning_input)
|
||||
if len(cond_emb.shape) == 3: # Just take the first element.
|
||||
cond_emb = cond_emb[:, :, 0]
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_converter(aligned_conditioning)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
|
||||
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
|
||||
code_emb)
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
hs = []
|
||||
for k, module in enumerate(self.input_blocks):
|
||||
if isinstance(module, nn.Conv1d):
|
||||
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||
|
|
@ -565,14 +539,7 @@ class DiffusionTts(nn.Module):
|
|||
h = h.float()
|
||||
out = self.out(h)
|
||||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
extraneous_addition = 0
|
||||
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
|
||||
for p in params:
|
||||
extraneous_addition = extraneous_addition + p.mean()
|
||||
out = out + extraneous_addition * 0
|
||||
|
||||
return out[:, :, :orig_x_shape]
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue