diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..b128482 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,17 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: tortoise/do_tts.py", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/tortoise/do_tts.py", + "env": { + "PYTHONPATH": "${workspaceFolder}" + }, + "args": [ + "text='this is a test'" + ] + } + ] +} diff --git a/Dockerfile b/Dockerfile index 36b363b..4b7b740 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,13 +41,15 @@ RUN conda create --name tortoise python=3.9 numba inflect -y && \ # Set conda environment to be activated by default in future RUN instructions RUN echo "conda activate tortoise" >> ~/.bashrc -FROM conda AS runner +FROM conda_base AS runner + # Install the application WORKDIR /app RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install" -# Provide default CMD if no arguments are passed -CMD ["--help"] - # Default entrypoint -ENTRYPOINT ["/bin/bash", "-c", "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python tortoise/do_tts.py"] +RUN chmod +x /app/scripts/docker-entrypoint.sh +ENTRYPOINT ["/app/scripts/docker-entrypoint.sh"] + +# Provide default CMD if no arguments are passed +CMD ["--help"] \ No newline at end of file diff --git a/Dockerfile.app b/Dockerfile.app index 6602c1e..9dc7e81 100644 --- a/Dockerfile.app +++ b/Dockerfile.app @@ -41,7 +41,7 @@ RUN conda create --name tortoise python=3.9 numba inflect -y && \ # Set conda environment to be activated by default in future RUN instructions RUN echo "conda activate tortoise" >> ~/.bashrc -FROM conda AS runner +FROM conda_base AS runner # Install the application WORKDIR /app diff --git a/README.md b/README.md index 6f82564..809a6de 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,11 @@ docker run --gpus all \ -v /root:/work \ -it tts ``` +(new version: current version has `ENTRYPOINT python tortoise/do_tts.py` already!) +``` + +``` +If the docker container is This gives you an interactive terminal in an environment that's ready to do some tts. Now you can explore the different interfaces that tortoise exposes for tts. For example: diff --git a/app/main.py b/app/main.py index 51d21c5..641653b 100644 --- a/app/main.py +++ b/app/main.py @@ -1,15 +1,21 @@ +import argparse +import os +import sys + + from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse +from beartype import beartype from pydantic import BaseModel + from app.models.request import TranscriptionRequest from app.models.tts import TTSArgs from tortoise.do_tts import main as tts_main -import os - app = FastAPI() +@beartype def local_inference_tts(args: TTSArgs): """ Run the TTS directly using the `main` function from `tortoise/do_tts.py`. @@ -43,3 +49,31 @@ async def transcribe(request: TranscriptionRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('text', type=str, help='Text to speak. This argument is required.') + parser.add_argument('--voice', type=str, help=""" + Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' + 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.""", default='random') + parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast') + parser.add_argument('--output_path', type=str, help='Where to store outputs (directory).', default='data/results/') + + args = parser.parse_args() + + tts_args = TTSArgs( + text=args.text, + voice=args.voice, + output_path=args.output_path, + preset=args.preset, + ) + + try: + output_path = local_inference_tts(tts_args) + return output_path + except Exception as e: + print(f"Error during TTS generation: {str(e)}") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/app/models/tts.py b/app/models/tts.py index 88ad8c5..86b9084 100644 --- a/app/models/tts.py +++ b/app/models/tts.py @@ -1,17 +1,18 @@ import os - from pydantic import BaseModel +from tortoise.api import MODELS_DIR + class TTSArgs(BaseModel): text: str - voice: str - output_path: str - preset: str - model_dir: str = os.getenv("TORTOISE_MODELS_DIR", "data/models") + voice: str = 'random' + output_path: str = 'results/' + preset: str = 'fast' + model_dir: str = os.getenv("TORTOISE_MODELS_DIR", MODELS_DIR) use_deepspeed: bool = False - kv_cache: bool = False - half: bool = False - candidates: int = 1 + kv_cache: bool = True + half: bool = True + candidates: int = 3 seed: int = None cvvp_amount: float = 0.0 - produce_debug_state: bool = False + produce_debug_state: bool = True diff --git a/app/services/oauth.py b/app/services/oauth.py new file mode 100644 index 0000000..a2bfc67 --- /dev/null +++ b/app/services/oauth.py @@ -0,0 +1,43 @@ +import os +from google_auth_oauthlib.flow import Flow +from fastapi import HTTPException, Depends +from fastapi.security import OAuth2AuthorizationCodeBearer + +# Configuration +CLIENT_SECRETS_FILE = 'client_secret.json' +SCOPES = ['https://www.googleapis.com/auth/youtube.upload'] +REDIRECT_URI = 'http://localhost:8000/callback' +OAUTH2_SCHEME = OAuth2AuthorizationCodeBearer(authorizationUrl='https://accounts.google.com/o/oauth2/auth', + tokenUrl='https://oauth2.googleapis.com/token') + +def get_authenticated_service(): + """ + Authenticate and return a service object for the YouTube API. + """ + flow = Flow.from_client_secrets_file( + CLIENT_SECRETS_FILE, SCOPES, + redirect_uri=REDIRECT_URI + ) + + # Tell the user to go to the authorization URL. + auth_url, _ = flow.authorization_url(prompt='consent') + print('Please go to this URL: {}'.format(auth_url)) + + # The user will get an authorization code. This code is used to get the access token. + code = input('Enter the authorization code: ') + flow.fetch_token(code=code) + + # You can use flow.credentials, or you can just get a requests session using flow.authorized_session. + session = flow.authorized_session() + print(session.get('https://www.googleapis.com/userinfo/v2/me').json()) + return session + +async def get_current_user(token: str = Depends(OAUTH2_SCHEME)): + # Assuming the token is valid and we can get user info + # In practice, you would validate the token and fetch user info + try: + session = get_authenticated_service() + user_info = session.get('https://www.googleapis.com/userinfo/v2/me').json() + return user_info + except Exception as e: + raise HTTPException(status_code=401, detail="Invalid authentication credentials") diff --git a/logging_conf.yml b/logging_conf.yml new file mode 100644 index 0000000..576e643 --- /dev/null +++ b/logging_conf.yml @@ -0,0 +1,28 @@ +version: 1 +disable_existing_loggers: False + +formatters: + standard: + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + +handlers: + console: + class: logging.StreamHandler + level: DEBUG + formatter: standard + stream: ext://sys.stdout + +loggers: + app: + level: DEBUG + handlers: [console] + propagate: no + + tortoise: + level: DEBUG + handlers: [console] + propagate: no + +root: + level: DEBUG + handlers: [console] diff --git a/requirements.txt b/requirements.txt index fd8d538..57d0fc5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,13 +3,13 @@ rotary_embedding_torch transformers==4.31.0 tokenizers inflect -progressbar +# progressbar2 einops==0.4.1 unidecode scipy librosa==0.9.1 ffmpeg -numpy +numpy==1.24.1 numba torchaudio threadpoolctl diff --git a/scripts/docker-entrypoint.sh b/scripts/docker-entrypoint.sh new file mode 100644 index 0000000..769b620 --- /dev/null +++ b/scripts/docker-entrypoint.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Source the conda environment +source /root/miniconda/etc/profile.d/conda.sh +conda activate tortoise + +# Execute the Python script with passed arguments +python /app/tortoise/do_tts.py "$@" \ No newline at end of file diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py new file mode 100644 index 0000000..4dcca9d --- /dev/null +++ b/tests/test_transcribe.py @@ -0,0 +1,83 @@ +import os, sys +import tempfile +import logging +from fastapi.testclient import TestClient + +# Add the root directory of the repo to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +# Add the tortoise directory to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise'))) + +from app.models.tts import TTSArgs +from app.main import app + +client = TestClient(app) + +# Configure logging to print to the console +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def test_transcribe(): + logger.debug("Starting test_transcribe") + # Create a temporary directory to store the output + with tempfile.TemporaryDirectory() as temp_dir: + logger.debug(f"Created temporary directory at {temp_dir}") + response = client.post( + "/transcribe", json={ + "text": "Hello, this is a test.", + "voice": "random", + "output_path": temp_dir, + "preset": "fast" + }) + logger.debug(f"Received response with status code {response.status_code}") + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + files = os.listdir(temp_dir) + assert len(files) >= 1 # or should equal args.candidate + # check if ended in .wav + assert files[0].endswith('.wav'), f"Expected the file to end with '.wav', but got {files[0]}" + logger.debug("test_transcribe completed successfully") + +def test_transcribe_file_not_found(): + logger.debug("Starting test_transcribe_file_not_found") + # Create a temporary directory and delete it immediately to ensure the path does not exist + temp_dir = tempfile.mkdtemp() + logger.debug(f"Created temporary directory at {temp_dir}") + os.rmdir(temp_dir) + logger.debug(f"Deleted temporary directory at {temp_dir}") + + response = client.post( + "/transcribe", json={ + "text": "Hello, this is a test.", + "voice": "random", + "output_path": temp_dir, + "preset": "fast" + }) + logger.debug(f"Received response with status code {response.status_code}") + + assert response.status_code == 404 + assert response.json() == {"detail": "File not found"} + logger.debug("test_transcribe_file_not_found completed successfully") + +def test_transcribe_internal_server_error(monkeypatch): + logger.debug("Starting test_transcribe_internal_server_error") + # Simulate an exception being raised in the `local_inference_tts` function + def mock_local_inference_tts(args): + logger.debug("Mock local_inference_tts called") + raise Exception("Test exception") + + monkeypatch.setattr("app.main.local_inference_tts", mock_local_inference_tts) + logger.debug("Replaced local_inference_tts with mock") + + response = client.post("/transcribe", json={ + "text": "Hello, this is a test.", + "voice": "random", + "output_path": "output.wav", + "preset": "fast" + }) + logger.debug(f"Received response with status code {response.status_code}") + + assert response.status_code == 500 + assert response.json() == {"detail": "Test exception"} + logger.debug("test_transcribe_internal_server_error completed successfully") diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py index 902cdb1..1a5a49a 100644 --- a/tortoise/do_tts.py +++ b/tortoise/do_tts.py @@ -1,56 +1,99 @@ import argparse -import os +import logging +import logging.config +import os, sys +from tqdm import tqdm import torch +import yaml import torchaudio -from tortoise.api import TextToSpeech, MODELS_DIR -from utils.audio import load_voices +# Add the root directory of the repo to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from tortoise.api import TextToSpeech, MODELS_DIR, pick_best_batch_size_for_gpu +from tortoise.utils.audio import load_voices + +# Configure logging to print to the console +# Load logging configuration +with open("logging_conf.yml", 'r') as file: + config = yaml.safe_load(file.read()) + logging.config.dictConfig(config) +logger = logging.getLogger(__name__) def main(args): - if torch.backends.mps.is_available(): - args.use_deepspeed = False + # if torch.backends.mps.is_available(): + # args.use_deepspeed = False os.makedirs(args.output_path, exist_ok=True) - tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half) + tts = TextToSpeech( + models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half) selected_voices = args.voice.split(',') - for k, selected_voice in enumerate(selected_voices): + for k, selected_voice in tqdm(enumerate(selected_voices), desc="generating using selected voice"): if '&' in selected_voice: voice_sel = selected_voice.split('&') else: voice_sel = [selected_voice] voice_samples, conditioning_latents = load_voices(voice_sel) - gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents, - preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount) + gen, dbg_state = tts.tts_with_preset( + args.text, k=args.candidates, voice_samples=voice_samples, + conditioning_latents=conditioning_latents, + preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount + ) if isinstance(gen, list): for j, g in enumerate(gen): - torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000) + torchaudio.save( + os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000 + ) else: torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000) - + print(f"Audio saved to {args.output_path} as {selected_voice}_{k}.wav") if args.produce_debug_state: os.makedirs('debug_states', exist_ok=True) torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth') if __name__ == '__main__': + """ + class TTSArgs(BaseModel): + text: str + voice: str + output_path: str + preset: str + model_dir: str = os.getenv("TORTOISE_MODELS_DIR", "data/models") + use_deepspeed: bool = False + kv_cache: bool = False + half: bool = False + candidates: int = 1 + seed: int = None + cvvp_amount: float = 0.0 + produce_debug_state: bool = False + """ parser = argparse.ArgumentParser() - parser.add_argument('text', type=str, help='Text to speak. This argument is required.') - parser.add_argument('--voice', type=str, help=""" - Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' - 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.""", default='random') - parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast') - parser.add_argument('--use_deepspeed', type=str, help='Use deepspeed for speed bump.', default=False) - parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True) - parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True) - parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') - parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' - 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) + parser.add_argument( + 'text', type=str, help='Text to speak. This argument is required.') + parser.add_argument( + '--voice', type=str, help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) Use the & character to join two voices together. Use a comma to perform inference on multiple voices.", default='random') + parser.add_argument( + '--preset', type=str, help='Which voice preset to use.', default='fast') + parser.add_argument( + '--use_deepspeed', action=argparse.BooleanOptionalAction, type=bool, help='Use deepspeed for speed bump.', default=False) + parser.add_argument( + '--kv_cache', type=bool, action=argparse.BooleanOptionalAction, help='If you disable this please wait for a long a time to get the output', default=True) + parser.add_argument( + '--autoregressive_batch_size', type=int, help='Batch size for autoregressive inference.', default=pick_best_batch_size_for_gpu()) + parser.add_argument( + '--half', type=bool, action=argparse.BooleanOptionalAction, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True) + parser.add_argument( + '--output_path', type=str, help='Where to store outputs (directory).', default='data/samples/') + parser.add_argument( + '--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so thisshould only be specified if you have custom checkpoints.', default=MODELS_DIR) parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3) parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) - parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) + parser.add_argument('--produce_debug_state', type=bool, action=argparse.BooleanOptionalAction, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.' 'Increasing this can in some cases reduce the likelihood of multiple speakers. Defaults to 0 (disabled)', default=.0) args = parser.parse_args() + assert args.half == False if args.use_deepspeed else True main(args)