mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-25 10:00:47 +01:00
stable tts inference docker application
This commit is contained in:
parent
6975fab600
commit
7a44f8ffda
|
|
@ -45,11 +45,11 @@ FROM conda_base AS runner
|
|||
|
||||
# Install the application
|
||||
WORKDIR /app
|
||||
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install"
|
||||
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install -r requirements.txt && python setup.py install"
|
||||
|
||||
# Default entrypoint
|
||||
RUN chmod +x /app/scripts/docker-entrypoint.sh
|
||||
ENTRYPOINT ["/app/scripts/docker-entrypoint.sh"]
|
||||
RUN chmod +x /app/scripts/tts-entrypoint.sh
|
||||
ENTRYPOINT ["/app/scripts/tts-entrypoint.sh"]
|
||||
|
||||
# Provide default CMD if no arguments are passed
|
||||
CMD ["--help"]
|
||||
|
|
@ -48,11 +48,10 @@ WORKDIR /app
|
|||
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && python setup.py install"
|
||||
|
||||
# Install FastAPI and Uvicorn
|
||||
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install fastapi uvicorn"
|
||||
RUN bash -c "source ${CONDA_DIR}/etc/profile.d/conda.sh && conda activate tortoise && pip install -r requirements.txt"
|
||||
|
||||
# Copy the FastAPI app
|
||||
COPY app /app/api
|
||||
# Default entrypoint
|
||||
RUN chmod +x /app/scripts/tts_app-entrypoint.sh
|
||||
ENTRYPOINT ["/app/scripts/tts_app-entrypoint.sh"]
|
||||
|
||||
# Default command to run the FastAPI app
|
||||
ENTRYPOINT ["uvicorn", "app.api:app"]
|
||||
CMD ["--host", "0.0.0.0", "--port", "8000"]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Presets(str, Enum):
|
||||
ULTRA_FAST='ultra_fast'
|
||||
FAST='fast'
|
||||
STANDARD='standard'
|
||||
HIGH_QUALITY='high_quality'
|
||||
|
||||
class TranscriptionRequest(BaseModel):
|
||||
text: str
|
||||
voice: str
|
||||
output_path: Optional[str] = "data/samples/"
|
||||
preset: str = "ultra_fast"
|
||||
preset: Presets = "ultra_fast"
|
||||
|
|
|
|||
|
|
@ -5,11 +5,7 @@ from enum import Enum
|
|||
from tortoise.do_tts import pick_best_batch_size_for_gpu
|
||||
from tortoise.api import MODELS_DIR
|
||||
|
||||
class Presets(str, Enum):
|
||||
ULTRA_FAST='ultra_fast'
|
||||
FAST='fast'
|
||||
STANDARD='standard'
|
||||
HIGH_QUALITY='high_quality'
|
||||
from .request import Presets
|
||||
|
||||
class TTSArgs(BaseModel):
|
||||
text: str
|
||||
|
|
|
|||
|
|
@ -25,4 +25,7 @@ sounddevice
|
|||
spacy==3.7.5
|
||||
# for api
|
||||
uvicorn==0.30.5
|
||||
google_auth_oauthlib
|
||||
google_auth_oauthlib==1.2.1
|
||||
python-dotenv==1.0.1
|
||||
fastapi==0.112.0
|
||||
beartype==0.18.5
|
||||
7
scripts/tts_app-entrypoint.sh
Normal file
7
scripts/tts_app-entrypoint.sh
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash
|
||||
# Source the conda environment
|
||||
source /root/miniconda/etc/profile.d/conda.sh
|
||||
conda activate tortoise
|
||||
|
||||
# Execute the Python script with passed arguments
|
||||
uvicorn app.main:app "$@"
|
||||
|
|
@ -1,56 +1,77 @@
|
|||
import pytest
|
||||
import os
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Add the root directory of the repo to sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
# Add the tortoise directory to sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
|
||||
|
||||
from app.main import app, tasks
|
||||
from app.models.request import TranscriptionRequest
|
||||
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Configure logging to print to the console
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_fifo_queue():
|
||||
global fifo_queue
|
||||
fifo_queue = asyncio.Queue()
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mocker():
|
||||
from pytest_mock import mocker
|
||||
return mocker
|
||||
|
||||
# Helper function to simulate authentication
|
||||
def basic_auth(username: str, password: str):
|
||||
def basic_auth(username, password):
|
||||
import base64
|
||||
credentials = f"{username}:{password}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode('utf-8')).decode('utf-8')
|
||||
encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii")
|
||||
return {"Authorization": f"Basic {encoded_credentials}"}
|
||||
|
||||
# Test for /transcribe endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcribe(mocker):
|
||||
# Mock the local_inference_tts function to avoid running the actual TTS
|
||||
mocker.patch('app.main.local_inference_tts', return_value='data/results/output.wav')
|
||||
mocker.patch('app.main.local_inference_tts', return_value='data/tests/random_0.wav')
|
||||
request_data = {
|
||||
"text": "Hello, how are you?",
|
||||
"voice": "random",
|
||||
"output_path": "data/results",
|
||||
"preset": "fast"
|
||||
"output_path": "data/tests",
|
||||
"preset": "ultra_fast"
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/transcribe",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
|
||||
json=request_data
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/tts",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers['content-type'] == 'audio/wav'
|
||||
assert response.headers['content-disposition'] == 'attachment; filename=output.wav'
|
||||
# assert response.headers['content-disposition'] == 'attachment; filename=random_0.wav'
|
||||
|
||||
# Test for /queue-status endpoint
|
||||
def test_queue_status():
|
||||
response = client.get(
|
||||
"/queue-status",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/queue-status",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert 'queue_length' in response.json()
|
||||
assert 'tasks' in response.json()
|
||||
|
||||
# Test for /task-status/{task_id} endpoint
|
||||
def test_task_status():
|
||||
# Add a dummy task to the tasks dictionary
|
||||
task_id = 'test-task-id'
|
||||
|
|
@ -59,17 +80,17 @@ def test_task_status():
|
|||
'request': TranscriptionRequest(
|
||||
text="Hello, how are you?",
|
||||
voice="random",
|
||||
output_path="data/tests",
|
||||
preset="fast"
|
||||
),
|
||||
'result': None,
|
||||
'error': None
|
||||
}
|
||||
|
||||
response = client.get(
|
||||
f"/task-status/{task_id}",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
f"/task-status/{task_id}",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()['status'] == 'queued'
|
||||
|
|
@ -77,18 +98,17 @@ def test_task_status():
|
|||
# Clean up the tasks dictionary
|
||||
del tasks[task_id]
|
||||
|
||||
# Test for failed authentication
|
||||
def test_failed_authentication():
|
||||
response = client.post(
|
||||
"/transcribe",
|
||||
headers=basic_auth("wrong_username", "wrong_password"),
|
||||
json={
|
||||
"text": "Hello, how are you?",
|
||||
"voice": "random",
|
||||
"output_path": "data/tests",
|
||||
"preset": "fast"
|
||||
}
|
||||
)
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/transcribe",
|
||||
headers=basic_auth("wrong_username", "wrong_password"),
|
||||
json={
|
||||
"text": "Hello, how are you?",
|
||||
"voice": "random",
|
||||
"output_path": "data/tests",
|
||||
"preset": "fast"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {"detail": "Incorrect username or password"}
|
||||
assert response.status_code == 401
|
||||
Loading…
Reference in a new issue