updated reqs, do_tts.py and dockerfile for version consistency, some logging

This commit is contained in:
supermomo668 2024-08-07 00:15:18 +00:00
parent 1ead1dd35c
commit c7dbcef434
12 changed files with 305 additions and 42 deletions

17
.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,17 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: tortoise/do_tts.py",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/tortoise/do_tts.py",
"env": {
"PYTHONPATH": "${workspaceFolder}"
},
"args": [
"text='this is a test'"
]
}
]
}

View file

@ -41,13 +41,15 @@ RUN conda create --name tortoise python=3.9 numba inflect -y && \
# Set conda environment to be activated by default in future RUN instructions
RUN echo "conda activate tortoise" >> ~/.bashrc
FROM conda AS runner
FROM conda_base AS runner
# Install the application
WORKDIR /app
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install"
# Provide default CMD if no arguments are passed
CMD ["--help"]
# Default entrypoint
ENTRYPOINT ["/bin/bash", "-c", "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python tortoise/do_tts.py"]
RUN chmod +x /app/scripts/docker-entrypoint.sh
ENTRYPOINT ["/app/scripts/docker-entrypoint.sh"]
# Provide default CMD if no arguments are passed
CMD ["--help"]

View file

@ -41,7 +41,7 @@ RUN conda create --name tortoise python=3.9 numba inflect -y && \
# Set conda environment to be activated by default in future RUN instructions
RUN echo "conda activate tortoise" >> ~/.bashrc
FROM conda AS runner
FROM conda_base AS runner
# Install the application
WORKDIR /app

View file

@ -94,6 +94,11 @@ docker run --gpus all \
-v /root:/work \
-it tts
```
(new version: current version has `ENTRYPOINT python tortoise/do_tts.py` already!)
```
```
If the docker container is
This gives you an interactive terminal in an environment that's ready to do some tts. Now you can explore the different interfaces that tortoise exposes for tts.
For example:

View file

@ -1,15 +1,21 @@
import argparse
import os
import sys
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from beartype import beartype
from pydantic import BaseModel
from app.models.request import TranscriptionRequest
from app.models.tts import TTSArgs
from tortoise.do_tts import main as tts_main
import os
app = FastAPI()
@beartype
def local_inference_tts(args: TTSArgs):
"""
Run the TTS directly using the `main` function from `tortoise/do_tts.py`.
@ -43,3 +49,31 @@ async def transcribe(request: TranscriptionRequest):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('text', type=str, help='Text to speak. This argument is required.')
parser.add_argument('--voice', type=str, help="""
Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.""", default='random')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
parser.add_argument('--output_path', type=str, help='Where to store outputs (directory).', default='data/results/')
args = parser.parse_args()
tts_args = TTSArgs(
text=args.text,
voice=args.voice,
output_path=args.output_path,
preset=args.preset,
)
try:
output_path = local_inference_tts(tts_args)
return output_path
except Exception as e:
print(f"Error during TTS generation: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -1,17 +1,18 @@
import os
from pydantic import BaseModel
from tortoise.api import MODELS_DIR
class TTSArgs(BaseModel):
text: str
voice: str
output_path: str
preset: str
model_dir: str = os.getenv("TORTOISE_MODELS_DIR", "data/models")
voice: str = 'random'
output_path: str = 'results/'
preset: str = 'fast'
model_dir: str = os.getenv("TORTOISE_MODELS_DIR", MODELS_DIR)
use_deepspeed: bool = False
kv_cache: bool = False
half: bool = False
candidates: int = 1
kv_cache: bool = True
half: bool = True
candidates: int = 3
seed: int = None
cvvp_amount: float = 0.0
produce_debug_state: bool = False
produce_debug_state: bool = True

43
app/services/oauth.py Normal file
View file

@ -0,0 +1,43 @@
import os
from google_auth_oauthlib.flow import Flow
from fastapi import HTTPException, Depends
from fastapi.security import OAuth2AuthorizationCodeBearer
# Configuration
CLIENT_SECRETS_FILE = 'client_secret.json'
SCOPES = ['https://www.googleapis.com/auth/youtube.upload']
REDIRECT_URI = 'http://localhost:8000/callback'
OAUTH2_SCHEME = OAuth2AuthorizationCodeBearer(authorizationUrl='https://accounts.google.com/o/oauth2/auth',
tokenUrl='https://oauth2.googleapis.com/token')
def get_authenticated_service():
"""
Authenticate and return a service object for the YouTube API.
"""
flow = Flow.from_client_secrets_file(
CLIENT_SECRETS_FILE, SCOPES,
redirect_uri=REDIRECT_URI
)
# Tell the user to go to the authorization URL.
auth_url, _ = flow.authorization_url(prompt='consent')
print('Please go to this URL: {}'.format(auth_url))
# The user will get an authorization code. This code is used to get the access token.
code = input('Enter the authorization code: ')
flow.fetch_token(code=code)
# You can use flow.credentials, or you can just get a requests session using flow.authorized_session.
session = flow.authorized_session()
print(session.get('https://www.googleapis.com/userinfo/v2/me').json())
return session
async def get_current_user(token: str = Depends(OAUTH2_SCHEME)):
# Assuming the token is valid and we can get user info
# In practice, you would validate the token and fetch user info
try:
session = get_authenticated_service()
user_info = session.get('https://www.googleapis.com/userinfo/v2/me').json()
return user_info
except Exception as e:
raise HTTPException(status_code=401, detail="Invalid authentication credentials")

28
logging_conf.yml Normal file
View file

@ -0,0 +1,28 @@
version: 1
disable_existing_loggers: False
formatters:
standard:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: standard
stream: ext://sys.stdout
loggers:
app:
level: DEBUG
handlers: [console]
propagate: no
tortoise:
level: DEBUG
handlers: [console]
propagate: no
root:
level: DEBUG
handlers: [console]

View file

@ -3,13 +3,13 @@ rotary_embedding_torch
transformers==4.31.0
tokenizers
inflect
progressbar
# progressbar2
einops==0.4.1
unidecode
scipy
librosa==0.9.1
ffmpeg
numpy
numpy==1.24.1
numba
torchaudio
threadpoolctl

View file

@ -0,0 +1,7 @@
#!/bin/bash
# Source the conda environment
source /root/miniconda/etc/profile.d/conda.sh
conda activate tortoise
# Execute the Python script with passed arguments
python /app/tortoise/do_tts.py "$@"

83
tests/test_transcribe.py Normal file
View file

@ -0,0 +1,83 @@
import os, sys
import tempfile
import logging
from fastapi.testclient import TestClient
# Add the root directory of the repo to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Add the tortoise directory to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
from app.models.tts import TTSArgs
from app.main import app
client = TestClient(app)
# Configure logging to print to the console
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_transcribe():
logger.debug("Starting test_transcribe")
# Create a temporary directory to store the output
with tempfile.TemporaryDirectory() as temp_dir:
logger.debug(f"Created temporary directory at {temp_dir}")
response = client.post(
"/transcribe", json={
"text": "Hello, this is a test.",
"voice": "random",
"output_path": temp_dir,
"preset": "fast"
})
logger.debug(f"Received response with status code {response.status_code}")
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
files = os.listdir(temp_dir)
assert len(files) >= 1 # or should equal args.candidate
# check if ended in .wav
assert files[0].endswith('.wav'), f"Expected the file to end with '.wav', but got {files[0]}"
logger.debug("test_transcribe completed successfully")
def test_transcribe_file_not_found():
logger.debug("Starting test_transcribe_file_not_found")
# Create a temporary directory and delete it immediately to ensure the path does not exist
temp_dir = tempfile.mkdtemp()
logger.debug(f"Created temporary directory at {temp_dir}")
os.rmdir(temp_dir)
logger.debug(f"Deleted temporary directory at {temp_dir}")
response = client.post(
"/transcribe", json={
"text": "Hello, this is a test.",
"voice": "random",
"output_path": temp_dir,
"preset": "fast"
})
logger.debug(f"Received response with status code {response.status_code}")
assert response.status_code == 404
assert response.json() == {"detail": "File not found"}
logger.debug("test_transcribe_file_not_found completed successfully")
def test_transcribe_internal_server_error(monkeypatch):
logger.debug("Starting test_transcribe_internal_server_error")
# Simulate an exception being raised in the `local_inference_tts` function
def mock_local_inference_tts(args):
logger.debug("Mock local_inference_tts called")
raise Exception("Test exception")
monkeypatch.setattr("app.main.local_inference_tts", mock_local_inference_tts)
logger.debug("Replaced local_inference_tts with mock")
response = client.post("/transcribe", json={
"text": "Hello, this is a test.",
"voice": "random",
"output_path": "output.wav",
"preset": "fast"
})
logger.debug(f"Received response with status code {response.status_code}")
assert response.status_code == 500
assert response.json() == {"detail": "Test exception"}
logger.debug("test_transcribe_internal_server_error completed successfully")

View file

@ -1,56 +1,99 @@
import argparse
import os
import logging
import logging.config
import os, sys
from tqdm import tqdm
import torch
import yaml
import torchaudio
from tortoise.api import TextToSpeech, MODELS_DIR
from utils.audio import load_voices
# Add the root directory of the repo to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tortoise.api import TextToSpeech, MODELS_DIR, pick_best_batch_size_for_gpu
from tortoise.utils.audio import load_voices
# Configure logging to print to the console
# Load logging configuration
with open("logging_conf.yml", 'r') as file:
config = yaml.safe_load(file.read())
logging.config.dictConfig(config)
logger = logging.getLogger(__name__)
def main(args):
if torch.backends.mps.is_available():
args.use_deepspeed = False
# if torch.backends.mps.is_available():
# args.use_deepspeed = False
os.makedirs(args.output_path, exist_ok=True)
tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
tts = TextToSpeech(
models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
selected_voices = args.voice.split(',')
for k, selected_voice in enumerate(selected_voices):
for k, selected_voice in tqdm(enumerate(selected_voices), desc="generating using selected voice"):
if '&' in selected_voice:
voice_sel = selected_voice.split('&')
else:
voice_sel = [selected_voice]
voice_samples, conditioning_latents = load_voices(voice_sel)
gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount)
gen, dbg_state = tts.tts_with_preset(
args.text, k=args.candidates, voice_samples=voice_samples,
conditioning_latents=conditioning_latents,
preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount
)
if isinstance(gen, list):
for j, g in enumerate(gen):
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
torchaudio.save(
os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000
)
else:
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)
print(f"Audio saved to {args.output_path} as {selected_voice}_{k}.wav")
if args.produce_debug_state:
os.makedirs('debug_states', exist_ok=True)
torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth')
if __name__ == '__main__':
"""
class TTSArgs(BaseModel):
text: str
voice: str
output_path: str
preset: str
model_dir: str = os.getenv("TORTOISE_MODELS_DIR", "data/models")
use_deepspeed: bool = False
kv_cache: bool = False
half: bool = False
candidates: int = 1
seed: int = None
cvvp_amount: float = 0.0
produce_debug_state: bool = False
"""
parser = argparse.ArgumentParser()
parser.add_argument('text', type=str, help='Text to speak. This argument is required.')
parser.add_argument('--voice', type=str, help="""
Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.""", default='random')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
parser.add_argument('--use_deepspeed', type=str, help='Use deepspeed for speed bump.', default=False)
parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True)
parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True)
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
parser.add_argument(
'text', type=str, help='Text to speak. This argument is required.')
parser.add_argument(
'--voice', type=str, help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) Use the & character to join two voices together. Use a comma to perform inference on multiple voices.", default='random')
parser.add_argument(
'--preset', type=str, help='Which voice preset to use.', default='fast')
parser.add_argument(
'--use_deepspeed', action=argparse.BooleanOptionalAction, type=bool, help='Use deepspeed for speed bump.', default=False)
parser.add_argument(
'--kv_cache', type=bool, action=argparse.BooleanOptionalAction, help='If you disable this please wait for a long a time to get the output', default=True)
parser.add_argument(
'--autoregressive_batch_size', type=int, help='Batch size for autoregressive inference.', default=pick_best_batch_size_for_gpu())
parser.add_argument(
'--half', type=bool, action=argparse.BooleanOptionalAction, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True)
parser.add_argument(
'--output_path', type=str, help='Where to store outputs (directory).', default='data/samples/')
parser.add_argument(
'--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so thisshould only be specified if you have custom checkpoints.', default=MODELS_DIR)
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
parser.add_argument('--produce_debug_state', type=bool, action=argparse.BooleanOptionalAction, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.'
'Increasing this can in some cases reduce the likelihood of multiple speakers. Defaults to 0 (disabled)', default=.0)
args = parser.parse_args()
assert args.half == False if args.use_deepspeed else True
main(args)