This commit is contained in:
Simon Sardorf 2024-12-18 16:08:52 +01:00 committed by GitHub
commit bc06041c99
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 27 additions and 27 deletions

View file

@ -28,9 +28,10 @@ setuptools.setup(
'scipy',
'librosa',
'transformers==4.31.0',
'tokenizers==0.14.0',
'scipy==1.13.1'
# 'deepspeed==0.8.3',
'tokenizers',
'scipy==1.13.1',
'deepspeed',
'py-cpuinfo'
],
classifiers=[
"Programming Language :: Python :: 3",

View file

@ -243,7 +243,7 @@ class TextToSpeech:
self.rlg_auto = None
self.rlg_diffusion = None
@contextmanager
def temporary_cuda(self, model):
def temporary_device(self, model):
m = model.to(self.device)
yield m
m = model.cpu()
@ -410,8 +410,9 @@ class TextToSpeech:
if verbose:
print("Generating autoregressive samples..")
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.autoregressive
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half
):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
@ -426,7 +427,9 @@ class TextToSpeech:
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
else:
with self.temporary_cuda(self.autoregressive) as autoregressive:
with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast(
device_type="mps", dtype=torch.float16, enabled=self.half
):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
@ -444,8 +447,10 @@ class TextToSpeech:
clip_results = []
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
with self.temporary_device(self.clvp) as clvp, torch.autocast(
device_type=self.device.type,
dtype=torch.float16,
enabled=self.half
):
if cvvp_amount > 0:
if self.cvvp is None:
@ -476,7 +481,7 @@ class TextToSpeech:
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
else:
with self.temporary_cuda(self.clvp) as clvp:
with self.temporary_device(self.clvp) as clvp:
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
@ -513,10 +518,12 @@ class TextToSpeech:
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
if not torch.backends.mps.is_available():
with self.temporary_cuda(
with self.temporary_device(
self.autoregressive
) as autoregressive, torch.autocast(
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
device_type=self.device.type,
dtype=torch.float16,
enabled=self.half
):
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
@ -524,7 +531,7 @@ class TextToSpeech:
return_latent=True, clip_inputs=False)
del auto_conditioning
else:
with self.temporary_cuda(
with self.temporary_device(
self.autoregressive
) as autoregressive:
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
@ -537,7 +544,7 @@ class TextToSpeech:
print("Transforming autoregressive outputs into audio..")
wav_candidates = []
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
with self.temporary_device(self.diffusion) as diffusion, self.temporary_device(
self.vocoder
) as vocoder:
for b in range(best_results.shape[0]):

View file

@ -371,7 +371,7 @@ class TextToSpeech:
if verbose:
print("Generating autoregressive samples..")
with torch.autocast(
device_type="cuda" , dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps" , dtype=torch.float16, enabled=self.half
):
fake_inputs = self.autoregressive.compute_embeddings(
auto_conditioning,
@ -400,7 +400,7 @@ class TextToSpeech:
while not is_end:
try:
with torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half
):
codes, latent = next(gpt_generator)
all_latents += [latent]
@ -477,9 +477,9 @@ class TextToSpeech:
with torch.no_grad():
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
if verbose:
print("Generating autoregressive samples..")
print("Generating autoregressive samples..")
with torch.autocast(
device_type="cuda" , dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half
):
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
top_k=50,

View file

@ -13,7 +13,7 @@ if __name__ == '__main__':
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
parser.add_argument('--use_deepspeed', type=str, help='Use deepspeed for speed bump.', default=False)
parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=False)
parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True)
parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True)
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
@ -25,8 +25,6 @@ if __name__ == '__main__':
parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.'
'Increasing this can in some cases reduce the likelihood of multiple speakers. Defaults to 0 (disabled)', default=.0)
args = parser.parse_args()
if torch.backends.mps.is_available():
args.use_deepspeed = False
os.makedirs(args.output_path, exist_ok=True)
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)

View file

@ -30,8 +30,6 @@ if __name__ == '__main__':
args = parser.parse_args()
if torch.backends.mps.is_available():
args.use_deepspeed = False
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
outpath = args.output_path

View file

@ -28,8 +28,6 @@ if __name__ == '__main__':
args = parser.parse_args()
if torch.backends.mps.is_available():
args.use_deepspeed = False
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
outpath = args.output_path

View file

@ -37,8 +37,6 @@ if __name__ == '__main__':
args = parser.parse_args()
if torch.backends.mps.is_available():
args.use_deepspeed = False
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
outpath = args.output_path