mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-28 11:24:28 +01:00
updated reqs, do_tts.py and dockerfile for version consistency, some logging
This commit is contained in:
parent
1ead1dd35c
commit
c7dbcef434
17
.vscode/launch.json
vendored
Normal file
17
.vscode/launch.json
vendored
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: tortoise/do_tts.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/tortoise/do_tts.py",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
},
|
||||
"args": [
|
||||
"text='this is a test'"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
12
Dockerfile
12
Dockerfile
|
|
@ -41,13 +41,15 @@ RUN conda create --name tortoise python=3.9 numba inflect -y && \
|
|||
# Set conda environment to be activated by default in future RUN instructions
|
||||
RUN echo "conda activate tortoise" >> ~/.bashrc
|
||||
|
||||
FROM conda AS runner
|
||||
FROM conda_base AS runner
|
||||
|
||||
# Install the application
|
||||
WORKDIR /app
|
||||
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install"
|
||||
|
||||
# Provide default CMD if no arguments are passed
|
||||
CMD ["--help"]
|
||||
|
||||
# Default entrypoint
|
||||
ENTRYPOINT ["/bin/bash", "-c", "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python tortoise/do_tts.py"]
|
||||
RUN chmod +x /app/scripts/docker-entrypoint.sh
|
||||
ENTRYPOINT ["/app/scripts/docker-entrypoint.sh"]
|
||||
|
||||
# Provide default CMD if no arguments are passed
|
||||
CMD ["--help"]
|
||||
|
|
@ -41,7 +41,7 @@ RUN conda create --name tortoise python=3.9 numba inflect -y && \
|
|||
# Set conda environment to be activated by default in future RUN instructions
|
||||
RUN echo "conda activate tortoise" >> ~/.bashrc
|
||||
|
||||
FROM conda AS runner
|
||||
FROM conda_base AS runner
|
||||
|
||||
# Install the application
|
||||
WORKDIR /app
|
||||
|
|
|
|||
|
|
@ -94,6 +94,11 @@ docker run --gpus all \
|
|||
-v /root:/work \
|
||||
-it tts
|
||||
```
|
||||
(new version: current version has `ENTRYPOINT python tortoise/do_tts.py` already!)
|
||||
```
|
||||
|
||||
```
|
||||
If the docker container is
|
||||
This gives you an interactive terminal in an environment that's ready to do some tts. Now you can explore the different interfaces that tortoise exposes for tts.
|
||||
|
||||
For example:
|
||||
|
|
|
|||
38
app/main.py
38
app/main.py
|
|
@ -1,15 +1,21 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from beartype import beartype
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.models.request import TranscriptionRequest
|
||||
from app.models.tts import TTSArgs
|
||||
|
||||
from tortoise.do_tts import main as tts_main
|
||||
|
||||
import os
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@beartype
|
||||
def local_inference_tts(args: TTSArgs):
|
||||
"""
|
||||
Run the TTS directly using the `main` function from `tortoise/do_tts.py`.
|
||||
|
|
@ -43,3 +49,31 @@ async def transcribe(request: TranscriptionRequest):
|
|||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('text', type=str, help='Text to speak. This argument is required.')
|
||||
parser.add_argument('--voice', type=str, help="""
|
||||
Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
|
||||
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.""", default='random')
|
||||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
|
||||
parser.add_argument('--output_path', type=str, help='Where to store outputs (directory).', default='data/results/')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tts_args = TTSArgs(
|
||||
text=args.text,
|
||||
voice=args.voice,
|
||||
output_path=args.output_path,
|
||||
preset=args.preset,
|
||||
)
|
||||
|
||||
try:
|
||||
output_path = local_inference_tts(tts_args)
|
||||
return output_path
|
||||
except Exception as e:
|
||||
print(f"Error during TTS generation: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,17 +1,18 @@
|
|||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tortoise.api import MODELS_DIR
|
||||
|
||||
class TTSArgs(BaseModel):
|
||||
text: str
|
||||
voice: str
|
||||
output_path: str
|
||||
preset: str
|
||||
model_dir: str = os.getenv("TORTOISE_MODELS_DIR", "data/models")
|
||||
voice: str = 'random'
|
||||
output_path: str = 'results/'
|
||||
preset: str = 'fast'
|
||||
model_dir: str = os.getenv("TORTOISE_MODELS_DIR", MODELS_DIR)
|
||||
use_deepspeed: bool = False
|
||||
kv_cache: bool = False
|
||||
half: bool = False
|
||||
candidates: int = 1
|
||||
kv_cache: bool = True
|
||||
half: bool = True
|
||||
candidates: int = 3
|
||||
seed: int = None
|
||||
cvvp_amount: float = 0.0
|
||||
produce_debug_state: bool = False
|
||||
produce_debug_state: bool = True
|
||||
|
|
|
|||
43
app/services/oauth.py
Normal file
43
app/services/oauth.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from fastapi import HTTPException, Depends
|
||||
from fastapi.security import OAuth2AuthorizationCodeBearer
|
||||
|
||||
# Configuration
|
||||
CLIENT_SECRETS_FILE = 'client_secret.json'
|
||||
SCOPES = ['https://www.googleapis.com/auth/youtube.upload']
|
||||
REDIRECT_URI = 'http://localhost:8000/callback'
|
||||
OAUTH2_SCHEME = OAuth2AuthorizationCodeBearer(authorizationUrl='https://accounts.google.com/o/oauth2/auth',
|
||||
tokenUrl='https://oauth2.googleapis.com/token')
|
||||
|
||||
def get_authenticated_service():
|
||||
"""
|
||||
Authenticate and return a service object for the YouTube API.
|
||||
"""
|
||||
flow = Flow.from_client_secrets_file(
|
||||
CLIENT_SECRETS_FILE, SCOPES,
|
||||
redirect_uri=REDIRECT_URI
|
||||
)
|
||||
|
||||
# Tell the user to go to the authorization URL.
|
||||
auth_url, _ = flow.authorization_url(prompt='consent')
|
||||
print('Please go to this URL: {}'.format(auth_url))
|
||||
|
||||
# The user will get an authorization code. This code is used to get the access token.
|
||||
code = input('Enter the authorization code: ')
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
# You can use flow.credentials, or you can just get a requests session using flow.authorized_session.
|
||||
session = flow.authorized_session()
|
||||
print(session.get('https://www.googleapis.com/userinfo/v2/me').json())
|
||||
return session
|
||||
|
||||
async def get_current_user(token: str = Depends(OAUTH2_SCHEME)):
|
||||
# Assuming the token is valid and we can get user info
|
||||
# In practice, you would validate the token and fetch user info
|
||||
try:
|
||||
session = get_authenticated_service()
|
||||
user_info = session.get('https://www.googleapis.com/userinfo/v2/me').json()
|
||||
return user_info
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||
28
logging_conf.yml
Normal file
28
logging_conf.yml
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
version: 1
|
||||
disable_existing_loggers: False
|
||||
|
||||
formatters:
|
||||
standard:
|
||||
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
level: DEBUG
|
||||
formatter: standard
|
||||
stream: ext://sys.stdout
|
||||
|
||||
loggers:
|
||||
app:
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
propagate: no
|
||||
|
||||
tortoise:
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
propagate: no
|
||||
|
||||
root:
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
|
|
@ -3,13 +3,13 @@ rotary_embedding_torch
|
|||
transformers==4.31.0
|
||||
tokenizers
|
||||
inflect
|
||||
progressbar
|
||||
# progressbar2
|
||||
einops==0.4.1
|
||||
unidecode
|
||||
scipy
|
||||
librosa==0.9.1
|
||||
ffmpeg
|
||||
numpy
|
||||
numpy==1.24.1
|
||||
numba
|
||||
torchaudio
|
||||
threadpoolctl
|
||||
|
|
|
|||
7
scripts/docker-entrypoint.sh
Normal file
7
scripts/docker-entrypoint.sh
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash
|
||||
# Source the conda environment
|
||||
source /root/miniconda/etc/profile.d/conda.sh
|
||||
conda activate tortoise
|
||||
|
||||
# Execute the Python script with passed arguments
|
||||
python /app/tortoise/do_tts.py "$@"
|
||||
83
tests/test_transcribe.py
Normal file
83
tests/test_transcribe.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import os, sys
|
||||
import tempfile
|
||||
import logging
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Add the root directory of the repo to sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
# Add the tortoise directory to sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
|
||||
|
||||
from app.models.tts import TTSArgs
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Configure logging to print to the console
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_transcribe():
|
||||
logger.debug("Starting test_transcribe")
|
||||
# Create a temporary directory to store the output
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
logger.debug(f"Created temporary directory at {temp_dir}")
|
||||
response = client.post(
|
||||
"/transcribe", json={
|
||||
"text": "Hello, this is a test.",
|
||||
"voice": "random",
|
||||
"output_path": temp_dir,
|
||||
"preset": "fast"
|
||||
})
|
||||
logger.debug(f"Received response with status code {response.status_code}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
files = os.listdir(temp_dir)
|
||||
assert len(files) >= 1 # or should equal args.candidate
|
||||
# check if ended in .wav
|
||||
assert files[0].endswith('.wav'), f"Expected the file to end with '.wav', but got {files[0]}"
|
||||
logger.debug("test_transcribe completed successfully")
|
||||
|
||||
def test_transcribe_file_not_found():
|
||||
logger.debug("Starting test_transcribe_file_not_found")
|
||||
# Create a temporary directory and delete it immediately to ensure the path does not exist
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
logger.debug(f"Created temporary directory at {temp_dir}")
|
||||
os.rmdir(temp_dir)
|
||||
logger.debug(f"Deleted temporary directory at {temp_dir}")
|
||||
|
||||
response = client.post(
|
||||
"/transcribe", json={
|
||||
"text": "Hello, this is a test.",
|
||||
"voice": "random",
|
||||
"output_path": temp_dir,
|
||||
"preset": "fast"
|
||||
})
|
||||
logger.debug(f"Received response with status code {response.status_code}")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "File not found"}
|
||||
logger.debug("test_transcribe_file_not_found completed successfully")
|
||||
|
||||
def test_transcribe_internal_server_error(monkeypatch):
|
||||
logger.debug("Starting test_transcribe_internal_server_error")
|
||||
# Simulate an exception being raised in the `local_inference_tts` function
|
||||
def mock_local_inference_tts(args):
|
||||
logger.debug("Mock local_inference_tts called")
|
||||
raise Exception("Test exception")
|
||||
|
||||
monkeypatch.setattr("app.main.local_inference_tts", mock_local_inference_tts)
|
||||
logger.debug("Replaced local_inference_tts with mock")
|
||||
|
||||
response = client.post("/transcribe", json={
|
||||
"text": "Hello, this is a test.",
|
||||
"voice": "random",
|
||||
"output_path": "output.wav",
|
||||
"preset": "fast"
|
||||
})
|
||||
logger.debug(f"Received response with status code {response.status_code}")
|
||||
|
||||
assert response.status_code == 500
|
||||
assert response.json() == {"detail": "Test exception"}
|
||||
logger.debug("test_transcribe_internal_server_error completed successfully")
|
||||
|
|
@ -1,56 +1,99 @@
|
|||
import argparse
|
||||
import os
|
||||
import logging
|
||||
import logging.config
|
||||
import os, sys
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
import torchaudio
|
||||
|
||||
from tortoise.api import TextToSpeech, MODELS_DIR
|
||||
from utils.audio import load_voices
|
||||
# Add the root directory of the repo to sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from tortoise.api import TextToSpeech, MODELS_DIR, pick_best_batch_size_for_gpu
|
||||
from tortoise.utils.audio import load_voices
|
||||
|
||||
# Configure logging to print to the console
|
||||
# Load logging configuration
|
||||
with open("logging_conf.yml", 'r') as file:
|
||||
config = yaml.safe_load(file.read())
|
||||
logging.config.dictConfig(config)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main(args):
|
||||
if torch.backends.mps.is_available():
|
||||
args.use_deepspeed = False
|
||||
# if torch.backends.mps.is_available():
|
||||
# args.use_deepspeed = False
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
|
||||
tts = TextToSpeech(
|
||||
models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
|
||||
|
||||
selected_voices = args.voice.split(',')
|
||||
for k, selected_voice in enumerate(selected_voices):
|
||||
for k, selected_voice in tqdm(enumerate(selected_voices), desc="generating using selected voice"):
|
||||
if '&' in selected_voice:
|
||||
voice_sel = selected_voice.split('&')
|
||||
else:
|
||||
voice_sel = [selected_voice]
|
||||
voice_samples, conditioning_latents = load_voices(voice_sel)
|
||||
|
||||
gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
||||
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount)
|
||||
gen, dbg_state = tts.tts_with_preset(
|
||||
args.text, k=args.candidates, voice_samples=voice_samples,
|
||||
conditioning_latents=conditioning_latents,
|
||||
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount
|
||||
)
|
||||
if isinstance(gen, list):
|
||||
for j, g in enumerate(gen):
|
||||
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
|
||||
torchaudio.save(
|
||||
os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000
|
||||
)
|
||||
else:
|
||||
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)
|
||||
|
||||
print(f"Audio saved to {args.output_path} as {selected_voice}_{k}.wav")
|
||||
if args.produce_debug_state:
|
||||
os.makedirs('debug_states', exist_ok=True)
|
||||
torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth')
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
class TTSArgs(BaseModel):
|
||||
text: str
|
||||
voice: str
|
||||
output_path: str
|
||||
preset: str
|
||||
model_dir: str = os.getenv("TORTOISE_MODELS_DIR", "data/models")
|
||||
use_deepspeed: bool = False
|
||||
kv_cache: bool = False
|
||||
half: bool = False
|
||||
candidates: int = 1
|
||||
seed: int = None
|
||||
cvvp_amount: float = 0.0
|
||||
produce_debug_state: bool = False
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('text', type=str, help='Text to speak. This argument is required.')
|
||||
parser.add_argument('--voice', type=str, help="""
|
||||
Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
|
||||
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.""", default='random')
|
||||
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
|
||||
parser.add_argument('--use_deepspeed', type=str, help='Use deepspeed for speed bump.', default=False)
|
||||
parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True)
|
||||
parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True)
|
||||
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
|
||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||
parser.add_argument(
|
||||
'text', type=str, help='Text to speak. This argument is required.')
|
||||
parser.add_argument(
|
||||
'--voice', type=str, help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) Use the & character to join two voices together. Use a comma to perform inference on multiple voices.", default='random')
|
||||
parser.add_argument(
|
||||
'--preset', type=str, help='Which voice preset to use.', default='fast')
|
||||
parser.add_argument(
|
||||
'--use_deepspeed', action=argparse.BooleanOptionalAction, type=bool, help='Use deepspeed for speed bump.', default=False)
|
||||
parser.add_argument(
|
||||
'--kv_cache', type=bool, action=argparse.BooleanOptionalAction, help='If you disable this please wait for a long a time to get the output', default=True)
|
||||
parser.add_argument(
|
||||
'--autoregressive_batch_size', type=int, help='Batch size for autoregressive inference.', default=pick_best_batch_size_for_gpu())
|
||||
parser.add_argument(
|
||||
'--half', type=bool, action=argparse.BooleanOptionalAction, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True)
|
||||
parser.add_argument(
|
||||
'--output_path', type=str, help='Where to store outputs (directory).', default='data/samples/')
|
||||
parser.add_argument(
|
||||
'--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so thisshould only be specified if you have custom checkpoints.', default=MODELS_DIR)
|
||||
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
|
||||
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
||||
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||
parser.add_argument('--produce_debug_state', type=bool, action=argparse.BooleanOptionalAction, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||
parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.'
|
||||
'Increasing this can in some cases reduce the likelihood of multiple speakers. Defaults to 0 (disabled)', default=.0)
|
||||
args = parser.parse_args()
|
||||
assert args.half == False if args.use_deepspeed else True
|
||||
main(args)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue