diff --git a/tortoise/models/hifigan_decoder.py b/tortoise/models/hifigan_decoder.py index 1a428d1..ae2f627 100644 --- a/tortoise/models/hifigan_decoder.py +++ b/tortoise/models/hifigan_decoder.py @@ -230,6 +230,10 @@ class HifiganGenerator(torch.nn.Module): if not conv_post_weight_norm: remove_weight_norm(self.conv_post) + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + def forward(self, x, g=None): """ Args: @@ -287,7 +291,7 @@ class HifiganGenerator(torch.nn.Module): mode="linear", ) g = g.unsqueeze(0) - return self.forward(up_2.to("cuda"), g.transpose(1,2)) + return self.forward(up_2.to(self.device), g.transpose(1,2)) def remove_weight_norm(self): print("Removing weight norm...")