Enable mps support

This commit is contained in:
Simon Sardorf 2024-12-18 15:44:50 +01:00
parent 8a2563ecab
commit bf7976172e
No known key found for this signature in database
GPG key ID: DDCDE930BECDBC83
2 changed files with 22 additions and 15 deletions

View file

@ -243,7 +243,7 @@ class TextToSpeech:
self.rlg_auto = None
self.rlg_diffusion = None
@contextmanager
def temporary_cuda(self, model):
def temporary_device(self, model):
m = model.to(self.device)
yield m
m = model.cpu()
@ -410,8 +410,9 @@ class TextToSpeech:
if verbose:
print("Generating autoregressive samples..")
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.autoregressive
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half
):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
@ -426,7 +427,9 @@ class TextToSpeech:
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
else:
with self.temporary_cuda(self.autoregressive) as autoregressive:
with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast(
device_type="mps", dtype=torch.float16, enabled=self.half
):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
@ -444,8 +447,10 @@ class TextToSpeech:
clip_results = []
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
with self.temporary_device(self.clvp) as clvp, torch.autocast(
device_type=self.device.type,
dtype=torch.float16,
enabled=self.half
):
if cvvp_amount > 0:
if self.cvvp is None:
@ -476,7 +481,7 @@ class TextToSpeech:
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
else:
with self.temporary_cuda(self.clvp) as clvp:
with self.temporary_device(self.clvp) as clvp:
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
@ -513,10 +518,12 @@ class TextToSpeech:
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
if not torch.backends.mps.is_available():
with self.temporary_cuda(
with self.temporary_device(
self.autoregressive
) as autoregressive, torch.autocast(
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
device_type=self.device.type,
dtype=torch.float16,
enabled=self.half
):
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
@ -524,7 +531,7 @@ class TextToSpeech:
return_latent=True, clip_inputs=False)
del auto_conditioning
else:
with self.temporary_cuda(
with self.temporary_device(
self.autoregressive
) as autoregressive:
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
@ -537,7 +544,7 @@ class TextToSpeech:
print("Transforming autoregressive outputs into audio..")
wav_candidates = []
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
with self.temporary_device(self.diffusion) as diffusion, self.temporary_device(
self.vocoder
) as vocoder:
for b in range(best_results.shape[0]):

View file

@ -371,7 +371,7 @@ class TextToSpeech:
if verbose:
print("Generating autoregressive samples..")
with torch.autocast(
device_type="cuda" , dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps" , dtype=torch.float16, enabled=self.half
):
fake_inputs = self.autoregressive.compute_embeddings(
auto_conditioning,
@ -400,7 +400,7 @@ class TextToSpeech:
while not is_end:
try:
with torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half
):
codes, latent = next(gpt_generator)
all_latents += [latent]
@ -477,9 +477,9 @@ class TextToSpeech:
with torch.no_grad():
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
if verbose:
print("Generating autoregressive samples..")
print("Generating autoregressive samples..")
with torch.autocast(
device_type="cuda" , dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half
):
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
top_k=50,