mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-24 09:30:27 +01:00
Merge pull request #635 from wgaylord/wgaylord-patch-1
Add CPU only support to hifigan_decoder.py
This commit is contained in:
commit
c8a3f8a3e0
|
|
@ -230,6 +230,10 @@ class HifiganGenerator(torch.nn.Module):
|
|||
if not conv_post_weight_norm:
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = torch.device('mps')
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -287,7 +291,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
mode="linear",
|
||||
)
|
||||
g = g.unsqueeze(0)
|
||||
return self.forward(up_2.to("cuda"), g.transpose(1,2))
|
||||
return self.forward(up_2.to(self.device), g.transpose(1,2))
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
|
|
|
|||
Loading…
Reference in a new issue