diff --git a/requirements.txt b/requirements.txt index 5168c32..48df0b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ pydantic==1.9.1 deepspeed==0.8.3 py-cpuinfo hjson +psutil diff --git a/tortoise/api.py b/tortoise/api.py index efa01fb..a095a88 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -189,6 +189,16 @@ def pick_best_batch_size_for_gpu(): return 8 elif availableGb > 7: return 4 + if torch.backends.mps.is_available(): + import psutil + available = psutil.virtual_memory().total + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 return 1 class TextToSpeech: @@ -212,7 +222,9 @@ class TextToSpeech: self.models_dir = models_dir self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.enable_redaction = enable_redaction - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') if self.enable_redaction: self.aligner = Wav2VecAlignment() @@ -254,6 +266,7 @@ class TextToSpeech: m = model.to(self.device) yield m m = model.cpu() + def load_cvvp(self): """Load CVVP model.""" @@ -410,54 +423,102 @@ class TextToSpeech: calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" if verbose: print("Generating autoregressive samples..") - with self.temporary_cuda(self.autoregressive - ) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half): - for b in tqdm(range(num_batches), disable=not verbose): - codes = autoregressive.inference_speech(auto_conditioning, text_tokens, - do_sample=True, - top_p=top_p, - temperature=temperature, - num_return_sequences=self.autoregressive_batch_size, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens, - **hf_generate_kwargs) - padding_needed = max_mel_tokens - codes.shape[1] - codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) - samples.append(codes) + if not torch.backends.mps.is_available(): + with self.temporary_cuda(self.autoregressive + ) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half): + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech(auto_conditioning, text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) + else: + with self.temporary_cuda(self.autoregressive) as autoregressive: + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech(auto_conditioning, text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) clip_results = [] - with self.temporary_cuda(self.clvp) as clvp, torch.autocast( - device_type="cuda", dtype=torch.float16, enabled=self.half - ): - if cvvp_amount > 0: - if self.cvvp is None: - self.load_cvvp() - self.cvvp = self.cvvp.to(self.device) - if verbose: - if self.cvvp is None: - print("Computing best candidates using CLVP") - else: - print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") - for batch in tqdm(samples, disable=not verbose): - for i in range(batch.shape[0]): - batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) - if cvvp_amount != 1: - clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) - if auto_conds is not None and cvvp_amount > 0: - cvvp_accumulator = 0 - for cl in range(auto_conds.shape[1]): - cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) - cvvp = cvvp_accumulator / auto_conds.shape[1] - if cvvp_amount == 1: - clip_results.append(cvvp) + + if not torch.backends.mps.is_available(): + with self.temporary_cuda(self.clvp) as clvp, torch.autocast( + device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half + ): + if cvvp_amount > 0: + if self.cvvp is None: + self.load_cvvp() + self.cvvp = self.cvvp.to(self.device) + if verbose: + if self.cvvp is None: + print("Computing best candidates using CLVP") else: - clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount)) - else: - clip_results.append(clvp_out) - clip_results = torch.cat(clip_results, dim=0) - samples = torch.cat(samples, dim=0) - best_results = samples[torch.topk(clip_results, k=k).indices] + print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + if cvvp_amount != 1: + clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + if auto_conds is not None and cvvp_amount > 0: + cvvp_accumulator = 0 + for cl in range(auto_conds.shape[1]): + cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp = cvvp_accumulator / auto_conds.shape[1] + if cvvp_amount == 1: + clip_results.append(cvvp) + else: + clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount)) + else: + clip_results.append(clvp_out) + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] + else: + with self.temporary_cuda(self.clvp) as clvp: + if cvvp_amount > 0: + if self.cvvp is None: + self.load_cvvp() + self.cvvp = self.cvvp.to(self.device) + if verbose: + if self.cvvp is None: + print("Computing best candidates using CLVP") + else: + print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + if cvvp_amount != 1: + clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + if auto_conds is not None and cvvp_amount > 0: + cvvp_accumulator = 0 + for cl in range(auto_conds.shape[1]): + cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp = cvvp_accumulator / auto_conds.shape[1] + if cvvp_amount == 1: + clip_results.append(cvvp) + else: + clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount)) + else: + clip_results.append(clvp_out) + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] if self.cvvp is not None: self.cvvp = self.cvvp.cpu() del samples @@ -465,26 +526,58 @@ class TextToSpeech: # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. - with self.temporary_cuda( - self.autoregressive - ) as autoregressive, torch.autocast( - device_type="cuda", dtype=torch.float16, enabled=self.half - ): - best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), - torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, - torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), - return_latent=True, clip_inputs=False) - del auto_conditioning + if not torch.backends.mps.is_available(): + with self.temporary_cuda( + self.autoregressive + ) as autoregressive, torch.autocast( + device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half + ): + best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), + return_latent=True, clip_inputs=False) + del auto_conditioning + else: + with self.temporary_cuda( + self.autoregressive + ) as autoregressive: + best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), + return_latent=True, clip_inputs=False) + del auto_conditioning if verbose: print("Transforming autoregressive outputs into audio..") wav_candidates = [] - with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda( - self.vocoder - ) as vocoder: + if not torch.backends.mps.is_available(): + with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda( + self.vocoder + ) as vocoder: + for b in range(best_results.shape[0]): + codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) + + # Find the first occurrence of the "calm" token and trim the codes to that. + ctokens = 0 + for k in range(codes.shape[-1]): + if codes[0, k] == calm_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :k] + break + mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature, + verbose=verbose) + wav = vocoder.inference(mel) + wav_candidates.append(wav.cpu()) + else: + diffusion, vocoder = self.diffusion, self.vocoder + diffusion_conditioning = diffusion_conditioning.cpu() for b in range(best_results.shape[0]): - codes = best_results[b].unsqueeze(0) - latents = best_latents[b].unsqueeze(0) + codes = best_results[b].unsqueeze(0).cpu() + latents = best_latents[b].unsqueeze(0).cpu() # Find the first occurrence of the "calm" token and trim the codes to that. ctokens = 0 @@ -496,7 +589,6 @@ class TextToSpeech: if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. latents = latents[:, :k] break - mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature, verbose=verbose) wav = vocoder.inference(mel) diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index 9a6eec9..2d01066 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -47,7 +47,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): self.cached_mel_emb = None def parallelize(self, device_map=None): self.device_map = ( - get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count()))) if device_map is None else device_map ) @@ -62,6 +62,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel): self.lm_head = self.lm_head.to("cpu") self.model_parallel = False torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() def get_output_embeddings(self): return self.lm_head @@ -162,7 +164,10 @@ class GPT2InferenceModel(GPT2PreTrainedModel): # Set device for model parallelism if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) + if torch.backends.mps.is_available(): + self.to(self.transformer.first_device) + else: + torch.cuda.set_device(self.transformer.first_device) hidden_states = hidden_states.to(self.lm_head.weight.device) lm_logits = self.lm_head(hidden_states) diff --git a/tortoise/models/diffusion_decoder.py b/tortoise/models/diffusion_decoder.py index f67d21a..e969129 100644 --- a/tortoise/models/diffusion_decoder.py +++ b/tortoise/models/diffusion_decoder.py @@ -302,7 +302,10 @@ class DiffusionTts(nn.Module): unused_params.extend(list(lyr.parameters())) else: # First and last blocks will have autocast disabled for improved precision. - with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): + if not torch.backends.mps.is_available(): + with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): + x = lyr(x, time_emb) + else: x = lyr(x, time_emb) x = x.float() diff --git a/tortoise/utils/audio.py b/tortoise/utils/audio.py index 91237dd..6842af5 100644 --- a/tortoise/utils/audio.py +++ b/tortoise/utils/audio.py @@ -180,7 +180,7 @@ class TacotronSTFT(torch.nn.Module): return mel_output -def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'): +def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'): stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000) stft = stft.to(device) mel = stft.mel_spectrogram(wav) diff --git a/tortoise/utils/diffusion.py b/tortoise/utils/diffusion.py index e877ff2..6d4d594 100644 --- a/tortoise/utils/diffusion.py +++ b/tortoise/utils/diffusion.py @@ -1244,7 +1244,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps] while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) \ No newline at end of file diff --git a/tortoise/utils/wav2vec_alignment.py b/tortoise/utils/wav2vec_alignment.py index bbe3285..adc39e3 100644 --- a/tortoise/utils/wav2vec_alignment.py +++ b/tortoise/utils/wav2vec_alignment.py @@ -49,7 +49,7 @@ class Wav2VecAlignment: """ Uses wav2vec2 to perform audio<->text alignment. """ - def __init__(self, device='cuda'): + def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'): self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')