diff --git a/tortoise/api.py b/tortoise/api.py index a5b95dd..8a010c2 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -194,7 +194,11 @@ class TextToSpeech: self.models_dir = models_dir self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.enable_redaction = enable_redaction - self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + else: + self.device = torch.device(device) + if torch.backends.mps.is_available(): self.device = torch.device('mps') if self.enable_redaction: diff --git a/tortoise/api_fast.py b/tortoise/api_fast.py index 216ecf2..fd7c590 100644 --- a/tortoise/api_fast.py +++ b/tortoise/api_fast.py @@ -193,7 +193,11 @@ class TextToSpeech: self.models_dir = models_dir self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.enable_redaction = enable_redaction - self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + else: + self.device = torch.device(device) + if torch.backends.mps.is_available(): self.device = torch.device('mps') if self.enable_redaction: