From b6822c725d1e32b4d9272357ee3a3a99dbd8a31a Mon Sep 17 00:00:00 2001 From: dsenanayake Date: Mon, 13 May 2024 12:57:36 +1000 Subject: [PATCH] adding a parameter check for device. --- tortoise/api.py | 6 +++++- tortoise/api_fast.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index a5b95dd..8a010c2 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -194,7 +194,11 @@ class TextToSpeech: self.models_dir = models_dir self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.enable_redaction = enable_redaction - self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + else: + self.device = torch.device(device) + if torch.backends.mps.is_available(): self.device = torch.device('mps') if self.enable_redaction: diff --git a/tortoise/api_fast.py b/tortoise/api_fast.py index 216ecf2..fd7c590 100644 --- a/tortoise/api_fast.py +++ b/tortoise/api_fast.py @@ -193,7 +193,11 @@ class TextToSpeech: self.models_dir = models_dir self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.enable_redaction = enable_redaction - self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + else: + self.device = torch.device(device) + if torch.backends.mps.is_available(): self.device = torch.device('mps') if self.enable_redaction: