mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-07 00:59:59 +01:00
added deepspeed inference
This commit is contained in:
parent
c27641791a
commit
5a9707d93c
|
|
@ -1,6 +1,6 @@
|
|||
tqdm
|
||||
rotary_embedding_torch
|
||||
transformers==4.29.2
|
||||
transformers==4.19
|
||||
tokenizers
|
||||
inflect
|
||||
progressbar
|
||||
|
|
@ -15,3 +15,7 @@ torchaudio
|
|||
threadpoolctl
|
||||
llvmlite
|
||||
appdirs
|
||||
nbconvert==5.3.1
|
||||
tornado==4.2
|
||||
pydantic==1.9.0
|
||||
deepspeed=9.0.0
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ class TextToSpeech:
|
|||
Main entry point into Tortoise.
|
||||
"""
|
||||
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None):
|
||||
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, use_deepspeed=False, device=None):
|
||||
"""
|
||||
Constructor
|
||||
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
|
||||
|
|
@ -229,7 +229,8 @@ class TextToSpeech:
|
|||
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
|
||||
train_solo_embeddings=False).cpu().eval()
|
||||
self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)))
|
||||
|
||||
self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed)
|
||||
|
||||
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
||||
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import functools
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import deepspeed
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
|
|
@ -340,7 +341,34 @@ class UnifiedVoice(nn.Module):
|
|||
embeddings.append(self.mel_embedding)
|
||||
for module in embeddings:
|
||||
module.weight.data.normal_(mean=0.0, std=.02)
|
||||
|
||||
def post_init_gpt2_config(self, use_deepspeed=False):
|
||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_mel_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
)
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
gpt_config,
|
||||
self.gpt,
|
||||
self.mel_pos_embedding,
|
||||
self.mel_embedding,
|
||||
self.final_norm,
|
||||
self.mel_head
|
||||
)
|
||||
if use_deepspeed:
|
||||
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
||||
mp_size=1,
|
||||
replace_with_kernel_inject=True,
|
||||
dtype=torch.float32)
|
||||
self.inference_model = self.ds_engine.module.eval()
|
||||
# self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1,0), value=start_token)
|
||||
tar = F.pad(input, (0,1), value=stop_token)
|
||||
|
|
@ -458,23 +486,10 @@ class UnifiedVoice(nn.Module):
|
|||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1,
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
||||
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||
if not hasattr(self, 'inference_model'):
|
||||
# TODO: Decouple gpt_config from this inference model.
|
||||
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True)
|
||||
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
conds = speech_conditioning_latent.unsqueeze(1)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Loading…
Reference in a new issue