mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-24 09:30:27 +01:00
Merge pull request #565 from NourEldin-Osama/main-1
Update autoregressive.py
This commit is contained in:
commit
4003544b6f
|
|
@ -375,14 +375,14 @@ class UnifiedVoice(nn.Module):
|
|||
self.mel_head,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
if use_deepspeed and half and torch.backends.cuda.is_available():
|
||||
if use_deepspeed and half and torch.cuda.is_available():
|
||||
import deepspeed
|
||||
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
||||
mp_size=1,
|
||||
replace_with_kernel_inject=True,
|
||||
dtype=torch.float16)
|
||||
self.inference_model = self.ds_engine.module.eval()
|
||||
elif use_deepspeed and torch.backends.cuda.is_available():
|
||||
elif use_deepspeed and torch.cuda.is_available():
|
||||
import deepspeed
|
||||
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
||||
mp_size=1,
|
||||
|
|
|
|||
Loading…
Reference in a new issue