Addes MPS support

This commit is contained in:
Jerry-Master 2023-08-06 19:01:10 +02:00
parent b4988c24b3
commit 8d67995ba7
2 changed files with 3 additions and 1 deletions

View file

@ -100,7 +100,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
def format_conditioning(clip, cond_length=132300, device='cuda'):
def format_conditioning(clip, cond_length=132300, device="cuda" if not torch.backends.mps.is_available() else 'mps'):
"""
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
"""

View file

@ -319,6 +319,8 @@ class TorchMelSpectrogram(nn.Module):
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
if torch.backends.mps.is_available():
inp = inp.to('cpu')
self.mel_stft = self.mel_stft.to(inp.device)
mel = self.mel_stft(inp)
# Perform dynamic range compression