mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-04-21 06:14:04 +00:00
Addes MPS support
This commit is contained in:
parent
b4988c24b3
commit
8d67995ba7
2 changed files with 3 additions and 1 deletions
|
|
@ -100,7 +100,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
|
|||
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
|
||||
|
||||
|
||||
def format_conditioning(clip, cond_length=132300, device='cuda'):
|
||||
def format_conditioning(clip, cond_length=132300, device="cuda" if not torch.backends.mps.is_available() else 'mps'):
|
||||
"""
|
||||
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -319,6 +319,8 @@ class TorchMelSpectrogram(nn.Module):
|
|||
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||
inp = inp.squeeze(1)
|
||||
assert len(inp.shape) == 2
|
||||
if torch.backends.mps.is_available():
|
||||
inp = inp.to('cpu')
|
||||
self.mel_stft = self.mel_stft.to(inp.device)
|
||||
mel = self.mel_stft(inp)
|
||||
# Perform dynamic range compression
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue