Upgrade CLIP model and add eval_multiple

This commit is contained in:
James Betker 2022-03-28 19:33:31 -06:00
parent c66954b6a6
commit b78ae92890
6 changed files with 350 additions and 67 deletions

View file

@ -15,7 +15,8 @@ from torch.nn import Linear
from torch.utils.checkpoint import checkpoint
from x_transformers import ContinuousTransformerWrapper, Encoder
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock, \
CheckpointedXTransformerEncoder
def is_latent(t):
@ -157,43 +158,6 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h
class CheckpointedLayer(nn.Module):
"""
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, x, *args, **kwargs):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
class CheckpointedXTransformerEncoder(nn.Module):
"""
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects.
"""
def __init__(self, needs_permute=True, **xtransformer_kwargs):
super().__init__()
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
self.needs_permute = needs_permute
for i in range(len(self.transformer.attn_layers.layers)):
n, b, r = self.transformer.attn_layers.layers[i]
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
def forward(self, x, **kwargs):
if self.needs_permute:
x = x.permute(0,2,1)
h = self.transformer(x, **kwargs)
return h.permute(0,2,1)
class DiffusionTts(nn.Module):
"""
The full UNet model with attention and timestep embedding.