stable api docker app, fixed bug with api token. stable tests

This commit is contained in:
supermomo668 2024-08-09 08:00:07 +00:00
parent 907cf247df
commit c469c00976
19 changed files with 444 additions and 249 deletions

View file

@ -3,48 +3,62 @@ import sys
import uuid
import asyncio
import dotenv
import logging
import time
from datetime import timedelta
import concurrent.futures
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.responses import FileResponse, JSONResponse
from contextlib import asynccontextmanager
from beartype import beartype
from fastapi.security import OAuth2PasswordRequestForm
from app.models.request import TranscriptionRequest
from app.models.tts import TTSArgs
from app.services.oauth import get_current_user
from app.utils import pick_max_worker_function
from app.services.auth import get_current_username
# from app.services.oauth import get_current_user
from app.utils import pick_max_worker
from app.services.auth import get_current_user, verify_user, ACCESS_TOKEN_EXPIRE_MINUTES, create_access_token, timedelta
from tortoise.utils.audio import BUILTIN_VOICES_DIR
from tortoise.do_tts import main as tts_main
from tortoise.do_tts import _initialized_tts, infer_voice
load_envar = dotenv.load_dotenv()
assert load_envar and os.getenv("DEFAULT_USERNAME"), "Missing environment variables at .env"
# Environment-specific variable to skip initialization during testing
IS_TESTING = os.getenv("TESTING", "False").lower() in ("true", "1")
# Configure logging
logging.basicConfig(
level=logging.DEBUG, # Set the logging level to DEBUG
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format
handlers=[
logging.FileHandler("app_debug.log"), # Log to a file named `app_debug.log`
logging.StreamHandler() # Also log to console
]
)
logger = logging.getLogger(__name__)
dotenv.load_dotenv()
# Create the ThreadPoolExecutor with the determined number of max workers
max_workers = pick_max_worker_function()
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
executor = concurrent.futures.ThreadPoolExecutor(
max_workers=int(os.getenv("MAX_WORKERS", pick_max_worker())))
TASK_TIMEOUT = int(os.getenv("TASK_TIMEOUT", 60*30)) # 30 minutes
# Dictionary to store task information
tasks = {}
@beartype
def local_inference_tts(args: TTSArgs):
"""
Run the TTS directly using the `main` function from `tortoise/do_tts.py`.
Args:
- args (TTSArgs): The arguments to pass to the TTS function.
Returns:
- str: Path to the output audio file.
"""
output_path = tts_main(args)
def local_inference_tts(tts, args):
output_path = infer_voice(tts, args)
return output_path
async def process_requests():
async def process_requests(tts):
while True:
task_id, (args, future) = await fifo_queue.get() # Wait for a request from the queue
task_id, (args, future) = await fifo_queue.get()
try:
tasks[task_id]['status'] = 'in_progress'
output_path = await asyncio.get_event_loop().run_in_executor(executor, local_inference_tts, args)
output_path = await asyncio.get_event_loop().run_in_executor(executor, local_inference_tts, tts, args)
future.set_result(output_path)
tasks[task_id]['status'] = 'completed'
tasks[task_id]['result'] = output_path
@ -53,34 +67,78 @@ async def process_requests():
tasks[task_id]['status'] = 'failed'
tasks[task_id]['error'] = str(e)
finally:
fifo_queue.task_done() # Indicate that the request has been processed
fifo_queue.task_done()
async def post_initialization_event():
try:
request = TranscriptionRequest(
text="Initialized! World",
voice="random", preset="ultra_fast"
)
response = await text_to_speech(request)
print("Initialization TTS Response:", response)
except Exception as e:
print("Error during initialization TTS:", str(e))
@asynccontextmanager
async def lifespan(app: FastAPI):
global fifo_queue
fifo_queue = asyncio.Queue() # Initialize the queue within the context
# Start the request processing task
task = asyncio.create_task(process_requests())
yield
# Clean up the task on shutdown
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Assign the lifespan context manager to the app
if not IS_TESTING:
args = TTSArgs(text="")
tts = _initialized_tts(args)
task = asyncio.create_task(process_requests(tts)) # Start the request processing task
await post_initialization_event() # Post initialization event
else:
print(f"Skipping initialization due to TESTING={IS_TESTING}")
task = None
yield
# Graceful shutdown
if task:
await fifo_queue.join() # Wait for the queue to empty
task.cancel() # Cancel the task to exit the loop
try:
await task
except asyncio.CancelledError:
pass
# Shut down the executor
executor.shutdown(wait=True)
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def home():
return JSONResponse(content={"message": "Hello, FiCast-TTS! Check the docs at /docs."})
return JSONResponse(content={
"message": "Hello, FiCast-TTS! Check the docs at /docs."})
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
if verify_user(form_data.username, form_data.password):
try:
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": form_data.username}, expires_delta=access_token_expires
)
print({"sub": form_data.username})
print(access_token)
return {
"access_token": access_token,
"token_type": "bearer",
"user": form_data.username
}
except:
raise HTTPException(
status_code=401, detail="Unable to create access token")
raise HTTPException(
status_code=401, detail="Incorrect username or password")
@app.get("/voices")
async def available_voices():
return JSONResponse(content={"voices": os.listdir(BUILTIN_VOICES_DIR)})
@app.post("/tts", dependencies=[Depends(get_current_username)])
@app.post("/tts", dependencies=[Depends(get_current_user)])
async def text_to_speech(request: TranscriptionRequest):
try:
args = TTSArgs(
@ -89,54 +147,71 @@ async def text_to_speech(request: TranscriptionRequest):
preset=request.preset
)
# Use a future to get the result of the inference
future = asyncio.get_event_loop().create_future()
# Generate a unique task ID
task_id = str(uuid.uuid4())
await fifo_queue.put((task_id, (args, future)))
# Store task information
tasks[task_id] = {
'status': 'queued',
'request': request,
'result': None,
'error': None
}
# Await the result of the future
output_path = await future
# Check if file exists
if not os.path.isfile(output_path):
raise HTTPException(status_code=404, detail=f"File not found: {output_path}")
return FileResponse(output_path, media_type='audio/wav', filename=os.path.basename(output_path))
return {"task_id": task_id, "status": "queued"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/queue-status", dependencies=[Depends(get_current_username)])
@app.get("/queue-status", dependencies=[Depends(get_current_user)])
async def queue_status():
"""
Endpoint to get the current status of the queue.
"""
return {
"queue_length": fifo_queue.qsize(),
"tasks": tasks
}
try:
return {
"queue_length": fifo_queue.qsize(),
"tasks": tasks
}
except HTTPException as e:
return JSONResponse(
status_code=e.status_code,
content={"detail": e.detail, "headers": e.headers, "error": str(e)},
)
@app.get("/task-status/{task_id}", dependencies=[Depends(get_current_username)])
@app.get("/task-status/{task_id}", dependencies=[Depends(get_current_user)])
async def task_status(task_id: str):
"""
Endpoint to get the status of a specific task.
"""
task = tasks.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@app.get("/task-result/{task_id}", dependencies=[Depends(get_current_user)])
async def wait_for_result(task_id: str):
try:
task = tasks.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
start_time = time.time()
while task["status"] != "completed":
if task["status"] == "failed":
raise HTTPException(
status_code=500, detail=f"Task failed: {task.get('error', 'Unknown error')}")
elif time.time() - start_time > TASK_TIMEOUT:
raise HTTPException(status_code=408, detail="Task did not complete within the timeout")
await asyncio.sleep(5)
task = tasks.get(task_id)
output_path = task["result"]
if not os.path.isfile(output_path):
raise HTTPException(status_code=404, detail=f"File not found: {output_path}")
# Expected output
return FileResponse(
output_path,
filename=os.path.basename(output_path),
media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import argparse
import uvicorn
def main():
parser = argparse.ArgumentParser()
parser.add_argument('text', type=str, help='Text to speak. This argument is required.')
@ -152,12 +227,14 @@ if __name__ == "__main__":
voice=args.voice,
preset=args.preset,
)
tts = _initialized_tts(tts_args)
try:
output_path = local_inference_tts(tts_args)
output_path = local_inference_tts(tts, tts_args)
print(f"Output stored at: {output_path}")
return output_path
except Exception as e:
print(f"Error during TTS generation: {str(e)}")
sys.exit(1)
main()
uvicorn.run(
app, host="0.0.0.0", port=42110, log_level="debug")

View file

@ -1,19 +1,115 @@
import base64
import os
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
import secrets
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
import jwt
from fastapi import Depends, HTTPException, Request, Security, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials, APIKeyHeader
# Setup logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format
handlers=[
logging.FileHandler("app_debug.log"),
logging.StreamHandler() # Also log to console
]
)
logger = logging.getLogger(__name__)
# Secret key to encode and decode JWT tokens
DEFAULT_SECRET_KEY = str(os.getenv("DEFAULT_SECRET_KEY", "fek3kz9xzlsndSuczhgjds0vndi"))
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 24*60
security = HTTPBasic()
api_key_header = APIKeyHeader(
name="Authorization", auto_error=False)
def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
assert os.getenv("TEST_USERNAME")
assert os.getenv("TEST_PASSWORD")
correct_username = secrets.compare_digest(credentials.username, os.getenv("TEST_USERNAME"))
correct_password = secrets.compare_digest(credentials.password, os.getenv("TEST_PASSWORD"))
if not (correct_username and correct_password):
def verify_user(username: str, password: str):
user = os.getenv("DEFAULT_USERNAME")
if user and secrets.compare_digest(os.getenv("DEFAULT_PASSWORD"), password):
return True
return False
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode, DEFAULT_SECRET_KEY, algorithm=ALGORITHM)
print(f"Encoded JWT: {encoded_jwt}")
return encoded_jwt
async def get_current_user(
request: Request, # Access the full request to manually check headers
authorization: Optional[str] = Security(api_key_header)
):
logger.debug("Attempting to authenticate user...")
# First, try to authenticate using Basic Authentication
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Basic "):
print("Using Basic Auth...")
auth = auth_header.split(" ")[1]
credentials = base64.b64decode(auth).decode("utf-8").split(":")
username = credentials[0]
password = credentials[1]
correct_username = secrets.compare_digest(username, os.getenv("DEFAULT_USERNAME"))
correct_password = secrets.compare_digest(password, os.getenv("DEFAULT_PASSWORD"))
if correct_username and correct_password:
logger.debug("Basic auth successful.")
return {"auth": "basic", "user": username}
else:
logger.debug("Basic auth failed.")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
# If Basic Auth is not provided, try to authenticate using API Key (JWT)
elif authorization:
print("Using API Key...")
logger.debug(f"Authorization header received: {authorization[:5]}...")
scheme, _, token = authorization.partition(" ")
if scheme.lower() != "bearer":
logger.debug("Invalid authentication scheme.")
raise HTTPException(
status_code=401, detail="Invalid authentication scheme")
try:
payload = jwt.decode(
token, DEFAULT_SECRET_KEY,
algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
logger.debug("JWT payload does not contain a username.")
raise HTTPException(status_code=401, detail="Could not validate credentials")
logger.debug("JWT validation successful.")
return {"auth": "api_key", "user": username}
except jwt.PyJWTError as e:
logger.debug(f"JWT validation failed: {e}")
raise HTTPException(
status_code=401, detail="Could not validate credentials")
else:
logger.debug("No credentials provided.")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username
)

View file

@ -18,7 +18,6 @@ def get_authenticated_service():
CLIENT_SECRETS_FILE, SCOPES,
redirect_uri=REDIRECT_URI
)
# Tell the user to go to the authorization URL.
auth_url, _ = flow.authorization_url(prompt='consent')
print('Please go to this URL: {}'.format(auth_url))

View file

@ -1,6 +1,6 @@
import os
def pick_max_worker_function():
def pick_max_worker():
# Determine the number of CPU cores
cpu_count = os.cpu_count()
if cpu_count is None:

View file

@ -1,10 +0,0 @@
curl -X POST "http://127.0.0.1:42110/transcribe" \
-u ficast-uzer:ficast-testpazz \
-H "Content-Type: application/json" \
-d '{
"text": "Hello, how are you?",
"voice": "random",
"preset": "ultra_fast"
}' \
-o data/samples/api-output.wav

9
bin/docker-run_app.sh Normal file
View file

@ -0,0 +1,9 @@
image="tts:app"
docker run --gpus all \
-e TORTOISE_MODELS_DIR=/models \
-v "${PWD}/data/models":/models \
-v "${PWD}/data/results":/results \
-v "${PWD}/data/.cache/huggingface":/root/.cache/huggingface \
-p 42110:42110 \
--name tts-api \
-it $image --port 42110 --host 0.0.0.0

View file

@ -1,28 +0,0 @@
version: 1
disable_existing_loggers: False
formatters:
standard:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: standard
stream: ext://sys.stdout
loggers:
app:
level: DEBUG
handlers: [console]
propagate: no
tortoise:
level: DEBUG
handlers: [console]
propagate: no
root:
level: DEBUG
handlers: [console]

View file

@ -28,4 +28,6 @@ uvicorn==0.30.5
google_auth_oauthlib==1.2.1
python-dotenv==1.0.1
fastapi==0.112.0
beartype==0.18.5
beartype==0.18.5
PyJWT==2.9.0
python-multipart==0.0.9

View file

@ -0,0 +1,11 @@
response=$(curl -X POST "http://127.0.0.1:42110/tts" \
-u $TEST_USERNAME:$TEST_PASSWORD \
-H "Content-Type: application/json" \
-d '{
"text": "Hello, curl test?",
"voice": "random",
"preset": "ultra_fast"
}')
TASK_ID=$(echo $response | jq -r .task_id)
echo $TASK_ID

View file

@ -0,0 +1,10 @@
response=$(curl -X POST "http://127.0.0.1:42110/tts" \
-H "Authorization: Bearer ${ACCESS_TOKEN}" \
-H "Content-Type: application/json" \
-d '{
"text": "Hello, curl test?",
"voice": "random",
"preset": "ultra_fast"
}')
TASK_ID=$(echo $response | jq -r .task_id)
echo $TASK_ID

View file

@ -0,0 +1,2 @@
curl "http://127.0.0.1:42110/queue-status" \
-H "Authorization: Bearer ${ACCESS_TOKEN}"

View file

@ -0,0 +1,6 @@
. tests/.env
response=$(curl -X POST "http://127.0.0.1:42110/login" \
-d "username=${TEST_USERNAME}&password=${TEST_PASSWORD}" \
-H "Content-Type: application/x-www-form-urlencoded")
ACCESS_TOKEN=$(echo $response | jq -r .access_token)
echo $ACCESS_TOKEN

View file

@ -0,0 +1,5 @@
TASK_ID=9ae03d68-9511-4642-b9bb-abb74948af61
curl "http://127.0.0.1:42110/task-result/$TASK_ID" \
-H "Authorization: Bearer ${ACCESS_TOKEN}" \
-H "Content-Type: application/json" \
-o data/samples/curl-task-result.wav

View file

@ -0,0 +1,3 @@
curl "http://127.0.0.1:42110/task-status/$TASK_ID" \
-u $TEST_USERNAME:$TEST_PASSWORD \
-H "Content-Type: application/json"

View file

@ -1,124 +0,0 @@
import pytest
import os
import sys
import logging
import dotenv
from fastapi.testclient import TestClient
# Add the root directory of the repo to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Add the tortoise directory to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
from app.main import app, tasks
from app.models.request import TranscriptionRequest
load_envvar = dotenv.load_dotenv()
assert load_envvar and os.getenv("TEST_USERNAME") and os.getenv("TEST_PASSWORD"), "Missing environment variables"
# Configure logging to print to the console
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
client = TestClient(app)
# Helper function to simulate authentication
def basic_auth(username, password):
import base64
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii")
return {"Authorization": f"Basic {encoded_credentials}"}
@pytest.mark.asyncio
async def test_transcribe(mocker):
# Mock the local_inference_tts function to avoid running the actual TTS
mocker.patch('app.main.local_inference_tts')
request_data = {
"text": "Hello, how are you?",
"voice": "random",
"preset": "ultra_fast"
}
with TestClient(app) as client:
response = client.post(
"/tts",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
json=request_data
)
# Log detailed information about the response for debugging
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response content: {response.content}")
assert response.status_code == 200
assert response.headers['content-type'] == 'audio/wav'
# assert response.headers['content-disposition'] == 'attachment; filename=random_0.wav'
def test_queue_status():
with TestClient(app) as client:
response = client.get(
"/queue-status",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
# Log detailed information about the response for debugging
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response content: {response.content}")
assert response.status_code == 200
assert 'queue_length' in response.json()
assert 'tasks' in response.json()
def test_task_status():
# Add a dummy task to the tasks dictionary
task_id = 'test-task-id'
tasks[task_id] = {
'status': 'queued',
'request': TranscriptionRequest(
text="Hello, how are you?",
voice="random",
preset="fast"
),
'result': None,
'error': None
}
with TestClient(app) as client:
response = client.get(
f"/task-status/{task_id}",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
# Log detailed information about the response for debugging
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response content: {response.content}")
assert response.status_code == 200
assert response.json()['status'] == 'queued'
# Clean up the tasks dictionary
del tasks[task_id]
def test_failed_authentication():
with TestClient(app) as client:
response = client.post(
"/tts",
headers=basic_auth("wrong_username", "wrong_password"),
json={
"text": "Hello, how are you?",
"voice": "random",
"output_path": "data/tests",
"preset": "fast"
}
)
# Log detailed information about the response for debugging
logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response content: {response.content}")
assert response.status_code == 401

126
tests/test_tts_api.py Normal file
View file

@ -0,0 +1,126 @@
import pytest
import os, sys, time
import dotenv
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
# Add the root directory of the repo to sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'tortoise')))
# load environment variables
load_envvar = dotenv.load_dotenv('tests/.env', override=True)
assert load_envvar and os.getenv("TESTING").lower()=="true", "Missing environment variables"
from app.main import app
client = TestClient(app)
# Helper function to simulate authentication
def basic_auth(username, password):
import base64
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode("ascii")).decode("ascii")
return {"Authorization": f"Basic {encoded_credentials}"}
@pytest.fixture(scope="module")
def task_id():
with patch('app.main.text_to_speech') as mock_tts:
mock_tts.return_value = {
"task_id": "mock_task_id", "status": "queued"}
request_data = {
"text": "Hello, how are you?",
"voice": "random",
"preset": "ultra_fast"
}
with TestClient(app) as client:
response = client.post(
"/tts",
headers=basic_auth(
os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD")),
json=request_data
)
response.raise_for_status()
return response.json().get("task_id")
@pytest.fixture(scope="module")
def access_token():
with TestClient(app) as client:
response = client.post(
"/login",
data={
"username": os.getenv("TEST_USERNAME"),
"password": os.getenv("TEST_PASSWORD")
}
)
response.raise_for_status()
token = response.json().get("access_token")
assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert "access_token" in response.json(), f"Access token missing in response. Response: {response.text}"
assert "token_type" in response.json(), f"Token type missing in response. Response: {response.text}"
assert response.json()["token_type"] == "bearer", f"Expected token type 'bearer', got {response.json()['token_type']}. Response: {response.text}"
return token
def test_queue_status_api_key(access_token):
with TestClient(app) as client:
response = client.get(
"/queue-status",
headers={
"Authorization": f"Bearer {access_token}"
}
)
assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert 'tasks' in response.json(), f"'tasks' key missing in response. Response: {response.json()}"
def test_queue_status():
with TestClient(app) as client:
response = client.get(
"/queue-status",
headers=basic_auth(os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert 'tasks' in response.json(), f"'tasks' key missing in response. Response: {response.json()}"
def test_tts_task_creation(task_id):
assert task_id is not None, f"Task ID is None. Check the task creation process."
with TestClient(app) as client:
response = client.get(
"/queue-status",
headers=basic_auth(
os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
response.raise_for_status()
queue_status = response.json()
assert task_id in queue_status["tasks"], f"Task ID {task_id} not found in queue tasks. Queue status: {queue_status}"
assert queue_status["tasks"][task_id]["status"] in ["queued", "in_progress", "completed"], f"Expected task status 'queued', got {queue_status['tasks'][task_id]['status']}. Queue status: {queue_status}"
def test_tts_task_completion(task_id):
with patch('app.main.tasks') as mock_tasks:
mock_tasks[task_id] = {
"status": "completed",
"result": {"message": "Task completed"}
}
with TestClient(app) as client:
start_time = time.time()
while True:
response = client.get(
f"/task-status/{task_id}",
headers=basic_auth(
os.getenv("TEST_USERNAME"), os.getenv("TEST_PASSWORD"))
)
if response.status_code == 200 and response.json().get("status") == "completed":
break
elif response.status_code == 500:
raise AssertionError(f"Task failed: {response.json().get('detail')}. Response: {response.text}")
elif time.time() - start_time > 60:
raise AssertionError(f"Task did not complete within 1 minute. Task ID: {task_id}")
time.sleep(5)
# Verify that the task is completed
assert response.status_code == 200, f"Expected status code 200, got {response.status_code}. Response: {response.text}"
assert response.json()["status"] == "completed", f"Expected task status 'completed', got {response.json()['status']}. Response: {response.json()}"
assert response.json()["result"]["message"] == "Task completed", f"Expected result message 'Task completed', got {response.json()['result']['message']}. Response: {response.json()}"

View file

@ -40,10 +40,14 @@ MODELS = {
def get_model_path(model_name, models_dir=MODELS_DIR):
"""
Get path to given model, download it if it doesn't exist.
Uses the cache provided by HuggingFace's hf_hub_download.
"""
if model_name not in MODELS:
raise ValueError(f'Model {model_name} not found in available models.')
model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir)
# hf_hub_download will automatically check the cache
model_path = hf_hub_download(
repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir
)
return model_path

View file

@ -137,7 +137,10 @@ def classify_audio_clip(clip):
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
dropout=0, kernel_size=5, distribute_zero_label=False)
classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu')))
# Use the get_model_path function to get the model's path
model_path = get_model_path('classifier.pth')
# Load the model state dictionary using the path returned
classifier.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
return results[0][0]

View file

@ -16,20 +16,15 @@ from tortoise.utils.audio import load_voices
# Configure logging to print to the console
# Load logging configuration
with open("logging_conf.yml", 'r') as file:
config = yaml.safe_load(file.read())
logging.config.dictConfig(config)
logger = logging.getLogger(__name__)
def main(args):
# if torch.backends.mps.is_available():
# args.use_deepspeed = False
os.makedirs(args.output_path, exist_ok=True)
if not args.autoregressive_batch_size:
args.autoregressive_batch_size = pick_best_batch_size_for_gpu()
def _initialized_tts(args):
tts = TextToSpeech(
models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half)
models_dir=args.model_dir, autoregressive_batch_size=args.autoregressive_batch_size, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache,
half=args.half)
return tts
def infer_voice(tts: TextToSpeech, args: argparse.Namespace):
selected_voices = args.voice.split(',')
for k, selected_voice in tqdm(enumerate(selected_voices), desc="generating using selected voice"):
if '&' in selected_voice:
@ -57,6 +52,15 @@ def main(args):
os.makedirs('debug_states', exist_ok=True)
torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth')
return output_path
def main(args):
if torch.backends.mps.is_available():
args.use_deepspeed = False
os.makedirs(args.output_path, exist_ok=True)
if not args.autoregressive_batch_size:
args.autoregressive_batch_size = pick_best_batch_size_for_gpu()
tts = _initialized_tts(args)
return infer_voice(tts, args)
if __name__ == '__main__':
"""