mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-07 23:53:40 +00:00
API: Add parallel request support for llama.cpp and ExLlamaV3
This commit is contained in:
parent
2f08dce7b0
commit
9824c82cb6
10 changed files with 198 additions and 63 deletions
|
|
@ -1,3 +1,5 @@
|
|||
import queue
|
||||
import threading
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple
|
||||
|
|
@ -34,6 +36,55 @@ except Exception:
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
class ConcurrentGenerator:
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
self.lock = threading.Lock()
|
||||
self.job_queues = {}
|
||||
self.active = True
|
||||
self.has_jobs = threading.Event()
|
||||
self.thread = threading.Thread(target=self._iterate_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def _iterate_loop(self):
|
||||
while self.active:
|
||||
self.has_jobs.wait(timeout=0.5)
|
||||
with self.lock:
|
||||
if self.generator.num_remaining_jobs() == 0:
|
||||
self.has_jobs.clear()
|
||||
continue
|
||||
results = self.generator.iterate()
|
||||
for result in results:
|
||||
job = result["job"]
|
||||
q = self.job_queues.get(job)
|
||||
if q:
|
||||
q.put(result)
|
||||
if result.get("eos"):
|
||||
self.job_queues.pop(job, None)
|
||||
if not self.job_queues:
|
||||
self.has_jobs.clear()
|
||||
|
||||
def submit(self, job) -> queue.Queue:
|
||||
q = queue.Queue()
|
||||
with self.lock:
|
||||
self.job_queues[job] = q
|
||||
self.generator.enqueue(job)
|
||||
self.has_jobs.set()
|
||||
return q
|
||||
|
||||
def cancel(self, job):
|
||||
with self.lock:
|
||||
if job in self.job_queues:
|
||||
self.generator.cancel(job)
|
||||
self.job_queues[job].put(None)
|
||||
del self.job_queues[job]
|
||||
|
||||
def stop(self):
|
||||
self.active = False
|
||||
self.has_jobs.set()
|
||||
self.thread.join(timeout=5)
|
||||
|
||||
|
||||
class Exllamav3Model:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
@ -167,6 +218,7 @@ class Exllamav3Model:
|
|||
result.cache = cache
|
||||
result.tokenizer = tokenizer
|
||||
result.generator = generator
|
||||
result.parallel_generator = ConcurrentGenerator(generator)
|
||||
result.config = config
|
||||
result.max_tokens = max_tokens
|
||||
result.vision_model = vision_model
|
||||
|
|
@ -346,27 +398,47 @@ class Exllamav3Model:
|
|||
)
|
||||
|
||||
# Stream generation
|
||||
self.generator.enqueue(job)
|
||||
|
||||
response_text = ""
|
||||
stop_event = state.get('stop_event')
|
||||
|
||||
try:
|
||||
while self.generator.num_remaining_jobs():
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
results = self.generator.iterate()
|
||||
for result in results:
|
||||
if "eos" in result and result["eos"]:
|
||||
if stop_event:
|
||||
# Concurrent path for API requests
|
||||
result_queue = self.parallel_generator.submit(job)
|
||||
try:
|
||||
while True:
|
||||
if stop_event.is_set() or shared.stop_everything:
|
||||
break
|
||||
try:
|
||||
result = result_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if result is None or result.get("eos"):
|
||||
break
|
||||
|
||||
chunk = result.get("text", "")
|
||||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
finally:
|
||||
self.parallel_generator.cancel(job)
|
||||
else:
|
||||
# Original single-request path (WebUI)
|
||||
self.generator.enqueue(job)
|
||||
try:
|
||||
while self.generator.num_remaining_jobs():
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
finally:
|
||||
self.generator.clear_queue()
|
||||
results = self.generator.iterate()
|
||||
for result in results:
|
||||
if "eos" in result and result["eos"]:
|
||||
break
|
||||
|
||||
chunk = result.get("text", "")
|
||||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
finally:
|
||||
self.generator.clear_queue()
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
|
|
@ -429,6 +501,13 @@ class Exllamav3Model:
|
|||
def unload(self):
|
||||
logger.info("Unloading ExLlamaV3 model components...")
|
||||
|
||||
if hasattr(self, 'parallel_generator') and self.parallel_generator is not None:
|
||||
try:
|
||||
self.parallel_generator.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping parallel generator: {e}")
|
||||
self.parallel_generator = None
|
||||
|
||||
if hasattr(self, 'vision_model') and self.vision_model is not None:
|
||||
try:
|
||||
del self.vision_model
|
||||
|
|
|
|||
|
|
@ -217,8 +217,9 @@ class LlamaServer:
|
|||
full_text = ""
|
||||
|
||||
# Process the streaming response
|
||||
stop_event = state.get('stop_event')
|
||||
for line in response.iter_lines():
|
||||
if shared.stop_everything:
|
||||
if shared.stop_everything or (stop_event and stop_event.is_set()):
|
||||
break
|
||||
|
||||
if not line:
|
||||
|
|
@ -410,6 +411,7 @@ class LlamaServer:
|
|||
cmd += ["--spec-ngram-size-n", str(shared.args.spec_ngram_size_n)]
|
||||
cmd += ["--spec-ngram-size-m", str(shared.args.spec_ngram_size_m)]
|
||||
cmd += ["--spec-ngram-min-hits", str(shared.args.spec_ngram_min_hits)]
|
||||
cmd += ["--parallel", str(shared.args.parallel)]
|
||||
if shared.args.streaming_llm:
|
||||
cmd += ["--cache-reuse", "1"]
|
||||
cmd += ["--swa-full"]
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ loaders_and_params = OrderedDict({
|
|||
'no_mmap',
|
||||
'mlock',
|
||||
'numa',
|
||||
'parallel',
|
||||
'model_draft',
|
||||
'draft_max',
|
||||
'gpu_layers_draft',
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number
|
|||
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
||||
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
||||
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
||||
group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||
|
||||
# Transformers/Accelerate
|
||||
|
|
|
|||
|
|
@ -22,13 +22,23 @@ def generate_reply(*args, **kwargs):
|
|||
from modules.models import load_model
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
shared.generation_lock.acquire()
|
||||
state = args[1] if len(args) > 1 else kwargs.get('state', {})
|
||||
use_parallel = (
|
||||
state.get('stop_event') is not None
|
||||
and shared.model.__class__.__name__ in ['Exllamav3Model', 'LlamaServer']
|
||||
and (shared.model.__class__.__name__ != 'LlamaServer' or shared.args.parallel > 1)
|
||||
)
|
||||
|
||||
if not use_parallel:
|
||||
shared.generation_lock.acquire()
|
||||
|
||||
try:
|
||||
for result in _generate_reply(*args, **kwargs):
|
||||
yield result
|
||||
finally:
|
||||
models.last_generation_time = time.time()
|
||||
shared.generation_lock.release()
|
||||
if not use_parallel:
|
||||
shared.generation_lock.release()
|
||||
|
||||
|
||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False):
|
||||
|
|
@ -68,7 +78,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
stop_event_ref = state.pop('stop_event', None)
|
||||
state = copy.deepcopy(state)
|
||||
if stop_event_ref is not None:
|
||||
state['stop_event'] = stop_event_ref
|
||||
state['stream'] = True
|
||||
|
||||
# Generate
|
||||
|
|
@ -99,7 +112,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
yield reply
|
||||
last_update = time.monotonic()
|
||||
|
||||
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
|
||||
stop_event = state.get('stop_event')
|
||||
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything) or (stop_event and stop_event.is_set()):
|
||||
break
|
||||
|
||||
if not is_chat:
|
||||
|
|
@ -474,7 +488,10 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N
|
|||
For models that do not use the transformers library for sampling
|
||||
"""
|
||||
|
||||
stop_event_ref = state.pop('stop_event', None)
|
||||
state = copy.deepcopy(state)
|
||||
if stop_event_ref is not None:
|
||||
state['stop_event'] = stop_event_ref
|
||||
state['seed'] = set_manual_seed(state['seed'])
|
||||
t0 = time.time()
|
||||
reply = ''
|
||||
|
|
|
|||
|
|
@ -151,6 +151,7 @@ def list_model_elements():
|
|||
'no_mmap',
|
||||
'mlock',
|
||||
'numa',
|
||||
'parallel',
|
||||
'use_double_quant',
|
||||
'bf16',
|
||||
'enable_tp',
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ def create_ui():
|
|||
with gr.Column():
|
||||
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
|
||||
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
|
||||
shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.')
|
||||
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
||||
shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
|
||||
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue