adding a parameter check for device.

This commit is contained in:
dsenanayake 2024-05-13 12:57:36 +10:00
parent 572bdf3d24
commit b6822c725d
2 changed files with 10 additions and 2 deletions

View file

@ -194,7 +194,11 @@ class TextToSpeech:
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
else:
self.device = torch.device(device)
if torch.backends.mps.is_available():
self.device = torch.device('mps')
if self.enable_redaction:

View file

@ -193,7 +193,11 @@ class TextToSpeech:
self.models_dir = models_dir
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
self.enable_redaction = enable_redaction
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
else:
self.device = torch.device(device)
if torch.backends.mps.is_available():
self.device = torch.device('mps')
if self.enable_redaction: