stable tts inference docker application

This commit is contained in:
supermomo668 2024-08-07 23:35:51 +00:00
parent 6975fab600
commit 7a44f8ffda
8 changed files with 84 additions and 53 deletions

View file

@ -45,11 +45,11 @@ FROM conda_base AS runner
# Install the application
WORKDIR /app
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install"
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install -r requirements.txt && python setup.py install"
# Default entrypoint
RUN chmod +x /app/scripts/docker-entrypoint.sh
ENTRYPOINT ["/app/scripts/docker-entrypoint.sh"]
RUN chmod +x /app/scripts/tts-entrypoint.sh
ENTRYPOINT ["/app/scripts/tts-entrypoint.sh"]
# Provide default CMD if no arguments are passed
CMD ["--help"]

View file

@ -48,11 +48,10 @@ WORKDIR /app
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install"
# Install FastAPI and Uvicorn
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install fastapi uvicorn"
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install -r requirements.txt"
# Copy the FastAPI app
COPY app /app/api
# Default entrypoint
RUN chmod +x /app/scripts/tts_app-entrypoint.sh
ENTRYPOINT ["/app/scripts/tts_app-entrypoint.sh"]
# Default command to run the FastAPI app
ENTRYPOINT ["uvicorn", "app.api:app"]
CMD ["--host", "0.0.0.0", "--port", "8000"]

View file

@ -1,8 +1,14 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
class Presets(str, Enum):
ULTRA_FAST='ultra_fast'
FAST='fast'
STANDARD='standard'
HIGH_QUALITY='high_quality'
class TranscriptionRequest(BaseModel):
text: str
voice: str
output_path: Optional[str] = "data/samples/"
preset: str = "ultra_fast"
preset: Presets = "ultra_fast"

View file

@ -5,11 +5,7 @@ from enum import Enum
from tortoise.do_tts import pick_best_batch_size_for_gpu
from tortoise.api import MODELS_DIR
class Presets(str, Enum):
ULTRA_FAST='ultra_fast'
FAST='fast'
STANDARD='standard'
HIGH_QUALITY='high_quality'
from .request import Presets
class TTSArgs(BaseModel):
text: str

View file

@ -25,4 +25,7 @@ sounddevice
spacy==3.7.5
# for api
uvicorn==0.30.5
google_auth_oauthlib
google_auth_oauthlib==1.2.1
python-dotenv==1.0.1
fastapi==0.112.0
beartype==0.18.5

View file

@ -0,0 +1,7 @@
#!/bin/bash
# Source the conda environment
source /root/miniconda/etc/profile.d/conda.sh
conda activate tortoise
# Execute the Python script with passed arguments
uvicorn app.main:app "$@"

View file

@ -1,56 +1,77 @@
import pytest
import os
import sys
import logging
import asyncio
from fastapi.testclient import TestClient
# Add the root directory of the repo to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Add the tortoise directory to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
from app.main import app, tasks
from app.models.request import TranscriptionRequest
import dotenv
dotenv.load_dotenv()
# Configure logging to print to the console
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_fifo_queue():
global fifo_queue
fifo_queue = asyncio.Queue()
@pytest.fixture(scope="module")
def mocker():
from pytest_mock import mocker
return mocker
# Helper function to simulate authentication
def basic_auth(username: str, password: str):
def basic_auth(username, password):
import base64
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode('utf-8')).decode('utf-8')
encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii")
return {"Authorization": f"Basic {encoded_credentials}"}
# Test for /transcribe endpoint
@pytest.mark.asyncio
async def test_transcribe(mocker):
# Mock the local_inference_tts function to avoid running the actual TTS
mocker.patch('app.main.local_inference_tts', return_value='data/results/output.wav')
mocker.patch('app.main.local_inference_tts', return_value='data/tests/random_0.wav')
request_data = {
"text": "Hello, how are you?",
"voice": "random",
"output_path": "data/results",
"preset": "fast"
"output_path": "data/tests",
"preset": "ultra_fast"
}
response = client.post(
"/transcribe",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
json=request_data
)
with TestClient(app) as client:
response = client.post(
"/tts",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
json=request_data
)
assert response.status_code == 200
assert response.headers['content-type'] == 'audio/wav'
assert response.headers['content-disposition'] == 'attachment; filename=output.wav'
# assert response.headers['content-disposition'] == 'attachment; filename=random_0.wav'
# Test for /queue-status endpoint
def test_queue_status():
response = client.get(
"/queue-status",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
with TestClient(app) as client:
response = client.get(
"/queue-status",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
assert response.status_code == 200
assert 'queue_length' in response.json()
assert 'tasks' in response.json()
# Test for /task-status/{task_id} endpoint
def test_task_status():
# Add a dummy task to the tasks dictionary
task_id = 'test-task-id'
@ -59,17 +80,17 @@ def test_task_status():
'request': TranscriptionRequest(
text="Hello, how are you?",
voice="random",
output_path="data/tests",
preset="fast"
),
'result': None,
'error': None
}
response = client.get(
f"/task-status/{task_id}",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
with TestClient(app) as client:
response = client.get(
f"/task-status/{task_id}",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
assert response.status_code == 200
assert response.json()['status'] == 'queued'
@ -77,18 +98,17 @@ def test_task_status():
# Clean up the tasks dictionary
del tasks[task_id]
# Test for failed authentication
def test_failed_authentication():
response = client.post(
"/transcribe",
headers=basic_auth("wrong_username", "wrong_password"),
json={
"text": "Hello, how are you?",
"voice": "random",
"output_path": "data/tests",
"preset": "fast"
}
)
with TestClient(app) as client:
response = client.post(
"/transcribe",
headers=basic_auth("wrong_username", "wrong_password"),
json={
"text": "Hello, how are you?",
"voice": "random",
"output_path": "data/tests",
"preset": "fast"
}
)
assert response.status_code == 401
assert response.json() == {"detail": "Incorrect username or password"}
assert response.status_code == 401