Merge pull request #777 from damithsenanayake/fix_issue_776

adding a parameter check for device.
This commit is contained in:
manmay nakhashi 2024-05-15 17:33:55 +05:30 committed by GitHub
commit e2d9fba0bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View file

@ -194,7 +194,11 @@ class TextToSpeech:
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
else:
self.device = torch.device(device)
if torch.backends.mps.is_available():
self.device = torch.device('mps')
if self.enable_redaction:

View file

@ -193,7 +193,11 @@ class TextToSpeech:
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
else:
self.device = torch.device(device)
if torch.backends.mps.is_available():
self.device = torch.device('mps')
if self.enable_redaction: