mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-04-21 06:14:04 +00:00
Merge 138a7b9070 into 8a2563ecab
This commit is contained in:
commit
29c75f3554
1 changed files with 9 additions and 2 deletions
|
|
@ -49,11 +49,18 @@ class Wav2VecAlignment:
|
|||
"""
|
||||
Uses wav2vec2 to perform audio<->text alignment.
|
||||
"""
|
||||
def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
|
||||
def __init__(self, device=None):
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
|
||||
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
|
||||
self.device = device
|
||||
if device is not None:
|
||||
self.device = device
|
||||
elif torch.backends.mps.is_available():
|
||||
self.device = 'mps'
|
||||
elif torch.cuda.is_available():
|
||||
self.device = 'cuda'
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
|
||||
def align(self, audio, expected_text, audio_sample_rate=24000):
|
||||
orig_len = audio.shape[-1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue