mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-09 23:23:49 +01:00
API: Add parallel request support for llama.cpp and ExLlamaV3
This commit is contained in:
parent
2f08dce7b0
commit
9824c82cb6
|
|
@ -338,6 +338,35 @@ for event in client.events():
|
|||
print()
|
||||
```
|
||||
|
||||
#### Python parallel requests example
|
||||
|
||||
The API supports handling multiple requests in parallel. For ExLlamaV3, this works out of the box. For llama.cpp, you need to pass `--parallel N` to set the number of concurrent slots.
|
||||
|
||||
```python
|
||||
import concurrent.futures
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:5000/v1/chat/completions"
|
||||
prompts = [
|
||||
"Write a haiku about the ocean.",
|
||||
"Explain quantum computing in simple terms.",
|
||||
"Tell me a joke about programmers.",
|
||||
]
|
||||
|
||||
def send_request(prompt):
|
||||
response = requests.post(url, json={
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 200,
|
||||
})
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
results = list(executor.map(send_request, prompts))
|
||||
|
||||
for prompt, result in zip(prompts, results):
|
||||
print(f"Q: {prompt}\nA: {result}\n")
|
||||
```
|
||||
|
||||
#### Python example with API key
|
||||
|
||||
Replace
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ def convert_history(history):
|
|||
}
|
||||
|
||||
|
||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
|
||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False, stop_event=None) -> dict:
|
||||
if body.get('functions', []):
|
||||
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||
|
||||
|
|
@ -211,6 +211,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
|
||||
# generation parameters
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
if stop_event is not None:
|
||||
generate_params['stop_event'] = stop_event
|
||||
continue_ = body['continue_']
|
||||
|
||||
# Instruction template
|
||||
|
|
@ -378,7 +380,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
yield resp
|
||||
|
||||
|
||||
def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
|
||||
object_type = 'text_completion.chunk' if stream else 'text_completion'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
|
|
@ -411,6 +413,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
generate_params['stream'] = stream
|
||||
if stop_event is not None:
|
||||
generate_params['stop_event'] = stop_event
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
suffix = body['suffix'] if body['suffix'] else ''
|
||||
|
|
@ -561,23 +565,23 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||
yield chunk
|
||||
|
||||
|
||||
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
generator = chat_completions_common(body, is_legacy, stream=False)
|
||||
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||
generator = chat_completions_common(body, is_legacy, stream=False, stop_event=stop_event)
|
||||
return deque(generator, maxlen=1).pop()
|
||||
|
||||
|
||||
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
for resp in chat_completions_common(body, is_legacy, stream=True):
|
||||
def stream_chat_completions(body: dict, is_legacy: bool = False, stop_event=None):
|
||||
for resp in chat_completions_common(body, is_legacy, stream=True, stop_event=stop_event):
|
||||
yield resp
|
||||
|
||||
|
||||
def completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
generator = completions_common(body, is_legacy, stream=False)
|
||||
def completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||
generator = completions_common(body, is_legacy, stream=False, stop_event=stop_event)
|
||||
return deque(generator, maxlen=1).pop()
|
||||
|
||||
|
||||
def stream_completions(body: dict, is_legacy: bool = False):
|
||||
for resp in completions_common(body, is_legacy, stream=True):
|
||||
def stream_completions(body: dict, is_legacy: bool = False, stop_event=None):
|
||||
for resp in completions_common(body, is_legacy, stream=True, stop_event=stop_event):
|
||||
yield resp
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import traceback
|
||||
from collections import deque
|
||||
from threading import Thread
|
||||
|
|
@ -24,7 +25,7 @@ from extensions.openai.utils import _start_cloudflared
|
|||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import unload_model
|
||||
from modules.text_generation import stop_everything_event
|
||||
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation
|
||||
|
||||
from .typing import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -58,10 +59,6 @@ params = {
|
|||
}
|
||||
|
||||
|
||||
streaming_semaphore = asyncio.Semaphore(1)
|
||||
image_generation_semaphore = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||
expected_api_key = shared.args.api_key
|
||||
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
||||
|
|
@ -113,28 +110,30 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
|||
is_legacy = "/generate" in path
|
||||
|
||||
if request_data.stream:
|
||||
async def generator():
|
||||
async with streaming_semaphore:
|
||||
try:
|
||||
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
async for resp in iterate_in_threadpool(response):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
stop_event = threading.Event()
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
finally:
|
||||
stop_everything_event()
|
||||
response.close()
|
||||
return
|
||||
async def generator():
|
||||
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
|
||||
try:
|
||||
async for resp in iterate_in_threadpool(response):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
||||
return EventSourceResponse(generator()) # SSE streaming
|
||||
|
||||
else:
|
||||
stop_event = threading.Event()
|
||||
response = await asyncio.to_thread(
|
||||
OAIcompletions.completions,
|
||||
to_dict(request_data),
|
||||
is_legacy=is_legacy
|
||||
is_legacy=is_legacy,
|
||||
stop_event=stop_event
|
||||
)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
|
@ -146,28 +145,30 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
|||
is_legacy = "/generate" in path
|
||||
|
||||
if request_data.stream:
|
||||
async def generator():
|
||||
async with streaming_semaphore:
|
||||
try:
|
||||
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
async for resp in iterate_in_threadpool(response):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
stop_event = threading.Event()
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
finally:
|
||||
stop_everything_event()
|
||||
response.close()
|
||||
return
|
||||
async def generator():
|
||||
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
|
||||
try:
|
||||
async for resp in iterate_in_threadpool(response):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
||||
return EventSourceResponse(generator()) # SSE streaming
|
||||
|
||||
else:
|
||||
stop_event = threading.Event()
|
||||
response = await asyncio.to_thread(
|
||||
OAIcompletions.chat_completions,
|
||||
to_dict(request_data),
|
||||
is_legacy=is_legacy
|
||||
is_legacy=is_legacy,
|
||||
stop_event=stop_event
|
||||
)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
|
@ -232,9 +233,8 @@ async def handle_audio_transcription(request: Request):
|
|||
async def handle_image_generation(request_data: ImageGenerationRequest):
|
||||
import extensions.openai.images as OAIimages
|
||||
|
||||
async with image_generation_semaphore:
|
||||
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
||||
return JSONResponse(response)
|
||||
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import queue
|
||||
import threading
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple
|
||||
|
|
@ -34,6 +36,55 @@ except Exception:
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
class ConcurrentGenerator:
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
self.lock = threading.Lock()
|
||||
self.job_queues = {}
|
||||
self.active = True
|
||||
self.has_jobs = threading.Event()
|
||||
self.thread = threading.Thread(target=self._iterate_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def _iterate_loop(self):
|
||||
while self.active:
|
||||
self.has_jobs.wait(timeout=0.5)
|
||||
with self.lock:
|
||||
if self.generator.num_remaining_jobs() == 0:
|
||||
self.has_jobs.clear()
|
||||
continue
|
||||
results = self.generator.iterate()
|
||||
for result in results:
|
||||
job = result["job"]
|
||||
q = self.job_queues.get(job)
|
||||
if q:
|
||||
q.put(result)
|
||||
if result.get("eos"):
|
||||
self.job_queues.pop(job, None)
|
||||
if not self.job_queues:
|
||||
self.has_jobs.clear()
|
||||
|
||||
def submit(self, job) -> queue.Queue:
|
||||
q = queue.Queue()
|
||||
with self.lock:
|
||||
self.job_queues[job] = q
|
||||
self.generator.enqueue(job)
|
||||
self.has_jobs.set()
|
||||
return q
|
||||
|
||||
def cancel(self, job):
|
||||
with self.lock:
|
||||
if job in self.job_queues:
|
||||
self.generator.cancel(job)
|
||||
self.job_queues[job].put(None)
|
||||
del self.job_queues[job]
|
||||
|
||||
def stop(self):
|
||||
self.active = False
|
||||
self.has_jobs.set()
|
||||
self.thread.join(timeout=5)
|
||||
|
||||
|
||||
class Exllamav3Model:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
@ -167,6 +218,7 @@ class Exllamav3Model:
|
|||
result.cache = cache
|
||||
result.tokenizer = tokenizer
|
||||
result.generator = generator
|
||||
result.parallel_generator = ConcurrentGenerator(generator)
|
||||
result.config = config
|
||||
result.max_tokens = max_tokens
|
||||
result.vision_model = vision_model
|
||||
|
|
@ -346,27 +398,47 @@ class Exllamav3Model:
|
|||
)
|
||||
|
||||
# Stream generation
|
||||
self.generator.enqueue(job)
|
||||
|
||||
response_text = ""
|
||||
stop_event = state.get('stop_event')
|
||||
|
||||
try:
|
||||
while self.generator.num_remaining_jobs():
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
results = self.generator.iterate()
|
||||
for result in results:
|
||||
if "eos" in result and result["eos"]:
|
||||
if stop_event:
|
||||
# Concurrent path for API requests
|
||||
result_queue = self.parallel_generator.submit(job)
|
||||
try:
|
||||
while True:
|
||||
if stop_event.is_set() or shared.stop_everything:
|
||||
break
|
||||
try:
|
||||
result = result_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if result is None or result.get("eos"):
|
||||
break
|
||||
|
||||
chunk = result.get("text", "")
|
||||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
finally:
|
||||
self.parallel_generator.cancel(job)
|
||||
else:
|
||||
# Original single-request path (WebUI)
|
||||
self.generator.enqueue(job)
|
||||
try:
|
||||
while self.generator.num_remaining_jobs():
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
finally:
|
||||
self.generator.clear_queue()
|
||||
results = self.generator.iterate()
|
||||
for result in results:
|
||||
if "eos" in result and result["eos"]:
|
||||
break
|
||||
|
||||
chunk = result.get("text", "")
|
||||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
finally:
|
||||
self.generator.clear_queue()
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
|
|
@ -429,6 +501,13 @@ class Exllamav3Model:
|
|||
def unload(self):
|
||||
logger.info("Unloading ExLlamaV3 model components...")
|
||||
|
||||
if hasattr(self, 'parallel_generator') and self.parallel_generator is not None:
|
||||
try:
|
||||
self.parallel_generator.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping parallel generator: {e}")
|
||||
self.parallel_generator = None
|
||||
|
||||
if hasattr(self, 'vision_model') and self.vision_model is not None:
|
||||
try:
|
||||
del self.vision_model
|
||||
|
|
|
|||
|
|
@ -217,8 +217,9 @@ class LlamaServer:
|
|||
full_text = ""
|
||||
|
||||
# Process the streaming response
|
||||
stop_event = state.get('stop_event')
|
||||
for line in response.iter_lines():
|
||||
if shared.stop_everything:
|
||||
if shared.stop_everything or (stop_event and stop_event.is_set()):
|
||||
break
|
||||
|
||||
if not line:
|
||||
|
|
@ -410,6 +411,7 @@ class LlamaServer:
|
|||
cmd += ["--spec-ngram-size-n", str(shared.args.spec_ngram_size_n)]
|
||||
cmd += ["--spec-ngram-size-m", str(shared.args.spec_ngram_size_m)]
|
||||
cmd += ["--spec-ngram-min-hits", str(shared.args.spec_ngram_min_hits)]
|
||||
cmd += ["--parallel", str(shared.args.parallel)]
|
||||
if shared.args.streaming_llm:
|
||||
cmd += ["--cache-reuse", "1"]
|
||||
cmd += ["--swa-full"]
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ loaders_and_params = OrderedDict({
|
|||
'no_mmap',
|
||||
'mlock',
|
||||
'numa',
|
||||
'parallel',
|
||||
'model_draft',
|
||||
'draft_max',
|
||||
'gpu_layers_draft',
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number
|
|||
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
||||
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
||||
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
||||
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||
|
||||
# Transformers/Accelerate
|
||||
|
|
|
|||
|
|
@ -22,13 +22,23 @@ def generate_reply(*args, **kwargs):
|
|||
from modules.models import load_model
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
shared.generation_lock.acquire()
|
||||
state = args[1] if len(args) > 1 else kwargs.get('state', {})
|
||||
use_parallel = (
|
||||
state.get('stop_event') is not None
|
||||
and shared.model.__class__.__name__ in ['Exllamav3Model', 'LlamaServer']
|
||||
and (shared.model.__class__.__name__ != 'LlamaServer' or shared.args.parallel > 1)
|
||||
)
|
||||
|
||||
if not use_parallel:
|
||||
shared.generation_lock.acquire()
|
||||
|
||||
try:
|
||||
for result in _generate_reply(*args, **kwargs):
|
||||
yield result
|
||||
finally:
|
||||
models.last_generation_time = time.time()
|
||||
shared.generation_lock.release()
|
||||
if not use_parallel:
|
||||
shared.generation_lock.release()
|
||||
|
||||
|
||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False):
|
||||
|
|
@ -68,7 +78,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
stop_event_ref = state.pop('stop_event', None)
|
||||
state = copy.deepcopy(state)
|
||||
if stop_event_ref is not None:
|
||||
state['stop_event'] = stop_event_ref
|
||||
state['stream'] = True
|
||||
|
||||
# Generate
|
||||
|
|
@ -99,7 +112,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
yield reply
|
||||
last_update = time.monotonic()
|
||||
|
||||
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
|
||||
stop_event = state.get('stop_event')
|
||||
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything) or (stop_event and stop_event.is_set()):
|
||||
break
|
||||
|
||||
if not is_chat:
|
||||
|
|
@ -474,7 +488,10 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N
|
|||
For models that do not use the transformers library for sampling
|
||||
"""
|
||||
|
||||
stop_event_ref = state.pop('stop_event', None)
|
||||
state = copy.deepcopy(state)
|
||||
if stop_event_ref is not None:
|
||||
state['stop_event'] = stop_event_ref
|
||||
state['seed'] = set_manual_seed(state['seed'])
|
||||
t0 = time.time()
|
||||
reply = ''
|
||||
|
|
|
|||
|
|
@ -151,6 +151,7 @@ def list_model_elements():
|
|||
'no_mmap',
|
||||
'mlock',
|
||||
'numa',
|
||||
'parallel',
|
||||
'use_double_quant',
|
||||
'bf16',
|
||||
'enable_tp',
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ def create_ui():
|
|||
with gr.Column():
|
||||
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
|
||||
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
|
||||
shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
||||
shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
|
||||
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
||||
|
|
|
|||
Loading…
Reference in a new issue