tests passed with API server. stable api docker application

This commit is contained in:
supermomo668 2024-08-08 00:23:31 +00:00
parent 7a44f8ffda
commit 907cf247df
2 changed files with 27 additions and 18 deletions

View file

@ -78,7 +78,7 @@ async def home():
@app.get("/voices")
async def available_voices():
return JSONResponse(content={"message": os.listdir(BUILTIN_VOICES_DIR)})
return JSONResponse(content={"voices": os.listdir(BUILTIN_VOICES_DIR)})
@app.post("/tts", dependencies=[Depends(get_current_username)])
async def text_to_speech(request: TranscriptionRequest):
@ -86,7 +86,6 @@ async def text_to_speech(request: TranscriptionRequest):
args = TTSArgs(
text=request.text,
voice=request.voice,
output_path=request.output_path,
preset=request.preset
)

View file

@ -2,7 +2,7 @@ import pytest
import os
import sys
import logging
import asyncio
import dotenv
from fastapi.testclient import TestClient
# Add the root directory of the repo to sys.path
@ -13,8 +13,9 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..',
from app.main import app, tasks
from app.models.request import TranscriptionRequest
import dotenv
dotenv.load_dotenv()
load_envvar = dotenv.load_dotenv()
assert load_envvar and os.getenv("TEST_USERNAME") and os.getenv("TEST_PASSWORD"), "Missing environment variables"
# Configure logging to print to the console
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@ -22,16 +23,6 @@ logger = logging.getLogger(__name__)
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_fifo_queue():
global fifo_queue
fifo_queue = asyncio.Queue()
@pytest.fixture(scope="module")
def mocker():
from pytest_mock import mocker
return mocker
# Helper function to simulate authentication
def basic_auth(username, password):
import base64
@ -42,11 +33,10 @@ def basic_auth(username, password):
@pytest.mark.asyncio
async def test_transcribe(mocker):
# Mock the local_inference_tts function to avoid running the actual TTS
mocker.patch('app.main.local_inference_tts', return_value='data/tests/random_0.wav')
mocker.patch('app.main.local_inference_tts')
request_data = {
"text": "Hello, how are you?",
"voice": "random",
"output_path": "data/tests",
"preset": "ultra_fast"
}
@ -56,6 +46,11 @@ async def test_transcribe(mocker):
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
json=request_data
)
# Log detailed information about the response for debugging
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response content: {response.content}")
assert response.status_code == 200
assert response.headers['content-type'] == 'audio/wav'
@ -68,6 +63,11 @@ def test_queue_status():
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
# Log detailed information about the response for debugging
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response content: {response.content}")
assert response.status_code == 200
assert 'queue_length' in response.json()
assert 'tasks' in response.json()
@ -92,6 +92,11 @@ def test_task_status():
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
# Log detailed information about the response for debugging
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response content: {response.content}")
assert response.status_code == 200
assert response.json()['status'] == 'queued'
@ -101,7 +106,7 @@ def test_task_status():
def test_failed_authentication():
with TestClient(app) as client:
response = client.post(
"/transcribe",
"/tts",
headers=basic_auth("wrong_username", "wrong_password"),
json={
"text": "Hello, how are you?",
@ -111,4 +116,9 @@ def test_failed_authentication():
}
)
# Log detailed information about the response for debugging
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response content: {response.content}")
assert response.status_code == 401