diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 03aa29f..4d04908 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -3,7 +3,6 @@ import functools import torch import torch.nn as nn import torch.nn.functional as F -import deepspeed from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from transformers.utils.model_parallel_utils import get_device_map, assert_device_map @@ -362,6 +361,7 @@ class UnifiedVoice(nn.Module): self.mel_head ) if use_deepspeed: + import deepspeed self.ds_engine = deepspeed.init_inference(model=self.inference_model, mp_size=1, replace_with_kernel_inject=True,