From c469c009762168ba3376cdcdef68ea53bfcb8de6 Mon Sep 17 00:00:00 2001 From: supermomo668 Date: Fri, 9 Aug 2024 08:00:07 +0000 Subject: [PATCH] stable api docker app, fixed bug with api token. stable tests --- app/main.py | 199 ++++++++++++++------ app/services/auth.py | 118 ++++++++++-- app/services/oauth.py | 1 - app/utils.py | 2 +- bin/curl/test-app_main.sh | 10 - bin/docker-run_app.sh | 9 + logging_conf.yml | 28 --- requirements.txt | 4 +- tests/curl/test-app_TTS-credential.sh | 11 ++ tests/curl/test-app_TTS-token.sh | 10 + tests/curl/test-app_api_key_queue_status.sh | 2 + tests/curl/test-app_login.sh | 6 + tests/curl/test-app_task_result.sh | 5 + tests/curl/test-app_task_status.sh | 3 + tests/test_transcribe.py | 124 ------------ tests/test_tts_api.py | 126 +++++++++++++ tortoise/api.py | 6 +- tortoise/api_fast.py | 5 +- tortoise/do_tts.py | 24 ++- 19 files changed, 444 insertions(+), 249 deletions(-) delete mode 100644 bin/curl/test-app_main.sh create mode 100644 bin/docker-run_app.sh delete mode 100644 logging_conf.yml create mode 100644 tests/curl/test-app_TTS-credential.sh create mode 100644 tests/curl/test-app_TTS-token.sh create mode 100644 tests/curl/test-app_api_key_queue_status.sh create mode 100644 tests/curl/test-app_login.sh create mode 100644 tests/curl/test-app_task_result.sh create mode 100644 tests/curl/test-app_task_status.sh delete mode 100644 tests/test_transcribe.py create mode 100644 tests/test_tts_api.py diff --git a/app/main.py b/app/main.py index 9d74b9e..83d6f1f 100644 --- a/app/main.py +++ b/app/main.py @@ -3,48 +3,62 @@ import sys import uuid import asyncio import dotenv +import logging +import time +from datetime import timedelta import concurrent.futures from fastapi import FastAPI, Depends, HTTPException, status from fastapi.responses import FileResponse, JSONResponse from contextlib import asynccontextmanager from beartype import beartype +from fastapi.security import OAuth2PasswordRequestForm from app.models.request import TranscriptionRequest from app.models.tts import TTSArgs -from app.services.oauth import get_current_user -from app.utils import pick_max_worker_function -from app.services.auth import get_current_username +# from app.services.oauth import get_current_user +from app.utils import pick_max_worker +from app.services.auth import get_current_user, verify_user, ACCESS_TOKEN_EXPIRE_MINUTES, create_access_token, timedelta from tortoise.utils.audio import BUILTIN_VOICES_DIR -from tortoise.do_tts import main as tts_main +from tortoise.do_tts import _initialized_tts, infer_voice + +load_envar = dotenv.load_dotenv() +assert load_envar and os.getenv("DEFAULT_USERNAME"), "Missing environment variables at .env" + +# Environment-specific variable to skip initialization during testing +IS_TESTING = os.getenv("TESTING", "False").lower() in ("true", "1") + +# Configure logging +logging.basicConfig( + level=logging.DEBUG, # Set the logging level to DEBUG + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format + handlers=[ + logging.FileHandler("app_debug.log"), # Log to a file named `app_debug.log` + logging.StreamHandler() # Also log to console + ] +) +logger = logging.getLogger(__name__) -dotenv.load_dotenv() # Create the ThreadPoolExecutor with the determined number of max workers -max_workers = pick_max_worker_function() -executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) +executor = concurrent.futures.ThreadPoolExecutor( + max_workers=int(os.getenv("MAX_WORKERS", pick_max_worker()))) +TASK_TIMEOUT = int(os.getenv("TASK_TIMEOUT", 60*30)) # 30 minutes # Dictionary to store task information tasks = {} @beartype -def local_inference_tts(args: TTSArgs): - """ - Run the TTS directly using the `main` function from `tortoise/do_tts.py`. - Args: - - args (TTSArgs): The arguments to pass to the TTS function. - Returns: - - str: Path to the output audio file. - """ - output_path = tts_main(args) +def local_inference_tts(tts, args): + output_path = infer_voice(tts, args) return output_path -async def process_requests(): +async def process_requests(tts): while True: - task_id, (args, future) = await fifo_queue.get() # Wait for a request from the queue + task_id, (args, future) = await fifo_queue.get() try: tasks[task_id]['status'] = 'in_progress' - output_path = await asyncio.get_event_loop().run_in_executor(executor, local_inference_tts, args) + output_path = await asyncio.get_event_loop().run_in_executor(executor, local_inference_tts, tts, args) future.set_result(output_path) tasks[task_id]['status'] = 'completed' tasks[task_id]['result'] = output_path @@ -53,34 +67,78 @@ async def process_requests(): tasks[task_id]['status'] = 'failed' tasks[task_id]['error'] = str(e) finally: - fifo_queue.task_done() # Indicate that the request has been processed + fifo_queue.task_done() + +async def post_initialization_event(): + try: + request = TranscriptionRequest( + text="Initialized! World", + voice="random", preset="ultra_fast" + ) + response = await text_to_speech(request) + print("Initialization TTS Response:", response) + except Exception as e: + print("Error during initialization TTS:", str(e)) @asynccontextmanager async def lifespan(app: FastAPI): global fifo_queue fifo_queue = asyncio.Queue() # Initialize the queue within the context - # Start the request processing task - task = asyncio.create_task(process_requests()) - yield - # Clean up the task on shutdown - task.cancel() - try: - await task - except asyncio.CancelledError: - pass -# Assign the lifespan context manager to the app + if not IS_TESTING: + args = TTSArgs(text="") + tts = _initialized_tts(args) + task = asyncio.create_task(process_requests(tts)) # Start the request processing task + await post_initialization_event() # Post initialization event + else: + print(f"Skipping initialization due to TESTING={IS_TESTING}") + task = None + yield + # Graceful shutdown + if task: + await fifo_queue.join() # Wait for the queue to empty + task.cancel() # Cancel the task to exit the loop + try: + await task + except asyncio.CancelledError: + pass + + # Shut down the executor + executor.shutdown(wait=True) + app = FastAPI(lifespan=lifespan) @app.get("/") async def home(): - return JSONResponse(content={"message": "Hello, FiCast-TTS! Check the docs at /docs."}) + return JSONResponse(content={ + "message": "Hello, FiCast-TTS! Check the docs at /docs."}) + +@app.post("/login") +async def login(form_data: OAuth2PasswordRequestForm = Depends()): + if verify_user(form_data.username, form_data.password): + try: + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": form_data.username}, expires_delta=access_token_expires + ) + print({"sub": form_data.username}) + print(access_token) + return { + "access_token": access_token, + "token_type": "bearer", + "user": form_data.username + } + except: + raise HTTPException( + status_code=401, detail="Unable to create access token") + raise HTTPException( + status_code=401, detail="Incorrect username or password") @app.get("/voices") async def available_voices(): return JSONResponse(content={"voices": os.listdir(BUILTIN_VOICES_DIR)}) -@app.post("/tts", dependencies=[Depends(get_current_username)]) +@app.post("/tts", dependencies=[Depends(get_current_user)]) async def text_to_speech(request: TranscriptionRequest): try: args = TTSArgs( @@ -89,54 +147,71 @@ async def text_to_speech(request: TranscriptionRequest): preset=request.preset ) - # Use a future to get the result of the inference future = asyncio.get_event_loop().create_future() - # Generate a unique task ID task_id = str(uuid.uuid4()) await fifo_queue.put((task_id, (args, future))) - # Store task information tasks[task_id] = { 'status': 'queued', 'request': request, 'result': None, 'error': None } - - # Await the result of the future - output_path = await future - - # Check if file exists - if not os.path.isfile(output_path): - raise HTTPException(status_code=404, detail=f"File not found: {output_path}") - - return FileResponse(output_path, media_type='audio/wav', filename=os.path.basename(output_path)) - + return {"task_id": task_id, "status": "queued"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/queue-status", dependencies=[Depends(get_current_username)]) + +@app.get("/queue-status", dependencies=[Depends(get_current_user)]) async def queue_status(): - """ - Endpoint to get the current status of the queue. - """ - return { - "queue_length": fifo_queue.qsize(), - "tasks": tasks - } + try: + return { + "queue_length": fifo_queue.qsize(), + "tasks": tasks + } + except HTTPException as e: + return JSONResponse( + status_code=e.status_code, + content={"detail": e.detail, "headers": e.headers, "error": str(e)}, + ) -@app.get("/task-status/{task_id}", dependencies=[Depends(get_current_username)]) +@app.get("/task-status/{task_id}", dependencies=[Depends(get_current_user)]) async def task_status(task_id: str): - """ - Endpoint to get the status of a specific task. - """ task = tasks.get(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") return task +@app.get("/task-result/{task_id}", dependencies=[Depends(get_current_user)]) +async def wait_for_result(task_id: str): + try: + task = tasks.get(task_id) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + start_time = time.time() + while task["status"] != "completed": + if task["status"] == "failed": + raise HTTPException( + status_code=500, detail=f"Task failed: {task.get('error', 'Unknown error')}") + elif time.time() - start_time > TASK_TIMEOUT: + raise HTTPException(status_code=408, detail="Task did not complete within the timeout") + await asyncio.sleep(5) + task = tasks.get(task_id) + + output_path = task["result"] + if not os.path.isfile(output_path): + raise HTTPException(status_code=404, detail=f"File not found: {output_path}") + # Expected output + return FileResponse( + output_path, + filename=os.path.basename(output_path), + media_type="audio/wav") + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + if __name__ == "__main__": import argparse + import uvicorn def main(): parser = argparse.ArgumentParser() parser.add_argument('text', type=str, help='Text to speak. This argument is required.') @@ -152,12 +227,14 @@ if __name__ == "__main__": voice=args.voice, preset=args.preset, ) - + tts = _initialized_tts(tts_args) try: - output_path = local_inference_tts(tts_args) + output_path = local_inference_tts(tts, tts_args) print(f"Output stored at: {output_path}") return output_path except Exception as e: print(f"Error during TTS generation: {str(e)}") sys.exit(1) - main() + + uvicorn.run( + app, host="0.0.0.0", port=42110, log_level="debug") diff --git a/app/services/auth.py b/app/services/auth.py index a7c2671..72920ba 100644 --- a/app/services/auth.py +++ b/app/services/auth.py @@ -1,19 +1,115 @@ +import base64 import os -from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBasic, HTTPBasicCredentials import secrets +import logging +from datetime import datetime, timedelta, timezone +from typing import Optional + +import jwt + +from fastapi import Depends, HTTPException, Request, Security, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials, APIKeyHeader + +# Setup logging +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format + handlers=[ + logging.FileHandler("app_debug.log"), + logging.StreamHandler() # Also log to console + ] +) +logger = logging.getLogger(__name__) + +# Secret key to encode and decode JWT tokens +DEFAULT_SECRET_KEY = str(os.getenv("DEFAULT_SECRET_KEY", "fek3kz9xzlsndSuczhgjds0vndi")) +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 24*60 security = HTTPBasic() +api_key_header = APIKeyHeader( + name="Authorization", auto_error=False) -def get_current_username(credentials: HTTPBasicCredentials = Depends(security)): - assert os.getenv("TEST_USERNAME") - assert os.getenv("TEST_PASSWORD") - correct_username = secrets.compare_digest(credentials.username, os.getenv("TEST_USERNAME")) - correct_password = secrets.compare_digest(credentials.password, os.getenv("TEST_PASSWORD")) - if not (correct_username and correct_password): + +def verify_user(username: str, password: str): + user = os.getenv("DEFAULT_USERNAME") + if user and secrets.compare_digest(os.getenv("DEFAULT_PASSWORD"), password): + return True + return False + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode( + to_encode, DEFAULT_SECRET_KEY, algorithm=ALGORITHM) + print(f"Encoded JWT: {encoded_jwt}") + return encoded_jwt + +async def get_current_user( + request: Request, # Access the full request to manually check headers + authorization: Optional[str] = Security(api_key_header) +): + logger.debug("Attempting to authenticate user...") + # First, try to authenticate using Basic Authentication + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Basic "): + print("Using Basic Auth...") + auth = auth_header.split(" ")[1] + credentials = base64.b64decode(auth).decode("utf-8").split(":") + username = credentials[0] + password = credentials[1] + + correct_username = secrets.compare_digest(username, os.getenv("DEFAULT_USERNAME")) + correct_password = secrets.compare_digest(password, os.getenv("DEFAULT_PASSWORD")) + + if correct_username and correct_password: + logger.debug("Basic auth successful.") + return {"auth": "basic", "user": username} + else: + logger.debug("Basic auth failed.") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Basic"}, + ) + + # If Basic Auth is not provided, try to authenticate using API Key (JWT) + elif authorization: + print("Using API Key...") + logger.debug(f"Authorization header received: {authorization[:5]}...") + scheme, _, token = authorization.partition(" ") + + if scheme.lower() != "bearer": + logger.debug("Invalid authentication scheme.") + raise HTTPException( + status_code=401, detail="Invalid authentication scheme") + + try: + payload = jwt.decode( + token, DEFAULT_SECRET_KEY, + algorithms=[ALGORITHM]) + username: str = payload.get("sub") + + if username is None: + logger.debug("JWT payload does not contain a username.") + raise HTTPException(status_code=401, detail="Could not validate credentials") + + logger.debug("JWT validation successful.") + return {"auth": "api_key", "user": username} + + except jwt.PyJWTError as e: + logger.debug(f"JWT validation failed: {e}") + raise HTTPException( + status_code=401, detail="Could not validate credentials") + + else: + logger.debug("No credentials provided.") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", + detail="Could not validate credentials", headers={"WWW-Authenticate": "Basic"}, - ) - return credentials.username + ) \ No newline at end of file diff --git a/app/services/oauth.py b/app/services/oauth.py index a2bfc67..91eba6d 100644 --- a/app/services/oauth.py +++ b/app/services/oauth.py @@ -18,7 +18,6 @@ def get_authenticated_service(): CLIENT_SECRETS_FILE, SCOPES, redirect_uri=REDIRECT_URI ) - # Tell the user to go to the authorization URL. auth_url, _ = flow.authorization_url(prompt='consent') print('Please go to this URL: {}'.format(auth_url)) diff --git a/app/utils.py b/app/utils.py index b409eba..60eaf2e 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,6 +1,6 @@ import os -def pick_max_worker_function(): +def pick_max_worker(): # Determine the number of CPU cores cpu_count = os.cpu_count() if cpu_count is None: diff --git a/bin/curl/test-app_main.sh b/bin/curl/test-app_main.sh deleted file mode 100644 index d675650..0000000 --- a/bin/curl/test-app_main.sh +++ /dev/null @@ -1,10 +0,0 @@ -curl -X POST "http://127.0.0.1:42110/transcribe" \ - -u ficast-uzer:ficast-testpazz \ - -H "Content-Type: application/json" \ - -d '{ - "text": "Hello, how are you?", - "voice": "random", - "preset": "ultra_fast" - }' \ - -o data/samples/api-output.wav - diff --git a/bin/docker-run_app.sh b/bin/docker-run_app.sh new file mode 100644 index 0000000..ed5a66c --- /dev/null +++ b/bin/docker-run_app.sh @@ -0,0 +1,9 @@ +image="tts:app" +docker run --gpus all \ + -e TORTOISE_MODELS_DIR=/models \ + -v "${PWD}/data/models":/models \ + -v "${PWD}/data/results":/results \ + -v "${PWD}/data/.cache/huggingface":/root/.cache/huggingface \ + -p 42110:42110 \ + --name tts-api \ + -it $image --port 42110 --host 0.0.0.0 \ No newline at end of file diff --git a/logging_conf.yml b/logging_conf.yml deleted file mode 100644 index 576e643..0000000 --- a/logging_conf.yml +++ /dev/null @@ -1,28 +0,0 @@ -version: 1 -disable_existing_loggers: False - -formatters: - standard: - format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - -handlers: - console: - class: logging.StreamHandler - level: DEBUG - formatter: standard - stream: ext://sys.stdout - -loggers: - app: - level: DEBUG - handlers: [console] - propagate: no - - tortoise: - level: DEBUG - handlers: [console] - propagate: no - -root: - level: DEBUG - handlers: [console] diff --git a/requirements.txt b/requirements.txt index 0dbe226..0f844a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,6 @@ uvicorn==0.30.5 google_auth_oauthlib==1.2.1 python-dotenv==1.0.1 fastapi==0.112.0 -beartype==0.18.5 \ No newline at end of file +beartype==0.18.5 +PyJWT==2.9.0 +python-multipart==0.0.9 \ No newline at end of file diff --git a/tests/curl/test-app_TTS-credential.sh b/tests/curl/test-app_TTS-credential.sh new file mode 100644 index 0000000..8e5d461 --- /dev/null +++ b/tests/curl/test-app_TTS-credential.sh @@ -0,0 +1,11 @@ +response=$(curl -X POST "http://127.0.0.1:42110/tts" \ + -u $TEST_USERNAME:$TEST_PASSWORD \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Hello, curl test?", + "voice": "random", + "preset": "ultra_fast" + }') + +TASK_ID=$(echo $response | jq -r .task_id) +echo $TASK_ID \ No newline at end of file diff --git a/tests/curl/test-app_TTS-token.sh b/tests/curl/test-app_TTS-token.sh new file mode 100644 index 0000000..1805f68 --- /dev/null +++ b/tests/curl/test-app_TTS-token.sh @@ -0,0 +1,10 @@ +response=$(curl -X POST "http://127.0.0.1:42110/tts" \ + -H "Authorization: Bearer ${ACCESS_TOKEN}" \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Hello, curl test?", + "voice": "random", + "preset": "ultra_fast" + }') +TASK_ID=$(echo $response | jq -r .task_id) +echo $TASK_ID \ No newline at end of file diff --git a/tests/curl/test-app_api_key_queue_status.sh b/tests/curl/test-app_api_key_queue_status.sh new file mode 100644 index 0000000..20c79a2 --- /dev/null +++ b/tests/curl/test-app_api_key_queue_status.sh @@ -0,0 +1,2 @@ +curl "http://127.0.0.1:42110/queue-status" \ + -H "Authorization: Bearer ${ACCESS_TOKEN}" \ No newline at end of file diff --git a/tests/curl/test-app_login.sh b/tests/curl/test-app_login.sh new file mode 100644 index 0000000..a54e3e5 --- /dev/null +++ b/tests/curl/test-app_login.sh @@ -0,0 +1,6 @@ +. tests/.env +response=$(curl -X POST "http://127.0.0.1:42110/login" \ + -d "username=${TEST_USERNAME}&password=${TEST_PASSWORD}" \ + -H "Content-Type: application/x-www-form-urlencoded") +ACCESS_TOKEN=$(echo $response | jq -r .access_token) +echo $ACCESS_TOKEN \ No newline at end of file diff --git a/tests/curl/test-app_task_result.sh b/tests/curl/test-app_task_result.sh new file mode 100644 index 0000000..5eb761d --- /dev/null +++ b/tests/curl/test-app_task_result.sh @@ -0,0 +1,5 @@ +TASK_ID=9ae03d68-9511-4642-b9bb-abb74948af61 +curl "http://127.0.0.1:42110/task-result/$TASK_ID" \ + -H "Authorization: Bearer ${ACCESS_TOKEN}" \ + -H "Content-Type: application/json" \ + -o data/samples/curl-task-result.wav \ No newline at end of file diff --git a/tests/curl/test-app_task_status.sh b/tests/curl/test-app_task_status.sh new file mode 100644 index 0000000..2c1f35d --- /dev/null +++ b/tests/curl/test-app_task_status.sh @@ -0,0 +1,3 @@ +curl "http://127.0.0.1:42110/task-status/$TASK_ID" \ + -u $TEST_USERNAME:$TEST_PASSWORD \ + -H "Content-Type: application/json" diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py deleted file mode 100644 index ae7d230..0000000 --- a/tests/test_transcribe.py +++ /dev/null @@ -1,124 +0,0 @@ -import pytest -import os -import sys -import logging -import dotenv -from fastapi.testclient import TestClient - -# Add the root directory of the repo to sys.path -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -# Add the tortoise directory to sys.path -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise'))) - -from app.main import app, tasks -from app.models.request import TranscriptionRequest - - -load_envvar = dotenv.load_dotenv() -assert load_envvar and os.getenv("TEST_USERNAME") and os.getenv("TEST_PASSWORD"), "Missing environment variables" - -# Configure logging to print to the console -logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -client = TestClient(app) - -# Helper function to simulate authentication -def basic_auth(username, password): - import base64 - credentials = f"{username}:{password}" - encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii") - return {"Authorization": f"Basic {encoded_credentials}"} - -@pytest.mark.asyncio -async def test_transcribe(mocker): - # Mock the local_inference_tts function to avoid running the actual TTS - mocker.patch('app.main.local_inference_tts') - request_data = { - "text": "Hello, how are you?", - "voice": "random", - "preset": "ultra_fast" - } - - with TestClient(app) as client: - response = client.post( - "/tts", - headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")), - json=request_data - ) - - # Log detailed information about the response for debugging - logger.debug(f"Response status code: {response.status_code}") - logger.debug(f"Response headers: {response.headers}") - logger.debug(f"Response content: {response.content}") - - assert response.status_code == 200 - assert response.headers['content-type'] == 'audio/wav' - # assert response.headers['content-disposition'] == 'attachment; filename=random_0.wav' - -def test_queue_status(): - with TestClient(app) as client: - response = client.get( - "/queue-status", - headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) - ) - - # Log detailed information about the response for debugging - logger.debug(f"Response status code: {response.status_code}") - logger.debug(f"Response headers: {response.headers}") - logger.debug(f"Response content: {response.content}") - - assert response.status_code == 200 - assert 'queue_length' in response.json() - assert 'tasks' in response.json() - -def test_task_status(): - # Add a dummy task to the tasks dictionary - task_id = 'test-task-id' - tasks[task_id] = { - 'status': 'queued', - 'request': TranscriptionRequest( - text="Hello, how are you?", - voice="random", - preset="fast" - ), - 'result': None, - 'error': None - } - - with TestClient(app) as client: - response = client.get( - f"/task-status/{task_id}", - headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) - ) - - # Log detailed information about the response for debugging - logger.debug(f"Response status code: {response.status_code}") - logger.debug(f"Response headers: {response.headers}") - logger.debug(f"Response content: {response.content}") - - assert response.status_code == 200 - assert response.json()['status'] == 'queued' - - # Clean up the tasks dictionary - del tasks[task_id] - -def test_failed_authentication(): - with TestClient(app) as client: - response = client.post( - "/tts", - headers=basic_auth("wrong_username", "wrong_password"), - json={ - "text": "Hello, how are you?", - "voice": "random", - "output_path": "data/tests", - "preset": "fast" - } - ) - - # Log detailed information about the response for debugging - logger.debug(f"Response status code: {response.status_code}") - logger.debug(f"Response headers: {response.headers}") - logger.debug(f"Response content: {response.content}") - - assert response.status_code == 401 \ No newline at end of file diff --git a/tests/test_tts_api.py b/tests/test_tts_api.py new file mode 100644 index 0000000..02f733e --- /dev/null +++ b/tests/test_tts_api.py @@ -0,0 +1,126 @@ +import pytest +import os, sys, time +import dotenv +from unittest.mock import patch, MagicMock +from fastapi.testclient import TestClient + +# Add the root directory of the repo to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise'))) + +# load environment variables +load_envvar = dotenv.load_dotenv('tests/.env', override=True) +assert load_envvar and os.getenv("TESTING").lower()=="true", "Missing environment variables" + +from app.main import app +client = TestClient(app) + +# Helper function to simulate authentication +def basic_auth(username, password): + import base64 + credentials = f"{username}:{password}" + encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii") + return {"Authorization": f"Basic {encoded_credentials}"} + +@pytest.fixture(scope="module") +def task_id(): + with patch('app.main.text_to_speech') as mock_tts: + mock_tts.return_value = { + "task_id": "mock_task_id", "status": "queued"} + request_data = { + "text": "Hello, how are you?", + "voice": "random", + "preset": "ultra_fast" + } + with TestClient(app) as client: + response = client.post( + "/tts", + headers=basic_auth( + os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")), + json=request_data + ) + response.raise_for_status() + return response.json().get("task_id") + +@pytest.fixture(scope="module") +def access_token(): + with TestClient(app) as client: + response = client.post( + "/login", + data={ + "username": os.getenv("TEST_USERNAME"), + "password": os.getenv("TEST_PASSWORD") + } + ) + response.raise_for_status() + token = response.json().get("access_token") + + assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}" + assert "access_token" in response.json(), f"Access token missing in response. Response: {response.text}" + assert "token_type" in response.json(), f"Token type missing in response. Response: {response.text}" + assert response.json()["token_type"] == "bearer", f"Expected token type 'bearer', got {response.json()['token_type']}. Response: {response.text}" + return token + +def test_queue_status_api_key(access_token): + with TestClient(app) as client: + response = client.get( + "/queue-status", + headers={ + "Authorization": f"Bearer {access_token}" + } + ) + assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}" + assert 'tasks' in response.json(), f"'tasks' key missing in response. Response: {response.json()}" + +def test_queue_status(): + with TestClient(app) as client: + response = client.get( + "/queue-status", + headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) + ) + + assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}" + assert 'tasks' in response.json(), f"'tasks' key missing in response. Response: {response.json()}" + +def test_tts_task_creation(task_id): + assert task_id is not None, f"Task ID is None. Check the task creation process." + + with TestClient(app) as client: + response = client.get( + "/queue-status", + headers=basic_auth( + os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) + ) + response.raise_for_status() + queue_status = response.json() + assert task_id in queue_status["tasks"], f"Task ID {task_id} not found in queue tasks. Queue status: {queue_status}" + assert queue_status["tasks"][task_id]["status"] in ["queued", "in_progress", "completed"], f"Expected task status 'queued', got {queue_status['tasks'][task_id]['status']}. Queue status: {queue_status}" + +def test_tts_task_completion(task_id): + with patch('app.main.tasks') as mock_tasks: + mock_tasks[task_id] = { + "status": "completed", + "result": {"message": "Task completed"} + } + + with TestClient(app) as client: + start_time = time.time() + while True: + response = client.get( + f"/task-status/{task_id}", + headers=basic_auth( + os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) + ) + if response.status_code == 200 and response.json().get("status") == "completed": + break + elif response.status_code == 500: + raise AssertionError(f"Task failed: {response.json().get('detail')}. Response: {response.text}") + elif time.time() - start_time > 60: + raise AssertionError(f"Task did not complete within 1 minute. Task ID: {task_id}") + time.sleep(5) + + # Verify that the task is completed + assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}" + assert response.json()["status"] == "completed", f"Expected task status 'completed', got {response.json()['status']}. Response: {response.json()}" + assert response.json()["result"]["message"] == "Task completed", f"Expected result message 'Task completed', got {response.json()['result']['message']}. Response: {response.json()}" + diff --git a/tortoise/api.py b/tortoise/api.py index 1c6853b..a5438fe 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -40,10 +40,14 @@ MODELS = { def get_model_path(model_name, models_dir=MODELS_DIR): """ Get path to given model, download it if it doesn't exist. + Uses the cache provided by HuggingFace's hf_hub_download. """ if model_name not in MODELS: raise ValueError(f'Model {model_name} not found in available models.') - model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir) + # hf_hub_download will automatically check the cache + model_path = hf_hub_download( + repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir + ) return model_path diff --git a/tortoise/api_fast.py b/tortoise/api_fast.py index 009359b..bc6a509 100644 --- a/tortoise/api_fast.py +++ b/tortoise/api_fast.py @@ -137,7 +137,10 @@ def classify_audio_clip(clip): classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, dropout=0, kernel_size=5, distribute_zero_label=False) - classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu'))) + # Use the get_model_path function to get the model's path + model_path = get_model_path('classifier.pth') + # Load the model state dictionary using the path returned + classifier.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) clip = clip.cpu().unsqueeze(0) results = F.softmax(classifier(clip), dim=-1) return results[0][0] diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py index 10ce29c..0809c17 100644 --- a/tortoise/do_tts.py +++ b/tortoise/do_tts.py @@ -16,20 +16,15 @@ from tortoise.utils.audio import load_voices # Configure logging to print to the console # Load logging configuration -with open("logging_conf.yml", 'r') as file: - config = yaml.safe_load(file.read()) - logging.config.dictConfig(config) logger = logging.getLogger(__name__) -def main(args): - # if torch.backends.mps.is_available(): - # args.use_deepspeed = False - os.makedirs(args.output_path, exist_ok=True) - if not args.autoregressive_batch_size: - args.autoregressive_batch_size = pick_best_batch_size_for_gpu() +def _initialized_tts(args): tts = TextToSpeech( - models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half) + models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, + half=args.half) + return tts +def infer_voice(tts: TextToSpeech, args: argparse.Namespace): selected_voices = args.voice.split(',') for k, selected_voice in tqdm(enumerate(selected_voices), desc="generating using selected voice"): if '&' in selected_voice: @@ -57,6 +52,15 @@ def main(args): os.makedirs('debug_states', exist_ok=True) torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth') return output_path + +def main(args): + if torch.backends.mps.is_available(): + args.use_deepspeed = False + os.makedirs(args.output_path, exist_ok=True) + if not args.autoregressive_batch_size: + args.autoregressive_batch_size = pick_best_batch_size_for_gpu() + tts = _initialized_tts(args) + return infer_voice(tts, args) if __name__ == '__main__': """