mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-04-08 16:04:16 +00:00
Upgrade CLIP model and add eval_multiple
This commit is contained in:
parent
c66954b6a6
commit
b78ae92890
6 changed files with 350 additions and 67 deletions
|
|
@ -15,7 +15,8 @@ from torch.nn import Linear
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
from x_transformers import ContinuousTransformerWrapper, Encoder
|
||||
|
||||
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
|
||||
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock, \
|
||||
CheckpointedXTransformerEncoder
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
|
|
@ -157,43 +158,6 @@ class ResBlock(TimestepBlock):
|
|||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class CheckpointedLayer(nn.Module):
|
||||
"""
|
||||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
||||
checkpoint for all other args.
|
||||
"""
|
||||
def __init__(self, wrap):
|
||||
super().__init__()
|
||||
self.wrap = wrap
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
||||
partial = functools.partial(self.wrap, **kwargs)
|
||||
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
||||
|
||||
|
||||
class CheckpointedXTransformerEncoder(nn.Module):
|
||||
"""
|
||||
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||
to channels-last that XTransformer expects.
|
||||
"""
|
||||
def __init__(self, needs_permute=True, **xtransformer_kwargs):
|
||||
super().__init__()
|
||||
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
||||
self.needs_permute = needs_permute
|
||||
|
||||
for i in range(len(self.transformer.attn_layers.layers)):
|
||||
n, b, r = self.transformer.attn_layers.layers[i]
|
||||
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
if self.needs_permute:
|
||||
x = x.permute(0,2,1)
|
||||
h = self.transformer(x, **kwargs)
|
||||
return h.permute(0,2,1)
|
||||
|
||||
|
||||
class DiffusionTts(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue