This commit is contained in:
Chigoma333 2024-11-25 21:39:11 +00:00 committed by GitHub
commit 29c75f3554
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -49,11 +49,18 @@ class Wav2VecAlignment:
"""
Uses wav2vec2 to perform audio<->text alignment.
"""
def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
def __init__(self, device=None):
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
self.device = device
if device is not None:
self.device = device
elif torch.backends.mps.is_available():
self.device = 'mps'
elif torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
def align(self, audio, expected_text, audio_sample_rate=24000):
orig_len = audio.shape[-1]