added deepspeed inference

This commit is contained in:
manmay-nakhashi 2023-07-09 18:40:10 +05:30
parent c27641791a
commit 5a9707d93c
4 changed files with 311 additions and 200 deletions

View file

@ -1,6 +1,6 @@
tqdm
rotary_embedding_torch
transformers==4.29.2
transformers==4.19
tokenizers
inflect
progressbar
@ -15,3 +15,7 @@ torchaudio
threadpoolctl
llvmlite
appdirs
nbconvert==5.3.1
tornado==4.2
pydantic==1.9.0
deepspeed=9.0.0

View file

@ -198,7 +198,7 @@ class TextToSpeech:
Main entry point into Tortoise.
"""
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None):
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, use_deepspeed=False, device=None):
"""
Constructor
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@ -229,7 +229,8 @@ class TextToSpeech:
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval()
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed)
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
layer_drop=0, unconditioned_percentage=0).cpu().eval()

View file

@ -3,6 +3,7 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepspeed
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
@ -340,7 +341,34 @@ class UnifiedVoice(nn.Module):
embeddings.append(self.mel_embedding)
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def post_init_gpt2_config(self, use_deepspeed=False):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True,
)
self.inference_model = GPT2InferenceModel(
gpt_config,
self.gpt,
self.mel_pos_embedding,
self.mel_embedding,
self.final_norm,
self.mel_head
)
if use_deepspeed:
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
dtype=torch.float32)
self.inference_model = self.ds_engine.module.eval()
# self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
@ -458,23 +486,10 @@ class UnifiedVoice(nn.Module):
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
conds = speech_conditioning_latent.unsqueeze(1)

File diff suppressed because one or more lines are too long