mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-29 11:54:40 +01:00
Enable mps support
This commit is contained in:
parent
8a2563ecab
commit
bf7976172e
|
|
@ -243,7 +243,7 @@ class TextToSpeech:
|
|||
self.rlg_auto = None
|
||||
self.rlg_diffusion = None
|
||||
@contextmanager
|
||||
def temporary_cuda(self, model):
|
||||
def temporary_device(self, model):
|
||||
m = model.to(self.device)
|
||||
yield m
|
||||
m = model.cpu()
|
||||
|
|
@ -410,8 +410,9 @@ class TextToSpeech:
|
|||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.autoregressive
|
||||
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
|
||||
with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast(
|
||||
device_type="cuda", dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
|
|
@ -426,7 +427,9 @@ class TextToSpeech:
|
|||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
else:
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive:
|
||||
with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast(
|
||||
device_type="mps", dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
do_sample=True,
|
||||
|
|
@ -444,8 +447,10 @@ class TextToSpeech:
|
|||
clip_results = []
|
||||
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
|
||||
with self.temporary_device(self.clvp) as clvp, torch.autocast(
|
||||
device_type=self.device.type,
|
||||
dtype=torch.float16,
|
||||
enabled=self.half
|
||||
):
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
|
|
@ -476,7 +481,7 @@ class TextToSpeech:
|
|||
samples = torch.cat(samples, dim=0)
|
||||
best_results = samples[torch.topk(clip_results, k=k).indices]
|
||||
else:
|
||||
with self.temporary_cuda(self.clvp) as clvp:
|
||||
with self.temporary_device(self.clvp) as clvp:
|
||||
if cvvp_amount > 0:
|
||||
if self.cvvp is None:
|
||||
self.load_cvvp()
|
||||
|
|
@ -513,10 +518,12 @@ class TextToSpeech:
|
|||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||
# results, but will increase memory usage.
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(
|
||||
with self.temporary_device(
|
||||
self.autoregressive
|
||||
) as autoregressive, torch.autocast(
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
|
||||
device_type=self.device.type,
|
||||
dtype=torch.float16,
|
||||
enabled=self.half
|
||||
):
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||
|
|
@ -524,7 +531,7 @@ class TextToSpeech:
|
|||
return_latent=True, clip_inputs=False)
|
||||
del auto_conditioning
|
||||
else:
|
||||
with self.temporary_cuda(
|
||||
with self.temporary_device(
|
||||
self.autoregressive
|
||||
) as autoregressive:
|
||||
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||
|
|
@ -537,7 +544,7 @@ class TextToSpeech:
|
|||
print("Transforming autoregressive outputs into audio..")
|
||||
wav_candidates = []
|
||||
if not torch.backends.mps.is_available():
|
||||
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
|
||||
with self.temporary_device(self.diffusion) as diffusion, self.temporary_device(
|
||||
self.vocoder
|
||||
) as vocoder:
|
||||
for b in range(best_results.shape[0]):
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ class TextToSpeech:
|
|||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
with torch.autocast(
|
||||
device_type="cuda" , dtype=torch.float16, enabled=self.half
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else "mps" , dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
fake_inputs = self.autoregressive.compute_embeddings(
|
||||
auto_conditioning,
|
||||
|
|
@ -400,7 +400,7 @@ class TextToSpeech:
|
|||
while not is_end:
|
||||
try:
|
||||
with torch.autocast(
|
||||
device_type="cuda", dtype=torch.float16, enabled=self.half
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
codes, latent = next(gpt_generator)
|
||||
all_latents += [latent]
|
||||
|
|
@ -477,9 +477,9 @@ class TextToSpeech:
|
|||
with torch.no_grad():
|
||||
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
print("Generating autoregressive samples..")
|
||||
with torch.autocast(
|
||||
device_type="cuda" , dtype=torch.float16, enabled=self.half
|
||||
device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half
|
||||
):
|
||||
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
|
||||
top_k=50,
|
||||
|
|
|
|||
Loading…
Reference in a new issue