From 9824c82cb65cb28953c038d36710ecc82df0eb1d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:49:58 -0800 Subject: [PATCH] API: Add parallel request support for llama.cpp and ExLlamaV3 --- docs/12 - OpenAI API.md | 29 +++++++++ extensions/openai/completions.py | 24 ++++--- extensions/openai/script.py | 72 ++++++++++----------- modules/exllamav3.py | 105 +++++++++++++++++++++++++++---- modules/llama_cpp_server.py | 4 +- modules/loaders.py | 1 + modules/shared.py | 1 + modules/text_generation.py | 23 ++++++- modules/ui.py | 1 + modules/ui_model_menu.py | 1 + 10 files changed, 198 insertions(+), 63 deletions(-) diff --git a/docs/12 - OpenAI API.md b/docs/12 - OpenAI API.md index cd5757f6..fc444b15 100644 --- a/docs/12 - OpenAI API.md +++ b/docs/12 - OpenAI API.md @@ -338,6 +338,35 @@ for event in client.events(): print() ``` +#### Python parallel requests example + +The API supports handling multiple requests in parallel. For ExLlamaV3, this works out of the box. For llama.cpp, you need to pass `--parallel N` to set the number of concurrent slots. + +```python +import concurrent.futures +import requests + +url = "http://127.0.0.1:5000/v1/chat/completions" +prompts = [ + "Write a haiku about the ocean.", + "Explain quantum computing in simple terms.", + "Tell me a joke about programmers.", +] + +def send_request(prompt): + response = requests.post(url, json={ + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 200, + }) + return response.json()["choices"][0]["message"]["content"] + +with concurrent.futures.ThreadPoolExecutor() as executor: + results = list(executor.map(send_request, prompts)) + +for prompt, result in zip(prompts, results): + print(f"Q: {prompt}\nA: {result}\n") +``` + #### Python example with API key Replace diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index c3037d0c..de944a8f 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -165,7 +165,7 @@ def convert_history(history): } -def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict: +def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False, stop_event=None) -> dict: if body.get('functions', []): raise InvalidRequestError(message="functions is not supported.", param='functions') @@ -211,6 +211,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p # generation parameters generate_params = process_parameters(body, is_legacy=is_legacy) + if stop_event is not None: + generate_params['stop_event'] = stop_event continue_ = body['continue_'] # Instruction template @@ -378,7 +380,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p yield resp -def completions_common(body: dict, is_legacy: bool = False, stream=False): +def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None): object_type = 'text_completion.chunk' if stream else 'text_completion' created_time = int(time.time()) cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) @@ -411,6 +413,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): generate_params = process_parameters(body, is_legacy=is_legacy) max_tokens = generate_params['max_new_tokens'] generate_params['stream'] = stream + if stop_event is not None: + generate_params['stop_event'] = stop_event requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) suffix = body['suffix'] if body['suffix'] else '' @@ -561,23 +565,23 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): yield chunk -def chat_completions(body: dict, is_legacy: bool = False) -> dict: - generator = chat_completions_common(body, is_legacy, stream=False) +def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict: + generator = chat_completions_common(body, is_legacy, stream=False, stop_event=stop_event) return deque(generator, maxlen=1).pop() -def stream_chat_completions(body: dict, is_legacy: bool = False): - for resp in chat_completions_common(body, is_legacy, stream=True): +def stream_chat_completions(body: dict, is_legacy: bool = False, stop_event=None): + for resp in chat_completions_common(body, is_legacy, stream=True, stop_event=stop_event): yield resp -def completions(body: dict, is_legacy: bool = False) -> dict: - generator = completions_common(body, is_legacy, stream=False) +def completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict: + generator = completions_common(body, is_legacy, stream=False, stop_event=stop_event) return deque(generator, maxlen=1).pop() -def stream_completions(body: dict, is_legacy: bool = False): - for resp in completions_common(body, is_legacy, stream=True): +def stream_completions(body: dict, is_legacy: bool = False, stop_event=None): + for resp in completions_common(body, is_legacy, stream=True, stop_event=stop_event): yield resp diff --git a/extensions/openai/script.py b/extensions/openai/script.py index edb22c22..7a30e311 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -3,6 +3,7 @@ import json import logging import os import socket +import threading import traceback from collections import deque from threading import Thread @@ -24,7 +25,7 @@ from extensions.openai.utils import _start_cloudflared from modules import shared from modules.logging_colors import logger from modules.models import unload_model -from modules.text_generation import stop_everything_event +from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation from .typing import ( ChatCompletionRequest, @@ -58,10 +59,6 @@ params = { } -streaming_semaphore = asyncio.Semaphore(1) -image_generation_semaphore = asyncio.Semaphore(1) - - def verify_api_key(authorization: str = Header(None)) -> None: expected_api_key = shared.args.api_key if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): @@ -113,28 +110,30 @@ async def openai_completions(request: Request, request_data: CompletionRequest): is_legacy = "/generate" in path if request_data.stream: - async def generator(): - async with streaming_semaphore: - try: - response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy) - async for resp in iterate_in_threadpool(response): - disconnected = await request.is_disconnected() - if disconnected: - break + stop_event = threading.Event() - yield {"data": json.dumps(resp)} - finally: - stop_everything_event() - response.close() - return + async def generator(): + response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event) + try: + async for resp in iterate_in_threadpool(response): + disconnected = await request.is_disconnected() + if disconnected: + break + + yield {"data": json.dumps(resp)} + finally: + stop_event.set() + response.close() return EventSourceResponse(generator()) # SSE streaming else: + stop_event = threading.Event() response = await asyncio.to_thread( OAIcompletions.completions, to_dict(request_data), - is_legacy=is_legacy + is_legacy=is_legacy, + stop_event=stop_event ) return JSONResponse(response) @@ -146,28 +145,30 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion is_legacy = "/generate" in path if request_data.stream: - async def generator(): - async with streaming_semaphore: - try: - response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy) - async for resp in iterate_in_threadpool(response): - disconnected = await request.is_disconnected() - if disconnected: - break + stop_event = threading.Event() - yield {"data": json.dumps(resp)} - finally: - stop_everything_event() - response.close() - return + async def generator(): + response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event) + try: + async for resp in iterate_in_threadpool(response): + disconnected = await request.is_disconnected() + if disconnected: + break + + yield {"data": json.dumps(resp)} + finally: + stop_event.set() + response.close() return EventSourceResponse(generator()) # SSE streaming else: + stop_event = threading.Event() response = await asyncio.to_thread( OAIcompletions.chat_completions, to_dict(request_data), - is_legacy=is_legacy + is_legacy=is_legacy, + stop_event=stop_event ) return JSONResponse(response) @@ -232,9 +233,8 @@ async def handle_audio_transcription(request: Request): async def handle_image_generation(request_data: ImageGenerationRequest): import extensions.openai.images as OAIimages - async with image_generation_semaphore: - response = await asyncio.to_thread(OAIimages.generations, request_data) - return JSONResponse(response) + response = await asyncio.to_thread(OAIimages.generations, request_data) + return JSONResponse(response) @app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key) diff --git a/modules/exllamav3.py b/modules/exllamav3.py index eca5dde0..af5745bc 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -1,3 +1,5 @@ +import queue +import threading import traceback from pathlib import Path from typing import Any, List, Tuple @@ -34,6 +36,55 @@ except Exception: traceback.print_exc() +class ConcurrentGenerator: + def __init__(self, generator): + self.generator = generator + self.lock = threading.Lock() + self.job_queues = {} + self.active = True + self.has_jobs = threading.Event() + self.thread = threading.Thread(target=self._iterate_loop, daemon=True) + self.thread.start() + + def _iterate_loop(self): + while self.active: + self.has_jobs.wait(timeout=0.5) + with self.lock: + if self.generator.num_remaining_jobs() == 0: + self.has_jobs.clear() + continue + results = self.generator.iterate() + for result in results: + job = result["job"] + q = self.job_queues.get(job) + if q: + q.put(result) + if result.get("eos"): + self.job_queues.pop(job, None) + if not self.job_queues: + self.has_jobs.clear() + + def submit(self, job) -> queue.Queue: + q = queue.Queue() + with self.lock: + self.job_queues[job] = q + self.generator.enqueue(job) + self.has_jobs.set() + return q + + def cancel(self, job): + with self.lock: + if job in self.job_queues: + self.generator.cancel(job) + self.job_queues[job].put(None) + del self.job_queues[job] + + def stop(self): + self.active = False + self.has_jobs.set() + self.thread.join(timeout=5) + + class Exllamav3Model: def __init__(self): pass @@ -167,6 +218,7 @@ class Exllamav3Model: result.cache = cache result.tokenizer = tokenizer result.generator = generator + result.parallel_generator = ConcurrentGenerator(generator) result.config = config result.max_tokens = max_tokens result.vision_model = vision_model @@ -346,27 +398,47 @@ class Exllamav3Model: ) # Stream generation - self.generator.enqueue(job) - response_text = "" + stop_event = state.get('stop_event') - try: - while self.generator.num_remaining_jobs(): - if shared.stop_everything: - break - - results = self.generator.iterate() - for result in results: - if "eos" in result and result["eos"]: + if stop_event: + # Concurrent path for API requests + result_queue = self.parallel_generator.submit(job) + try: + while True: + if stop_event.is_set() or shared.stop_everything: + break + try: + result = result_queue.get(timeout=0.1) + except queue.Empty: + continue + if result is None or result.get("eos"): break - chunk = result.get("text", "") if chunk: response_text += chunk yield response_text + finally: + self.parallel_generator.cancel(job) + else: + # Original single-request path (WebUI) + self.generator.enqueue(job) + try: + while self.generator.num_remaining_jobs(): + if shared.stop_everything: + break - finally: - self.generator.clear_queue() + results = self.generator.iterate() + for result in results: + if "eos" in result and result["eos"]: + break + + chunk = result.get("text", "") + if chunk: + response_text += chunk + yield response_text + finally: + self.generator.clear_queue() def generate(self, prompt, state): output = "" @@ -429,6 +501,13 @@ class Exllamav3Model: def unload(self): logger.info("Unloading ExLlamaV3 model components...") + if hasattr(self, 'parallel_generator') and self.parallel_generator is not None: + try: + self.parallel_generator.stop() + except Exception as e: + logger.warning(f"Error stopping parallel generator: {e}") + self.parallel_generator = None + if hasattr(self, 'vision_model') and self.vision_model is not None: try: del self.vision_model diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 2af9aa8a..017c5d2a 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -217,8 +217,9 @@ class LlamaServer: full_text = "" # Process the streaming response + stop_event = state.get('stop_event') for line in response.iter_lines(): - if shared.stop_everything: + if shared.stop_everything or (stop_event and stop_event.is_set()): break if not line: @@ -410,6 +411,7 @@ class LlamaServer: cmd += ["--spec-ngram-size-n", str(shared.args.spec_ngram_size_n)] cmd += ["--spec-ngram-size-m", str(shared.args.spec_ngram_size_m)] cmd += ["--spec-ngram-min-hits", str(shared.args.spec_ngram_min_hits)] + cmd += ["--parallel", str(shared.args.parallel)] if shared.args.streaming_llm: cmd += ["--cache-reuse", "1"] cmd += ["--swa-full"] diff --git a/modules/loaders.py b/modules/loaders.py index 9923f116..0348c939 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -23,6 +23,7 @@ loaders_and_params = OrderedDict({ 'no_mmap', 'mlock', 'numa', + 'parallel', 'model_draft', 'draft_max', 'gpu_layers_draft', diff --git a/modules/shared.py b/modules/shared.py index 8b820e2d..ec6e23b9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -102,6 +102,7 @@ group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number group.add_argument('--threads', type=int, default=0, help='Number of threads to use.') group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.') group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.') +group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.') group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"') # Transformers/Accelerate diff --git a/modules/text_generation.py b/modules/text_generation.py index cace3f7c..6e0e67a1 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -22,13 +22,23 @@ def generate_reply(*args, **kwargs): from modules.models import load_model shared.model, shared.tokenizer = load_model(shared.model_name) - shared.generation_lock.acquire() + state = args[1] if len(args) > 1 else kwargs.get('state', {}) + use_parallel = ( + state.get('stop_event') is not None + and shared.model.__class__.__name__ in ['Exllamav3Model', 'LlamaServer'] + and (shared.model.__class__.__name__ != 'LlamaServer' or shared.args.parallel > 1) + ) + + if not use_parallel: + shared.generation_lock.acquire() + try: for result in _generate_reply(*args, **kwargs): yield result finally: models.last_generation_time = time.time() - shared.generation_lock.release() + if not use_parallel: + shared.generation_lock.release() def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False): @@ -68,7 +78,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap reply = '' is_stream = state['stream'] if len(all_stop_strings) > 0 and not state['stream']: + stop_event_ref = state.pop('stop_event', None) state = copy.deepcopy(state) + if stop_event_ref is not None: + state['stop_event'] = stop_event_ref state['stream'] = True # Generate @@ -99,7 +112,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield reply last_update = time.monotonic() - if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything): + stop_event = state.get('stop_event') + if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything) or (stop_event and stop_event.is_set()): break if not is_chat: @@ -474,7 +488,10 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N For models that do not use the transformers library for sampling """ + stop_event_ref = state.pop('stop_event', None) state = copy.deepcopy(state) + if stop_event_ref is not None: + state['stop_event'] = stop_event_ref state['seed'] = set_manual_seed(state['seed']) t0 = time.time() reply = '' diff --git a/modules/ui.py b/modules/ui.py index e118e684..fd20e782 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -151,6 +151,7 @@ def list_model_elements(): 'no_mmap', 'mlock', 'numa', + 'parallel', 'use_double_quant', 'bf16', 'enable_tp', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 3acdd062..12b5654c 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -90,6 +90,7 @@ def create_ui(): with gr.Column(): shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads) shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch) + shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.') shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size) shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size) shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')