stable , concurrent queue system for transcription and dockerized application

This commit is contained in:
supermomo668 2024-08-07 06:38:19 +00:00
parent c7dbcef434
commit ef6535712f
15 changed files with 277 additions and 128 deletions

View file

@ -1,36 +1,80 @@
import argparse
import os
import sys
import uuid
import asyncio
import dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
import concurrent.futures
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.responses import FileResponse, JSONResponse
from contextlib import asynccontextmanager
from beartype import beartype
from pydantic import BaseModel
from app.models.request import TranscriptionRequest
from app.models.tts import TTSArgs
from app.services.oauth import get_current_user
from app.utils import pick_max_worker_function
from app.services.auth import get_current_username
from tortoise.do_tts import main as tts_main
app = FastAPI()
dotenv.load_dotenv()
# Create the ThreadPoolExecutor with the determined number of max workers
max_workers = pick_max_worker_function()
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
# Dictionary to store task information
tasks = {}
@beartype
def local_inference_tts(args: TTSArgs):
"""
Run the TTS directly using the `main` function from `tortoise/do_tts.py`.
Args:
- args (Args): The arguments to pass to the TTS function.
- args (TTSArgs): The arguments to pass to the TTS function.
Returns:
- str: Path to the output audio file.
"""
tts_main(args)
return args.output_path
@app.post("/transcribe")
output_path = tts_main(args)
return output_path
async def process_requests():
while True:
task_id, (args, future) = await fifo_queue.get() # Wait for a request from the queue
try:
tasks[task_id]['status'] = 'in_progress'
output_path = await asyncio.get_event_loop().run_in_executor(executor, local_inference_tts, args)
future.set_result(output_path)
tasks[task_id]['status'] = 'completed'
tasks[task_id]['result'] = output_path
except Exception as e:
future.set_exception(e)
tasks[task_id]['status'] = 'failed'
tasks[task_id]['error'] = str(e)
finally:
fifo_queue.task_done() # Indicate that the request has been processed
@asynccontextmanager
async def lifespan(app: FastAPI):
global fifo_queue
fifo_queue = asyncio.Queue() # Initialize the queue within the context
# Start the request processing task
task = asyncio.create_task(process_requests())
yield
# Clean up the task on shutdown
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Assign the lifespan context manager to the app
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def home():
return JSONResponse(content={"message": "Hello, FiCast-TTS! Check the docs at /docs."})
@app.post("/transcribe", dependencies=[Depends(get_current_username)])
async def transcribe(request: TranscriptionRequest):
try:
args = TTSArgs(
@ -39,41 +83,76 @@ async def transcribe(request: TranscriptionRequest):
output_path=request.output_path,
preset=request.preset
)
output_path = local_inference_tts(args)
# Use a future to get the result of the inference
future = asyncio.get_event_loop().create_future()
# Generate a unique task ID
task_id = str(uuid.uuid4())
await fifo_queue.put((task_id, (args, future)))
# Store task information
tasks[task_id] = {
'status': 'queued',
'request': request,
'result': None,
'error': None
}
# Await the result of the future
output_path = await future
# Check if file exists
if not os.path.isfile(output_path):
raise HTTPException(status_code=404, detail="File not found")
raise HTTPException(status_code=404, detail=f"File not found: {output_path}")
return FileResponse(output_path, media_type='audio/wav', filename=os.path.basename(output_path))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('text', type=str, help='Text to speak. This argument is required.')
parser.add_argument('--voice', type=str, help="""
Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.""", default='random')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast')
parser.add_argument('--output_path', type=str, help='Where to store outputs (directory).', default='data/results/')
@app.get("/queue-status", dependencies=[Depends(get_current_username)])
async def queue_status():
"""
Endpoint to get the current status of the queue.
"""
return {
"queue_length": fifo_queue.qsize(),
"tasks": tasks
}
args = parser.parse_args()
tts_args = TTSArgs(
text=args.text,
voice=args.voice,
output_path=args.output_path,
preset=args.preset,
)
try:
output_path = local_inference_tts(tts_args)
return output_path
except Exception as e:
print(f"Error during TTS generation: {str(e)}")
sys.exit(1)
@app.get("/task-status/{task_id}", dependencies=[Depends(get_current_username)])
async def task_status(task_id: str):
"""
Endpoint to get the status of a specific task.
"""
task = tasks.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
if __name__ == "__main__":
main()
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('text', type=str, help='Text to speak. This argument is required.')
parser.add_argument('--voice', type=str, help="""
Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.""", default='random')
parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='ultra_fast')
args = parser.parse_args()
tts_args = TTSArgs(
text=args.text,
voice=args.voice,
preset=args.preset,
)
try:
output_path = local_inference_tts(tts_args)
print(f"Output stored at: {output_path}")
return output_path
except Exception as e:
print(f"Error during TTS generation: {str(e)}")
sys.exit(1)
main()

View file

@ -1,7 +1,8 @@
from typing import Optional
from pydantic import BaseModel
class TranscriptionRequest(BaseModel):
text: str
voice: str
output_path: str
output_path: Optional[str] = "data/samples/"
preset: str = "ultra_fast"

View file

@ -1,18 +1,27 @@
import os
from pydantic import BaseModel
from enum import Enum
from tortoise.do_tts import pick_best_batch_size_for_gpu
from tortoise.api import MODELS_DIR
class Presets(str, Enum):
ULTRA_FAST='ultra_fast'
FAST='fast'
STANDARD='standard'
HIGH_QUALITY='high_quality'
class TTSArgs(BaseModel):
text: str
voice: str = 'random'
output_path: str = 'results/'
preset: str = 'fast'
output_path: str = 'data/samples/'
preset: Presets = 'ultra_fast'
model_dir: str = os.getenv("TORTOISE_MODELS_DIR", MODELS_DIR)
use_deepspeed: bool = False
kv_cache: bool = True
autoregressive_batch_size: int = pick_best_batch_size_for_gpu()
half: bool = True
candidates: int = 3
candidates: int = 1
seed: int = None
cvvp_amount: float = 0.0
produce_debug_state: bool = True

19
app/services/auth.py Normal file
View file

@ -0,0 +1,19 @@
import os
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
import secrets
security = HTTPBasic()
def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
assert os.getenv("TEST_USERNAME")
assert os.getenv("TEST_PASSWORD")
correct_username = secrets.compare_digest(credentials.username, os.getenv("TEST_USERNAME"))
correct_password = secrets.compare_digest(credentials.password, os.getenv("TEST_PASSWORD"))
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username

13
app/utils.py Normal file
View file

@ -0,0 +1,13 @@
import os
def pick_max_worker_function():
# Determine the number of CPU cores
cpu_count = os.cpu_count()
if cpu_count is None:
cpu_count = 1 # Default to 1 if unable to determine
# Assume the tasks are I/O-bound; we can afford to have more workers
# You might want to adjust this logic based on the nature of your tasks
max_workers = cpu_count * 2
print(f"Picked max workers: {max_workers}")
return max_workers

View file

@ -0,0 +1,8 @@
curl -X POST "http://127.0.0.1:42110/transcribe" \
-u ficast-uzer:ficast-testpazz \
-H "Content-Type: application/json" \
-d '{
"text": "Hello, how are you?",
"voice": "random",
"preset": "ultra_fast"
}'

View file

@ -1,8 +1,9 @@
image="tts:app"
docker run --gpus all \
-e TORTOISE_MODELS_DIR=/models \
-v "${PWD}/data/models":/models \
-v "${PWD}/data/results":/results \
-v "${PWD}/data/.cache/huggingface":/root/.cache/huggingface \
-v /root:/work \
--name tts-app \
-it tts
--name tts-api \
-it $image

View file

@ -3,7 +3,6 @@ rotary_embedding_torch
transformers==4.31.0
tokenizers
inflect
# progressbar2
einops==0.4.1
unidecode
scipy
@ -24,3 +23,6 @@ hjson
psutil
sounddevice
spacy==3.7.5
# for api
uvicorn==0.30.5
google_auth_oauthlib

View file

@ -22,7 +22,6 @@ setuptools.setup(
'tqdm',
'rotary_embedding_torch',
'inflect',
'progressbar',
'einops',
'unidecode',
'scipy',

View file

@ -1,83 +1,94 @@
import os, sys
import tempfile
import logging
import pytest
import os
from fastapi.testclient import TestClient
from app.main import app, tasks
from app.models.request import TranscriptionRequest
# Add the root directory of the repo to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Add the tortoise directory to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
from app.models.tts import TTSArgs
from app.main import app
import dotenv
dotenv.load_dotenv()
client = TestClient(app)
# Configure logging to print to the console
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Helper function to simulate authentication
def basic_auth(username: str, password: str):
import base64
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode('utf-8')).decode('utf-8')
return {"Authorization": f"Basic {encoded_credentials}"}
def test_transcribe():
logger.debug("Starting test_transcribe")
# Create a temporary directory to store the output
with tempfile.TemporaryDirectory() as temp_dir:
logger.debug(f"Created temporary directory at {temp_dir}")
response = client.post(
"/transcribe", json={
"text": "Hello, this is a test.",
"voice": "random",
"output_path": temp_dir,
"preset": "fast"
})
logger.debug(f"Received response with status code {response.status_code}")
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
files = os.listdir(temp_dir)
assert len(files) >= 1 # or should equal args.candidate
# check if ended in .wav
assert files[0].endswith('.wav'), f"Expected the file to end with '.wav', but got {files[0]}"
logger.debug("test_transcribe completed successfully")
def test_transcribe_file_not_found():
logger.debug("Starting test_transcribe_file_not_found")
# Create a temporary directory and delete it immediately to ensure the path does not exist
temp_dir = tempfile.mkdtemp()
logger.debug(f"Created temporary directory at {temp_dir}")
os.rmdir(temp_dir)
logger.debug(f"Deleted temporary directory at {temp_dir}")
# Test for /transcribe endpoint
@pytest.mark.asyncio
async def test_transcribe(mocker):
# Mock the local_inference_tts function to avoid running the actual TTS
mocker.patch('app.main.local_inference_tts', return_value='data/results/output.wav')
request_data = {
"text": "Hello, how are you?",
"voice": "random",
"output_path": "data/results",
"preset": "fast"
}
response = client.post(
"/transcribe", json={
"text": "Hello, this is a test.",
"/transcribe",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
json=request_data
)
assert response.status_code == 200
assert response.headers['content-type'] == 'audio/wav'
assert response.headers['content-disposition'] == 'attachment; filename=output.wav'
# Test for /queue-status endpoint
def test_queue_status():
response = client.get(
"/queue-status",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
assert response.status_code == 200
assert 'queue_length' in response.json()
assert 'tasks' in response.json()
# Test for /task-status/{task_id} endpoint
def test_task_status():
# Add a dummy task to the tasks dictionary
task_id = 'test-task-id'
tasks[task_id] = {
'status': 'queued',
'request': TranscriptionRequest(
text="Hello, how are you?",
voice="random",
output_path="data/tests",
preset="fast"
),
'result': None,
'error': None
}
response = client.get(
f"/task-status/{task_id}",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
assert response.status_code == 200
assert response.json()['status'] == 'queued'
# Clean up the tasks dictionary
del tasks[task_id]
# Test for failed authentication
def test_failed_authentication():
response = client.post(
"/transcribe",
headers=basic_auth("wrong_username", "wrong_password"),
json={
"text": "Hello, how are you?",
"voice": "random",
"output_path": temp_dir,
"output_path": "data/tests",
"preset": "fast"
})
logger.debug(f"Received response with status code {response.status_code}")
}
)
assert response.status_code == 404
assert response.json() == {"detail": "File not found"}
logger.debug("test_transcribe_file_not_found completed successfully")
def test_transcribe_internal_server_error(monkeypatch):
logger.debug("Starting test_transcribe_internal_server_error")
# Simulate an exception being raised in the `local_inference_tts` function
def mock_local_inference_tts(args):
logger.debug("Mock local_inference_tts called")
raise Exception("Test exception")
monkeypatch.setattr("app.main.local_inference_tts", mock_local_inference_tts)
logger.debug("Replaced local_inference_tts with mock")
response = client.post("/transcribe", json={
"text": "Hello, this is a test.",
"voice": "random",
"output_path": "output.wav",
"preset": "fast"
})
logger.debug(f"Received response with status code {response.status_code}")
assert response.status_code == 500
assert response.json() == {"detail": "Test exception"}
logger.debug("test_transcribe_internal_server_error completed successfully")
assert response.status_code == 401
assert response.json() == {"detail": "Incorrect username or password"}

View file

@ -1,12 +1,10 @@
import os
import random
import uuid
from time import time
from urllib import request
# from urllib import request
import torch
import torch.nn.functional as F
import progressbar
import torchaudio
from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead

View file

@ -6,7 +6,6 @@ from urllib import request
import torch
import torch.nn.functional as F
import progressbar
import torchaudio
import numpy as np
from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead

View file

@ -25,6 +25,8 @@ def main(args):
# if torch.backends.mps.is_available():
# args.use_deepspeed = False
os.makedirs(args.output_path, exist_ok=True)
if not args.autoregressive_batch_size:
args.autoregressive_batch_size = pick_best_batch_size_for_gpu()
tts = TextToSpeech(
models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
@ -43,15 +45,18 @@ def main(args):
)
if isinstance(gen, list):
for j, g in enumerate(gen):
output_path = os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav')
torchaudio.save(
os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000
output_path, g.squeeze(0).cpu(), 24000
)
else:
torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)
output_path = os.path.join(args.output_path, f'{selected_voice}_{k}.wav')
torchaudio.save(output_path, gen.squeeze(0).cpu(), 24000)
print(f"Audio saved to {args.output_path} as {selected_voice}_{k}.wav")
if args.produce_debug_state:
os.makedirs('debug_states', exist_ok=True)
torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth')
return output_path
if __name__ == '__main__':
"""
@ -75,7 +80,12 @@ if __name__ == '__main__':
parser.add_argument(
'--voice', type=str, help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) Use the & character to join two voices together. Use a comma to perform inference on multiple voices.", default='random')
parser.add_argument(
'--preset', type=str, help='Which voice preset to use.', default='fast')
'--preset', type=str, help="""Which voice preset to use. Available presets = {
'ultra_fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},
'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 80},
'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200},
'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400}
}""", choices=['ultra_fast', 'fast', 'standard', 'high_quality'], default='fast')
parser.add_argument(
'--use_deepspeed', action=argparse.BooleanOptionalAction, type=bool, help='Use deepspeed for speed bump.', default=False)
parser.add_argument(

Binary file not shown.