Merge pull request #565 from NourEldin-Osama/main-1

Update autoregressive.py
This commit is contained in:
manmay nakhashi 2023-08-11 23:08:33 +05:30 committed by GitHub
commit 4003544b6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -375,14 +375,14 @@ class UnifiedVoice(nn.Module):
self.mel_head,
kv_cache=kv_cache,
)
if use_deepspeed and half and torch.backends.cuda.is_available():
if use_deepspeed and half and torch.cuda.is_available():
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,
dtype=torch.float16)
self.inference_model = self.ds_engine.module.eval()
elif use_deepspeed and torch.backends.cuda.is_available():
elif use_deepspeed and torch.cuda.is_available():
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,