mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-06 08:40:24 +01:00
import deepspeed only if use_deepspeed is True
This commit is contained in:
parent
31a2e153ff
commit
82724cca54
|
|
@ -3,7 +3,6 @@ import functools
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import deepspeed
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
|
|
@ -362,6 +361,7 @@ class UnifiedVoice(nn.Module):
|
|||
self.mel_head
|
||||
)
|
||||
if use_deepspeed:
|
||||
import deepspeed
|
||||
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
||||
mp_size=1,
|
||||
replace_with_kernel_inject=True,
|
||||
|
|
|
|||
Loading…
Reference in a new issue