diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 4c91534..7483d45 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -375,14 +375,14 @@ class UnifiedVoice(nn.Module): self.mel_head, kv_cache=kv_cache, ) - if use_deepspeed and half and torch.backends.cuda.is_available(): + if use_deepspeed and half and torch.cuda.is_available(): import deepspeed self.ds_engine = deepspeed.init_inference(model=self.inference_model, mp_size=1, replace_with_kernel_inject=True, dtype=torch.float16) self.inference_model = self.ds_engine.module.eval() - elif use_deepspeed and torch.backends.cuda.is_available(): + elif use_deepspeed and torch.cuda.is_available(): import deepspeed self.ds_engine = deepspeed.init_inference(model=self.inference_model, mp_size=1,