diff --git a/Dockerfile b/Dockerfile index 4b7b740..7109d0c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,11 +45,11 @@ FROM conda_base AS runner # Install the application WORKDIR /app -RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install" +RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install -r requirements.txt && python setup.py install" # Default entrypoint -RUN chmod +x /app/scripts/docker-entrypoint.sh -ENTRYPOINT ["/app/scripts/docker-entrypoint.sh"] +RUN chmod +x /app/scripts/tts-entrypoint.sh +ENTRYPOINT ["/app/scripts/tts-entrypoint.sh"] # Provide default CMD if no arguments are passed CMD ["--help"] \ No newline at end of file diff --git a/Dockerfile.app b/Dockerfile.app index c0afb16..3bc9e02 100644 --- a/Dockerfile.app +++ b/Dockerfile.app @@ -48,11 +48,10 @@ WORKDIR /app RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install" # Install FastAPI and Uvicorn -RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install fastapi uvicorn" +RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install -r requirements.txt" -# Copy the FastAPI app -COPY app /app/api +# Default entrypoint +RUN chmod +x /app/scripts/tts_app-entrypoint.sh +ENTRYPOINT ["/app/scripts/tts_app-entrypoint.sh"] -# Default command to run the FastAPI app -ENTRYPOINT ["uvicorn", "app.api:app"] CMD ["--host", "0.0.0.0", "--port", "8000"] diff --git a/app/models/request.py b/app/models/request.py index 7645180..ea1df28 100644 --- a/app/models/request.py +++ b/app/models/request.py @@ -1,8 +1,14 @@ +from enum import Enum from typing import Optional from pydantic import BaseModel +class Presets(str, Enum): + ULTRA_FAST='ultra_fast' + FAST='fast' + STANDARD='standard' + HIGH_QUALITY='high_quality' + class TranscriptionRequest(BaseModel): text: str voice: str - output_path: Optional[str] = "data/samples/" - preset: str = "ultra_fast" + preset: Presets = "ultra_fast" diff --git a/app/models/tts.py b/app/models/tts.py index d4186b9..2b5450f 100644 --- a/app/models/tts.py +++ b/app/models/tts.py @@ -5,11 +5,7 @@ from enum import Enum from tortoise.do_tts import pick_best_batch_size_for_gpu from tortoise.api import MODELS_DIR -class Presets(str, Enum): - ULTRA_FAST='ultra_fast' - FAST='fast' - STANDARD='standard' - HIGH_QUALITY='high_quality' +from .request import Presets class TTSArgs(BaseModel): text: str diff --git a/requirements.txt b/requirements.txt index 09469f2..0dbe226 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,7 @@ sounddevice spacy==3.7.5 # for api uvicorn==0.30.5 -google_auth_oauthlib \ No newline at end of file +google_auth_oauthlib==1.2.1 +python-dotenv==1.0.1 +fastapi==0.112.0 +beartype==0.18.5 \ No newline at end of file diff --git a/scripts/docker-entrypoint.sh b/scripts/tts-entrypoint.sh similarity index 100% rename from scripts/docker-entrypoint.sh rename to scripts/tts-entrypoint.sh diff --git a/scripts/tts_app-entrypoint.sh b/scripts/tts_app-entrypoint.sh new file mode 100644 index 0000000..5c6f456 --- /dev/null +++ b/scripts/tts_app-entrypoint.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Source the conda environment +source /root/miniconda/etc/profile.d/conda.sh +conda activate tortoise + +# Execute the Python script with passed arguments +uvicorn app.main:app "$@" \ No newline at end of file diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 8b589bd..50cdff2 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -1,56 +1,77 @@ import pytest import os - +import sys +import logging +import asyncio from fastapi.testclient import TestClient + +# Add the root directory of the repo to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +# Add the tortoise directory to sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise'))) + from app.main import app, tasks from app.models.request import TranscriptionRequest import dotenv dotenv.load_dotenv() +# Configure logging to print to the console +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + client = TestClient(app) +@pytest.fixture(scope="module", autouse=True) +def setup_fifo_queue(): + global fifo_queue + fifo_queue = asyncio.Queue() + +@pytest.fixture(scope="module") +def mocker(): + from pytest_mock import mocker + return mocker + # Helper function to simulate authentication -def basic_auth(username: str, password: str): +def basic_auth(username, password): import base64 credentials = f"{username}:{password}" - encoded_credentials = base64.b64encode(credentials.encode('utf-8')).decode('utf-8') + encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii") return {"Authorization": f"Basic {encoded_credentials}"} -# Test for /transcribe endpoint @pytest.mark.asyncio async def test_transcribe(mocker): # Mock the local_inference_tts function to avoid running the actual TTS - mocker.patch('app.main.local_inference_tts', return_value='data/results/output.wav') + mocker.patch('app.main.local_inference_tts', return_value='data/tests/random_0.wav') request_data = { "text": "Hello, how are you?", "voice": "random", - "output_path": "data/results", - "preset": "fast" + "output_path": "data/tests", + "preset": "ultra_fast" } - response = client.post( - "/transcribe", - headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")), - json=request_data - ) + with TestClient(app) as client: + response = client.post( + "/tts", + headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")), + json=request_data + ) assert response.status_code == 200 assert response.headers['content-type'] == 'audio/wav' - assert response.headers['content-disposition'] == 'attachment; filename=output.wav' + # assert response.headers['content-disposition'] == 'attachment; filename=random_0.wav' -# Test for /queue-status endpoint def test_queue_status(): - response = client.get( - "/queue-status", - headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) - ) + with TestClient(app) as client: + response = client.get( + "/queue-status", + headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) + ) assert response.status_code == 200 assert 'queue_length' in response.json() assert 'tasks' in response.json() -# Test for /task-status/{task_id} endpoint def test_task_status(): # Add a dummy task to the tasks dictionary task_id = 'test-task-id' @@ -59,17 +80,17 @@ def test_task_status(): 'request': TranscriptionRequest( text="Hello, how are you?", voice="random", - output_path="data/tests", preset="fast" ), 'result': None, 'error': None } - response = client.get( - f"/task-status/{task_id}", - headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) - ) + with TestClient(app) as client: + response = client.get( + f"/task-status/{task_id}", + headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")) + ) assert response.status_code == 200 assert response.json()['status'] == 'queued' @@ -77,18 +98,17 @@ def test_task_status(): # Clean up the tasks dictionary del tasks[task_id] -# Test for failed authentication def test_failed_authentication(): - response = client.post( - "/transcribe", - headers=basic_auth("wrong_username", "wrong_password"), - json={ - "text": "Hello, how are you?", - "voice": "random", - "output_path": "data/tests", - "preset": "fast" - } - ) + with TestClient(app) as client: + response = client.post( + "/transcribe", + headers=basic_auth("wrong_username", "wrong_password"), + json={ + "text": "Hello, how are you?", + "voice": "random", + "output_path": "data/tests", + "preset": "fast" + } + ) - assert response.status_code == 401 - assert response.json() == {"detail": "Incorrect username or password"} + assert response.status_code == 401 \ No newline at end of file