mirror of
https://github.com/neonbjb/tortoise-tts.git
synced 2026-01-25 01:50:31 +01:00
stable api docker app, fixed bug with api token. stable tests
This commit is contained in:
parent
907cf247df
commit
c469c00976
199
app/main.py
199
app/main.py
|
|
@ -3,48 +3,62 @@ import sys
|
|||
import uuid
|
||||
import asyncio
|
||||
import dotenv
|
||||
import logging
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
import concurrent.futures
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
from beartype import beartype
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from app.models.request import TranscriptionRequest
|
||||
from app.models.tts import TTSArgs
|
||||
from app.services.oauth import get_current_user
|
||||
from app.utils import pick_max_worker_function
|
||||
from app.services.auth import get_current_username
|
||||
# from app.services.oauth import get_current_user
|
||||
from app.utils import pick_max_worker
|
||||
from app.services.auth import get_current_user, verify_user, ACCESS_TOKEN_EXPIRE_MINUTES, create_access_token, timedelta
|
||||
|
||||
from tortoise.utils.audio import BUILTIN_VOICES_DIR
|
||||
from tortoise.do_tts import main as tts_main
|
||||
from tortoise.do_tts import _initialized_tts, infer_voice
|
||||
|
||||
load_envar = dotenv.load_dotenv()
|
||||
assert load_envar and os.getenv("DEFAULT_USERNAME"), "Missing environment variables at .env"
|
||||
|
||||
# Environment-specific variable to skip initialization during testing
|
||||
IS_TESTING = os.getenv("TESTING", "False").lower() in ("true", "1")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, # Set the logging level to DEBUG
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format
|
||||
handlers=[
|
||||
logging.FileHandler("app_debug.log"), # Log to a file named `app_debug.log`
|
||||
logging.StreamHandler() # Also log to console
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
dotenv.load_dotenv()
|
||||
# Create the ThreadPoolExecutor with the determined number of max workers
|
||||
max_workers = pick_max_worker_function()
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=int(os.getenv("MAX_WORKERS", pick_max_worker())))
|
||||
TASK_TIMEOUT = int(os.getenv("TASK_TIMEOUT", 60*30)) # 30 minutes
|
||||
|
||||
# Dictionary to store task information
|
||||
tasks = {}
|
||||
|
||||
@beartype
|
||||
def local_inference_tts(args: TTSArgs):
|
||||
"""
|
||||
Run the TTS directly using the `main` function from `tortoise/do_tts.py`.
|
||||
Args:
|
||||
- args (TTSArgs): The arguments to pass to the TTS function.
|
||||
Returns:
|
||||
- str: Path to the output audio file.
|
||||
"""
|
||||
output_path = tts_main(args)
|
||||
def local_inference_tts(tts, args):
|
||||
output_path = infer_voice(tts, args)
|
||||
return output_path
|
||||
|
||||
async def process_requests():
|
||||
async def process_requests(tts):
|
||||
while True:
|
||||
task_id, (args, future) = await fifo_queue.get() # Wait for a request from the queue
|
||||
task_id, (args, future) = await fifo_queue.get()
|
||||
try:
|
||||
tasks[task_id]['status'] = 'in_progress'
|
||||
output_path = await asyncio.get_event_loop().run_in_executor(executor, local_inference_tts, args)
|
||||
output_path = await asyncio.get_event_loop().run_in_executor(executor, local_inference_tts, tts, args)
|
||||
future.set_result(output_path)
|
||||
tasks[task_id]['status'] = 'completed'
|
||||
tasks[task_id]['result'] = output_path
|
||||
|
|
@ -53,34 +67,78 @@ async def process_requests():
|
|||
tasks[task_id]['status'] = 'failed'
|
||||
tasks[task_id]['error'] = str(e)
|
||||
finally:
|
||||
fifo_queue.task_done() # Indicate that the request has been processed
|
||||
fifo_queue.task_done()
|
||||
|
||||
async def post_initialization_event():
|
||||
try:
|
||||
request = TranscriptionRequest(
|
||||
text="Initialized! World",
|
||||
voice="random", preset="ultra_fast"
|
||||
)
|
||||
response = await text_to_speech(request)
|
||||
print("Initialization TTS Response:", response)
|
||||
except Exception as e:
|
||||
print("Error during initialization TTS:", str(e))
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global fifo_queue
|
||||
fifo_queue = asyncio.Queue() # Initialize the queue within the context
|
||||
# Start the request processing task
|
||||
task = asyncio.create_task(process_requests())
|
||||
yield
|
||||
# Clean up the task on shutdown
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Assign the lifespan context manager to the app
|
||||
if not IS_TESTING:
|
||||
args = TTSArgs(text="")
|
||||
tts = _initialized_tts(args)
|
||||
task = asyncio.create_task(process_requests(tts)) # Start the request processing task
|
||||
await post_initialization_event() # Post initialization event
|
||||
else:
|
||||
print(f"Skipping initialization due to TESTING={IS_TESTING}")
|
||||
task = None
|
||||
yield
|
||||
# Graceful shutdown
|
||||
if task:
|
||||
await fifo_queue.join() # Wait for the queue to empty
|
||||
task.cancel() # Cancel the task to exit the loop
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Shut down the executor
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.get("/")
|
||||
async def home():
|
||||
return JSONResponse(content={"message": "Hello, FiCast-TTS! Check the docs at /docs."})
|
||||
return JSONResponse(content={
|
||||
"message": "Hello, FiCast-TTS! Check the docs at /docs."})
|
||||
|
||||
@app.post("/login")
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
if verify_user(form_data.username, form_data.password):
|
||||
try:
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": form_data.username}, expires_delta=access_token_expires
|
||||
)
|
||||
print({"sub": form_data.username})
|
||||
print(access_token)
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"user": form_data.username
|
||||
}
|
||||
except:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Unable to create access token")
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Incorrect username or password")
|
||||
|
||||
@app.get("/voices")
|
||||
async def available_voices():
|
||||
return JSONResponse(content={"voices": os.listdir(BUILTIN_VOICES_DIR)})
|
||||
|
||||
@app.post("/tts", dependencies=[Depends(get_current_username)])
|
||||
@app.post("/tts", dependencies=[Depends(get_current_user)])
|
||||
async def text_to_speech(request: TranscriptionRequest):
|
||||
try:
|
||||
args = TTSArgs(
|
||||
|
|
@ -89,54 +147,71 @@ async def text_to_speech(request: TranscriptionRequest):
|
|||
preset=request.preset
|
||||
)
|
||||
|
||||
# Use a future to get the result of the inference
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
# Generate a unique task ID
|
||||
task_id = str(uuid.uuid4())
|
||||
await fifo_queue.put((task_id, (args, future)))
|
||||
|
||||
# Store task information
|
||||
tasks[task_id] = {
|
||||
'status': 'queued',
|
||||
'request': request,
|
||||
'result': None,
|
||||
'error': None
|
||||
}
|
||||
|
||||
# Await the result of the future
|
||||
output_path = await future
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.isfile(output_path):
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {output_path}")
|
||||
|
||||
return FileResponse(output_path, media_type='audio/wav', filename=os.path.basename(output_path))
|
||||
|
||||
return {"task_id": task_id, "status": "queued"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/queue-status", dependencies=[Depends(get_current_username)])
|
||||
|
||||
@app.get("/queue-status", dependencies=[Depends(get_current_user)])
|
||||
async def queue_status():
|
||||
"""
|
||||
Endpoint to get the current status of the queue.
|
||||
"""
|
||||
return {
|
||||
"queue_length": fifo_queue.qsize(),
|
||||
"tasks": tasks
|
||||
}
|
||||
try:
|
||||
return {
|
||||
"queue_length": fifo_queue.qsize(),
|
||||
"tasks": tasks
|
||||
}
|
||||
except HTTPException as e:
|
||||
return JSONResponse(
|
||||
status_code=e.status_code,
|
||||
content={"detail": e.detail, "headers": e.headers, "error": str(e)},
|
||||
)
|
||||
|
||||
@app.get("/task-status/{task_id}", dependencies=[Depends(get_current_username)])
|
||||
@app.get("/task-status/{task_id}", dependencies=[Depends(get_current_user)])
|
||||
async def task_status(task_id: str):
|
||||
"""
|
||||
Endpoint to get the status of a specific task.
|
||||
"""
|
||||
task = tasks.get(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
|
||||
@app.get("/task-result/{task_id}", dependencies=[Depends(get_current_user)])
|
||||
async def wait_for_result(task_id: str):
|
||||
try:
|
||||
task = tasks.get(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
start_time = time.time()
|
||||
while task["status"] != "completed":
|
||||
if task["status"] == "failed":
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Task failed: {task.get('error', 'Unknown error')}")
|
||||
elif time.time() - start_time > TASK_TIMEOUT:
|
||||
raise HTTPException(status_code=408, detail="Task did not complete within the timeout")
|
||||
await asyncio.sleep(5)
|
||||
task = tasks.get(task_id)
|
||||
|
||||
output_path = task["result"]
|
||||
if not os.path.isfile(output_path):
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {output_path}")
|
||||
# Expected output
|
||||
return FileResponse(
|
||||
output_path,
|
||||
filename=os.path.basename(output_path),
|
||||
media_type="audio/wav")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import uvicorn
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('text', type=str, help='Text to speak. This argument is required.')
|
||||
|
|
@ -152,12 +227,14 @@ if __name__ == "__main__":
|
|||
voice=args.voice,
|
||||
preset=args.preset,
|
||||
)
|
||||
|
||||
tts = _initialized_tts(tts_args)
|
||||
try:
|
||||
output_path = local_inference_tts(tts_args)
|
||||
output_path = local_inference_tts(tts, tts_args)
|
||||
print(f"Output stored at: {output_path}")
|
||||
return output_path
|
||||
except Exception as e:
|
||||
print(f"Error during TTS generation: {str(e)}")
|
||||
sys.exit(1)
|
||||
main()
|
||||
|
||||
uvicorn.run(
|
||||
app, host="0.0.0.0", port=42110, log_level="debug")
|
||||
|
|
|
|||
|
|
@ -1,19 +1,115 @@
|
|||
import base64
|
||||
import os
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
import secrets
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, Security, status
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials, APIKeyHeader
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format
|
||||
handlers=[
|
||||
logging.FileHandler("app_debug.log"),
|
||||
logging.StreamHandler() # Also log to console
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Secret key to encode and decode JWT tokens
|
||||
DEFAULT_SECRET_KEY = str(os.getenv("DEFAULT_SECRET_KEY", "fek3kz9xzlsndSuczhgjds0vndi"))
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 24*60
|
||||
|
||||
security = HTTPBasic()
|
||||
api_key_header = APIKeyHeader(
|
||||
name="Authorization", auto_error=False)
|
||||
|
||||
def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
assert os.getenv("TEST_USERNAME")
|
||||
assert os.getenv("TEST_PASSWORD")
|
||||
correct_username = secrets.compare_digest(credentials.username, os.getenv("TEST_USERNAME"))
|
||||
correct_password = secrets.compare_digest(credentials.password, os.getenv("TEST_PASSWORD"))
|
||||
if not (correct_username and correct_password):
|
||||
|
||||
def verify_user(username: str, password: str):
|
||||
user = os.getenv("DEFAULT_USERNAME")
|
||||
if user and secrets.compare_digest(os.getenv("DEFAULT_PASSWORD"), password):
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, DEFAULT_SECRET_KEY, algorithm=ALGORITHM)
|
||||
print(f"Encoded JWT: {encoded_jwt}")
|
||||
return encoded_jwt
|
||||
|
||||
async def get_current_user(
|
||||
request: Request, # Access the full request to manually check headers
|
||||
authorization: Optional[str] = Security(api_key_header)
|
||||
):
|
||||
logger.debug("Attempting to authenticate user...")
|
||||
# First, try to authenticate using Basic Authentication
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Basic "):
|
||||
print("Using Basic Auth...")
|
||||
auth = auth_header.split(" ")[1]
|
||||
credentials = base64.b64decode(auth).decode("utf-8").split(":")
|
||||
username = credentials[0]
|
||||
password = credentials[1]
|
||||
|
||||
correct_username = secrets.compare_digest(username, os.getenv("DEFAULT_USERNAME"))
|
||||
correct_password = secrets.compare_digest(password, os.getenv("DEFAULT_PASSWORD"))
|
||||
|
||||
if correct_username and correct_password:
|
||||
logger.debug("Basic auth successful.")
|
||||
return {"auth": "basic", "user": username}
|
||||
else:
|
||||
logger.debug("Basic auth failed.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
|
||||
# If Basic Auth is not provided, try to authenticate using API Key (JWT)
|
||||
elif authorization:
|
||||
print("Using API Key...")
|
||||
logger.debug(f"Authorization header received: {authorization[:5]}...")
|
||||
scheme, _, token = authorization.partition(" ")
|
||||
|
||||
if scheme.lower() != "bearer":
|
||||
logger.debug("Invalid authentication scheme.")
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid authentication scheme")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, DEFAULT_SECRET_KEY,
|
||||
algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
|
||||
if username is None:
|
||||
logger.debug("JWT payload does not contain a username.")
|
||||
raise HTTPException(status_code=401, detail="Could not validate credentials")
|
||||
|
||||
logger.debug("JWT validation successful.")
|
||||
return {"auth": "api_key", "user": username}
|
||||
|
||||
except jwt.PyJWTError as e:
|
||||
logger.debug(f"JWT validation failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Could not validate credentials")
|
||||
|
||||
else:
|
||||
logger.debug("No credentials provided.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
return credentials.username
|
||||
)
|
||||
|
|
@ -18,7 +18,6 @@ def get_authenticated_service():
|
|||
CLIENT_SECRETS_FILE, SCOPES,
|
||||
redirect_uri=REDIRECT_URI
|
||||
)
|
||||
|
||||
# Tell the user to go to the authorization URL.
|
||||
auth_url, _ = flow.authorization_url(prompt='consent')
|
||||
print('Please go to this URL: {}'.format(auth_url))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
def pick_max_worker_function():
|
||||
def pick_max_worker():
|
||||
# Determine the number of CPU cores
|
||||
cpu_count = os.cpu_count()
|
||||
if cpu_count is None:
|
||||
|
|
|
|||
|
|
@ -1,10 +0,0 @@
|
|||
curl -X POST "http://127.0.0.1:42110/transcribe" \
|
||||
-u ficast-uzer:ficast-testpazz \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"text": "Hello, how are you?",
|
||||
"voice": "random",
|
||||
"preset": "ultra_fast"
|
||||
}' \
|
||||
-o data/samples/api-output.wav
|
||||
|
||||
9
bin/docker-run_app.sh
Normal file
9
bin/docker-run_app.sh
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
image="tts:app"
|
||||
docker run --gpus all \
|
||||
-e TORTOISE_MODELS_DIR=/models \
|
||||
-v "${PWD}/data/models":/models \
|
||||
-v "${PWD}/data/results":/results \
|
||||
-v "${PWD}/data/.cache/huggingface":/root/.cache/huggingface \
|
||||
-p 42110:42110 \
|
||||
--name tts-api \
|
||||
-it $image --port 42110 --host 0.0.0.0
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
version: 1
|
||||
disable_existing_loggers: False
|
||||
|
||||
formatters:
|
||||
standard:
|
||||
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
level: DEBUG
|
||||
formatter: standard
|
||||
stream: ext://sys.stdout
|
||||
|
||||
loggers:
|
||||
app:
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
propagate: no
|
||||
|
||||
tortoise:
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
propagate: no
|
||||
|
||||
root:
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
|
|
@ -28,4 +28,6 @@ uvicorn==0.30.5
|
|||
google_auth_oauthlib==1.2.1
|
||||
python-dotenv==1.0.1
|
||||
fastapi==0.112.0
|
||||
beartype==0.18.5
|
||||
beartype==0.18.5
|
||||
PyJWT==2.9.0
|
||||
python-multipart==0.0.9
|
||||
11
tests/curl/test-app_TTS-credential.sh
Normal file
11
tests/curl/test-app_TTS-credential.sh
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
response=$(curl -X POST "http://127.0.0.1:42110/tts" \
|
||||
-u $TEST_USERNAME:$TEST_PASSWORD \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"text": "Hello, curl test?",
|
||||
"voice": "random",
|
||||
"preset": "ultra_fast"
|
||||
}')
|
||||
|
||||
TASK_ID=$(echo $response | jq -r .task_id)
|
||||
echo $TASK_ID
|
||||
10
tests/curl/test-app_TTS-token.sh
Normal file
10
tests/curl/test-app_TTS-token.sh
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
response=$(curl -X POST "http://127.0.0.1:42110/tts" \
|
||||
-H "Authorization: Bearer ${ACCESS_TOKEN}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"text": "Hello, curl test?",
|
||||
"voice": "random",
|
||||
"preset": "ultra_fast"
|
||||
}')
|
||||
TASK_ID=$(echo $response | jq -r .task_id)
|
||||
echo $TASK_ID
|
||||
2
tests/curl/test-app_api_key_queue_status.sh
Normal file
2
tests/curl/test-app_api_key_queue_status.sh
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
curl "http://127.0.0.1:42110/queue-status" \
|
||||
-H "Authorization: Bearer ${ACCESS_TOKEN}"
|
||||
6
tests/curl/test-app_login.sh
Normal file
6
tests/curl/test-app_login.sh
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
. tests/.env
|
||||
response=$(curl -X POST "http://127.0.0.1:42110/login" \
|
||||
-d "username=${TEST_USERNAME}&password=${TEST_PASSWORD}" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded")
|
||||
ACCESS_TOKEN=$(echo $response | jq -r .access_token)
|
||||
echo $ACCESS_TOKEN
|
||||
5
tests/curl/test-app_task_result.sh
Normal file
5
tests/curl/test-app_task_result.sh
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
TASK_ID=9ae03d68-9511-4642-b9bb-abb74948af61
|
||||
curl "http://127.0.0.1:42110/task-result/$TASK_ID" \
|
||||
-H "Authorization: Bearer ${ACCESS_TOKEN}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-o data/samples/curl-task-result.wav
|
||||
3
tests/curl/test-app_task_status.sh
Normal file
3
tests/curl/test-app_task_status.sh
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
curl "http://127.0.0.1:42110/task-status/$TASK_ID" \
|
||||
-u $TEST_USERNAME:$TEST_PASSWORD \
|
||||
-H "Content-Type: application/json"
|
||||
|
|
@ -1,124 +0,0 @@
|
|||
import pytest
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import dotenv
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Add the root directory of the repo to sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
# Add the tortoise directory to sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
|
||||
|
||||
from app.main import app, tasks
|
||||
from app.models.request import TranscriptionRequest
|
||||
|
||||
|
||||
load_envvar = dotenv.load_dotenv()
|
||||
assert load_envvar and os.getenv("TEST_USERNAME") and os.getenv("TEST_PASSWORD"), "Missing environment variables"
|
||||
|
||||
# Configure logging to print to the console
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Helper function to simulate authentication
|
||||
def basic_auth(username, password):
|
||||
import base64
|
||||
credentials = f"{username}:{password}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii")
|
||||
return {"Authorization": f"Basic {encoded_credentials}"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcribe(mocker):
|
||||
# Mock the local_inference_tts function to avoid running the actual TTS
|
||||
mocker.patch('app.main.local_inference_tts')
|
||||
request_data = {
|
||||
"text": "Hello, how are you?",
|
||||
"voice": "random",
|
||||
"preset": "ultra_fast"
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/tts",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
|
||||
json=request_data
|
||||
)
|
||||
|
||||
# Log detailed information about the response for debugging
|
||||
logger.debug(f"Response status code: {response.status_code}")
|
||||
logger.debug(f"Response headers: {response.headers}")
|
||||
logger.debug(f"Response content: {response.content}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers['content-type'] == 'audio/wav'
|
||||
# assert response.headers['content-disposition'] == 'attachment; filename=random_0.wav'
|
||||
|
||||
def test_queue_status():
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/queue-status",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
|
||||
# Log detailed information about the response for debugging
|
||||
logger.debug(f"Response status code: {response.status_code}")
|
||||
logger.debug(f"Response headers: {response.headers}")
|
||||
logger.debug(f"Response content: {response.content}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert 'queue_length' in response.json()
|
||||
assert 'tasks' in response.json()
|
||||
|
||||
def test_task_status():
|
||||
# Add a dummy task to the tasks dictionary
|
||||
task_id = 'test-task-id'
|
||||
tasks[task_id] = {
|
||||
'status': 'queued',
|
||||
'request': TranscriptionRequest(
|
||||
text="Hello, how are you?",
|
||||
voice="random",
|
||||
preset="fast"
|
||||
),
|
||||
'result': None,
|
||||
'error': None
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
f"/task-status/{task_id}",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
|
||||
# Log detailed information about the response for debugging
|
||||
logger.debug(f"Response status code: {response.status_code}")
|
||||
logger.debug(f"Response headers: {response.headers}")
|
||||
logger.debug(f"Response content: {response.content}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()['status'] == 'queued'
|
||||
|
||||
# Clean up the tasks dictionary
|
||||
del tasks[task_id]
|
||||
|
||||
def test_failed_authentication():
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/tts",
|
||||
headers=basic_auth("wrong_username", "wrong_password"),
|
||||
json={
|
||||
"text": "Hello, how are you?",
|
||||
"voice": "random",
|
||||
"output_path": "data/tests",
|
||||
"preset": "fast"
|
||||
}
|
||||
)
|
||||
|
||||
# Log detailed information about the response for debugging
|
||||
logger.debug(f"Response status code: {response.status_code}")
|
||||
logger.debug(f"Response headers: {response.headers}")
|
||||
logger.debug(f"Response content: {response.content}")
|
||||
|
||||
assert response.status_code == 401
|
||||
126
tests/test_tts_api.py
Normal file
126
tests/test_tts_api.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
import pytest
|
||||
import os, sys, time
|
||||
import dotenv
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Add the root directory of the repo to sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
|
||||
|
||||
# load environment variables
|
||||
load_envvar = dotenv.load_dotenv('tests/.env', override=True)
|
||||
assert load_envvar and os.getenv("TESTING").lower()=="true", "Missing environment variables"
|
||||
|
||||
from app.main import app
|
||||
client = TestClient(app)
|
||||
|
||||
# Helper function to simulate authentication
|
||||
def basic_auth(username, password):
|
||||
import base64
|
||||
credentials = f"{username}:{password}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii")
|
||||
return {"Authorization": f"Basic {encoded_credentials}"}
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def task_id():
|
||||
with patch('app.main.text_to_speech') as mock_tts:
|
||||
mock_tts.return_value = {
|
||||
"task_id": "mock_task_id", "status": "queued"}
|
||||
request_data = {
|
||||
"text": "Hello, how are you?",
|
||||
"voice": "random",
|
||||
"preset": "ultra_fast"
|
||||
}
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/tts",
|
||||
headers=basic_auth(
|
||||
os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
|
||||
json=request_data
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get("task_id")
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def access_token():
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/login",
|
||||
data={
|
||||
"username": os.getenv("TEST_USERNAME"),
|
||||
"password": os.getenv("TEST_PASSWORD")
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
token = response.json().get("access_token")
|
||||
|
||||
assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
|
||||
assert "access_token" in response.json(), f"Access token missing in response. Response: {response.text}"
|
||||
assert "token_type" in response.json(), f"Token type missing in response. Response: {response.text}"
|
||||
assert response.json()["token_type"] == "bearer", f"Expected token type 'bearer', got {response.json()['token_type']}. Response: {response.text}"
|
||||
return token
|
||||
|
||||
def test_queue_status_api_key(access_token):
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/queue-status",
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
|
||||
assert 'tasks' in response.json(), f"'tasks' key missing in response. Response: {response.json()}"
|
||||
|
||||
def test_queue_status():
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/queue-status",
|
||||
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
|
||||
assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
|
||||
assert 'tasks' in response.json(), f"'tasks' key missing in response. Response: {response.json()}"
|
||||
|
||||
def test_tts_task_creation(task_id):
|
||||
assert task_id is not None, f"Task ID is None. Check the task creation process."
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/queue-status",
|
||||
headers=basic_auth(
|
||||
os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
response.raise_for_status()
|
||||
queue_status = response.json()
|
||||
assert task_id in queue_status["tasks"], f"Task ID {task_id} not found in queue tasks. Queue status: {queue_status}"
|
||||
assert queue_status["tasks"][task_id]["status"] in ["queued", "in_progress", "completed"], f"Expected task status 'queued', got {queue_status['tasks'][task_id]['status']}. Queue status: {queue_status}"
|
||||
|
||||
def test_tts_task_completion(task_id):
|
||||
with patch('app.main.tasks') as mock_tasks:
|
||||
mock_tasks[task_id] = {
|
||||
"status": "completed",
|
||||
"result": {"message": "Task completed"}
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
start_time = time.time()
|
||||
while True:
|
||||
response = client.get(
|
||||
f"/task-status/{task_id}",
|
||||
headers=basic_auth(
|
||||
os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
|
||||
)
|
||||
if response.status_code == 200 and response.json().get("status") == "completed":
|
||||
break
|
||||
elif response.status_code == 500:
|
||||
raise AssertionError(f"Task failed: {response.json().get('detail')}. Response: {response.text}")
|
||||
elif time.time() - start_time > 60:
|
||||
raise AssertionError(f"Task did not complete within 1 minute. Task ID: {task_id}")
|
||||
time.sleep(5)
|
||||
|
||||
# Verify that the task is completed
|
||||
assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
|
||||
assert response.json()["status"] == "completed", f"Expected task status 'completed', got {response.json()['status']}. Response: {response.json()}"
|
||||
assert response.json()["result"]["message"] == "Task completed", f"Expected result message 'Task completed', got {response.json()['result']['message']}. Response: {response.json()}"
|
||||
|
||||
|
|
@ -40,10 +40,14 @@ MODELS = {
|
|||
def get_model_path(model_name, models_dir=MODELS_DIR):
|
||||
"""
|
||||
Get path to given model, download it if it doesn't exist.
|
||||
Uses the cache provided by HuggingFace's hf_hub_download.
|
||||
"""
|
||||
if model_name not in MODELS:
|
||||
raise ValueError(f'Model {model_name} not found in available models.')
|
||||
model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir)
|
||||
# hf_hub_download will automatically check the cache
|
||||
model_path = hf_hub_download(
|
||||
repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir
|
||||
)
|
||||
return model_path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -137,7 +137,10 @@ def classify_audio_clip(clip):
|
|||
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
||||
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
||||
dropout=0, kernel_size=5, distribute_zero_label=False)
|
||||
classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu')))
|
||||
# Use the get_model_path function to get the model's path
|
||||
model_path = get_model_path('classifier.pth')
|
||||
# Load the model state dictionary using the path returned
|
||||
classifier.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
||||
clip = clip.cpu().unsqueeze(0)
|
||||
results = F.softmax(classifier(clip), dim=-1)
|
||||
return results[0][0]
|
||||
|
|
|
|||
|
|
@ -16,20 +16,15 @@ from tortoise.utils.audio import load_voices
|
|||
|
||||
# Configure logging to print to the console
|
||||
# Load logging configuration
|
||||
with open("logging_conf.yml", 'r') as file:
|
||||
config = yaml.safe_load(file.read())
|
||||
logging.config.dictConfig(config)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main(args):
|
||||
# if torch.backends.mps.is_available():
|
||||
# args.use_deepspeed = False
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
if not args.autoregressive_batch_size:
|
||||
args.autoregressive_batch_size = pick_best_batch_size_for_gpu()
|
||||
def _initialized_tts(args):
|
||||
tts = TextToSpeech(
|
||||
models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
|
||||
models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache,
|
||||
half=args.half)
|
||||
return tts
|
||||
|
||||
def infer_voice(tts: TextToSpeech, args: argparse.Namespace):
|
||||
selected_voices = args.voice.split(',')
|
||||
for k, selected_voice in tqdm(enumerate(selected_voices), desc="generating using selected voice"):
|
||||
if '&' in selected_voice:
|
||||
|
|
@ -57,6 +52,15 @@ def main(args):
|
|||
os.makedirs('debug_states', exist_ok=True)
|
||||
torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth')
|
||||
return output_path
|
||||
|
||||
def main(args):
|
||||
if torch.backends.mps.is_available():
|
||||
args.use_deepspeed = False
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
if not args.autoregressive_batch_size:
|
||||
args.autoregressive_batch_size = pick_best_batch_size_for_gpu()
|
||||
tts = _initialized_tts(args)
|
||||
return infer_voice(tts, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue