API: Add parallel request support for llama.cpp and ExLlamaV3

This commit is contained in:
oobabooga 2026-03-05 16:49:58 -08:00
parent 2f08dce7b0
commit 9824c82cb6
10 changed files with 198 additions and 63 deletions

View file

@ -338,6 +338,35 @@ for event in client.events():
print()
```
#### Python parallel requests example
The API supports handling multiple requests in parallel. For ExLlamaV3, this works out of the box. For llama.cpp, you need to pass `--parallel N` to set the number of concurrent slots.
```python
import concurrent.futures
import requests
url = "http://127.0.0.1:5000/v1/chat/completions"
prompts = [
"Write a haiku about the ocean.",
"Explain quantum computing in simple terms.",
"Tell me a joke about programmers.",
]
def send_request(prompt):
response = requests.post(url, json={
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 200,
})
return response.json()["choices"][0]["message"]["content"]
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(send_request, prompts))
for prompt, result in zip(prompts, results):
print(f"Q: {prompt}\nA: {result}\n")
```
#### Python example with API key
Replace

View file

@ -165,7 +165,7 @@ def convert_history(history):
}
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False, stop_event=None) -> dict:
if body.get('functions', []):
raise InvalidRequestError(message="functions is not supported.", param='functions')
@ -211,6 +211,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
# generation parameters
generate_params = process_parameters(body, is_legacy=is_legacy)
if stop_event is not None:
generate_params['stop_event'] = stop_event
continue_ = body['continue_']
# Instruction template
@ -378,7 +380,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
yield resp
def completions_common(body: dict, is_legacy: bool = False, stream=False):
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
object_type = 'text_completion.chunk' if stream else 'text_completion'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
@ -411,6 +413,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
generate_params = process_parameters(body, is_legacy=is_legacy)
max_tokens = generate_params['max_new_tokens']
generate_params['stream'] = stream
if stop_event is not None:
generate_params['stop_event'] = stop_event
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
suffix = body['suffix'] if body['suffix'] else ''
@ -561,23 +565,23 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
yield chunk
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
generator = chat_completions_common(body, is_legacy, stream=False)
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
generator = chat_completions_common(body, is_legacy, stream=False, stop_event=stop_event)
return deque(generator, maxlen=1).pop()
def stream_chat_completions(body: dict, is_legacy: bool = False):
for resp in chat_completions_common(body, is_legacy, stream=True):
def stream_chat_completions(body: dict, is_legacy: bool = False, stop_event=None):
for resp in chat_completions_common(body, is_legacy, stream=True, stop_event=stop_event):
yield resp
def completions(body: dict, is_legacy: bool = False) -> dict:
generator = completions_common(body, is_legacy, stream=False)
def completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
generator = completions_common(body, is_legacy, stream=False, stop_event=stop_event)
return deque(generator, maxlen=1).pop()
def stream_completions(body: dict, is_legacy: bool = False):
for resp in completions_common(body, is_legacy, stream=True):
def stream_completions(body: dict, is_legacy: bool = False, stop_event=None):
for resp in completions_common(body, is_legacy, stream=True, stop_event=stop_event):
yield resp

View file

@ -3,6 +3,7 @@ import json
import logging
import os
import socket
import threading
import traceback
from collections import deque
from threading import Thread
@ -24,7 +25,7 @@ from extensions.openai.utils import _start_cloudflared
from modules import shared
from modules.logging_colors import logger
from modules.models import unload_model
from modules.text_generation import stop_everything_event
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation
from .typing import (
ChatCompletionRequest,
@ -58,10 +59,6 @@ params = {
}
streaming_semaphore = asyncio.Semaphore(1)
image_generation_semaphore = asyncio.Semaphore(1)
def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
@ -113,28 +110,30 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
is_legacy = "/generate" in path
if request_data.stream:
async def generator():
async with streaming_semaphore:
try:
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected()
if disconnected:
break
stop_event = threading.Event()
yield {"data": json.dumps(resp)}
finally:
stop_everything_event()
response.close()
return
async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
try:
async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected()
if disconnected:
break
yield {"data": json.dumps(resp)}
finally:
stop_event.set()
response.close()
return EventSourceResponse(generator()) # SSE streaming
else:
stop_event = threading.Event()
response = await asyncio.to_thread(
OAIcompletions.completions,
to_dict(request_data),
is_legacy=is_legacy
is_legacy=is_legacy,
stop_event=stop_event
)
return JSONResponse(response)
@ -146,28 +145,30 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
is_legacy = "/generate" in path
if request_data.stream:
async def generator():
async with streaming_semaphore:
try:
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected()
if disconnected:
break
stop_event = threading.Event()
yield {"data": json.dumps(resp)}
finally:
stop_everything_event()
response.close()
return
async def generator():
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
try:
async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected()
if disconnected:
break
yield {"data": json.dumps(resp)}
finally:
stop_event.set()
response.close()
return EventSourceResponse(generator()) # SSE streaming
else:
stop_event = threading.Event()
response = await asyncio.to_thread(
OAIcompletions.chat_completions,
to_dict(request_data),
is_legacy=is_legacy
is_legacy=is_legacy,
stop_event=stop_event
)
return JSONResponse(response)
@ -232,9 +233,8 @@ async def handle_audio_transcription(request: Request):
async def handle_image_generation(request_data: ImageGenerationRequest):
import extensions.openai.images as OAIimages
async with image_generation_semaphore:
response = await asyncio.to_thread(OAIimages.generations, request_data)
return JSONResponse(response)
response = await asyncio.to_thread(OAIimages.generations, request_data)
return JSONResponse(response)
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)

View file

@ -1,3 +1,5 @@
import queue
import threading
import traceback
from pathlib import Path
from typing import Any, List, Tuple
@ -34,6 +36,55 @@ except Exception:
traceback.print_exc()
class ConcurrentGenerator:
def __init__(self, generator):
self.generator = generator
self.lock = threading.Lock()
self.job_queues = {}
self.active = True
self.has_jobs = threading.Event()
self.thread = threading.Thread(target=self._iterate_loop, daemon=True)
self.thread.start()
def _iterate_loop(self):
while self.active:
self.has_jobs.wait(timeout=0.5)
with self.lock:
if self.generator.num_remaining_jobs() == 0:
self.has_jobs.clear()
continue
results = self.generator.iterate()
for result in results:
job = result["job"]
q = self.job_queues.get(job)
if q:
q.put(result)
if result.get("eos"):
self.job_queues.pop(job, None)
if not self.job_queues:
self.has_jobs.clear()
def submit(self, job) -> queue.Queue:
q = queue.Queue()
with self.lock:
self.job_queues[job] = q
self.generator.enqueue(job)
self.has_jobs.set()
return q
def cancel(self, job):
with self.lock:
if job in self.job_queues:
self.generator.cancel(job)
self.job_queues[job].put(None)
del self.job_queues[job]
def stop(self):
self.active = False
self.has_jobs.set()
self.thread.join(timeout=5)
class Exllamav3Model:
def __init__(self):
pass
@ -167,6 +218,7 @@ class Exllamav3Model:
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
result.parallel_generator = ConcurrentGenerator(generator)
result.config = config
result.max_tokens = max_tokens
result.vision_model = vision_model
@ -346,27 +398,47 @@ class Exllamav3Model:
)
# Stream generation
self.generator.enqueue(job)
response_text = ""
stop_event = state.get('stop_event')
try:
while self.generator.num_remaining_jobs():
if shared.stop_everything:
break
results = self.generator.iterate()
for result in results:
if "eos" in result and result["eos"]:
if stop_event:
# Concurrent path for API requests
result_queue = self.parallel_generator.submit(job)
try:
while True:
if stop_event.is_set() or shared.stop_everything:
break
try:
result = result_queue.get(timeout=0.1)
except queue.Empty:
continue
if result is None or result.get("eos"):
break
chunk = result.get("text", "")
if chunk:
response_text += chunk
yield response_text
finally:
self.parallel_generator.cancel(job)
else:
# Original single-request path (WebUI)
self.generator.enqueue(job)
try:
while self.generator.num_remaining_jobs():
if shared.stop_everything:
break
finally:
self.generator.clear_queue()
results = self.generator.iterate()
for result in results:
if "eos" in result and result["eos"]:
break
chunk = result.get("text", "")
if chunk:
response_text += chunk
yield response_text
finally:
self.generator.clear_queue()
def generate(self, prompt, state):
output = ""
@ -429,6 +501,13 @@ class Exllamav3Model:
def unload(self):
logger.info("Unloading ExLlamaV3 model components...")
if hasattr(self, 'parallel_generator') and self.parallel_generator is not None:
try:
self.parallel_generator.stop()
except Exception as e:
logger.warning(f"Error stopping parallel generator: {e}")
self.parallel_generator = None
if hasattr(self, 'vision_model') and self.vision_model is not None:
try:
del self.vision_model

View file

@ -217,8 +217,9 @@ class LlamaServer:
full_text = ""
# Process the streaming response
stop_event = state.get('stop_event')
for line in response.iter_lines():
if shared.stop_everything:
if shared.stop_everything or (stop_event and stop_event.is_set()):
break
if not line:
@ -410,6 +411,7 @@ class LlamaServer:
cmd += ["--spec-ngram-size-n", str(shared.args.spec_ngram_size_n)]
cmd += ["--spec-ngram-size-m", str(shared.args.spec_ngram_size_m)]
cmd += ["--spec-ngram-min-hits", str(shared.args.spec_ngram_min_hits)]
cmd += ["--parallel", str(shared.args.parallel)]
if shared.args.streaming_llm:
cmd += ["--cache-reuse", "1"]
cmd += ["--swa-full"]

View file

@ -23,6 +23,7 @@ loaders_and_params = OrderedDict({
'no_mmap',
'mlock',
'numa',
'parallel',
'model_draft',
'draft_max',
'gpu_layers_draft',

View file

@ -102,6 +102,7 @@ group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
# Transformers/Accelerate

View file

@ -22,13 +22,23 @@ def generate_reply(*args, **kwargs):
from modules.models import load_model
shared.model, shared.tokenizer = load_model(shared.model_name)
shared.generation_lock.acquire()
state = args[1] if len(args) > 1 else kwargs.get('state', {})
use_parallel = (
state.get('stop_event') is not None
and shared.model.__class__.__name__ in ['Exllamav3Model', 'LlamaServer']
and (shared.model.__class__.__name__ != 'LlamaServer' or shared.args.parallel > 1)
)
if not use_parallel:
shared.generation_lock.acquire()
try:
for result in _generate_reply(*args, **kwargs):
yield result
finally:
models.last_generation_time = time.time()
shared.generation_lock.release()
if not use_parallel:
shared.generation_lock.release()
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False):
@ -68,7 +78,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
reply = ''
is_stream = state['stream']
if len(all_stop_strings) > 0 and not state['stream']:
stop_event_ref = state.pop('stop_event', None)
state = copy.deepcopy(state)
if stop_event_ref is not None:
state['stop_event'] = stop_event_ref
state['stream'] = True
# Generate
@ -99,7 +112,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
yield reply
last_update = time.monotonic()
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
stop_event = state.get('stop_event')
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything) or (stop_event and stop_event.is_set()):
break
if not is_chat:
@ -474,7 +488,10 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N
For models that do not use the transformers library for sampling
"""
stop_event_ref = state.pop('stop_event', None)
state = copy.deepcopy(state)
if stop_event_ref is not None:
state['stop_event'] = stop_event_ref
state['seed'] = set_manual_seed(state['seed'])
t0 = time.time()
reply = ''

View file

@ -151,6 +151,7 @@ def list_model_elements():
'no_mmap',
'mlock',
'numa',
'parallel',
'use_double_quant',
'bf16',
'enable_tp',

View file

@ -90,6 +90,7 @@ def create_ui():
with gr.Column():
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')