mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-17 19:04:39 +01:00
Add native logit_bias and logprobs support for ExLlamav3
This commit is contained in:
parent
8aeaa76365
commit
3304b57bdf
|
|
@ -37,13 +37,13 @@ def load_chat_template_file(filepath):
|
|||
return text
|
||||
|
||||
|
||||
def get_logprobs_from_llama_cpp():
|
||||
"""Read logprobs captured from the llama.cpp server response."""
|
||||
def get_logprobs_from_backend():
|
||||
"""Read logprobs captured from llama.cpp or ExLlamav3 native backend."""
|
||||
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
|
||||
return None
|
||||
|
||||
# Convert llama.cpp format to {token: logprob} dict
|
||||
# llama.cpp returns: [{"token": "text", "logprob": -0.5, "top_logprobs": [{"token": "t", "logprob": -0.1}, ...]}, ...]
|
||||
# Both backends store data in shared.model.last_completion_probabilities
|
||||
# Format: [{"top_logprobs": [{"token": "text", "logprob": -0.5}, ...]}, ...]
|
||||
result = {}
|
||||
for entry in shared.model.last_completion_probabilities:
|
||||
top = entry.get('top_logprobs', entry.get('top_probs', []))
|
||||
|
|
@ -90,8 +90,8 @@ def process_parameters(body, is_legacy=False):
|
|||
elif isinstance(body['stop'], list):
|
||||
generate_params['custom_stopping_strings'] = body['stop']
|
||||
|
||||
# For llama.cpp, logit_bias and logprobs are forwarded natively via prepare_payload()
|
||||
if shared.args.loader != 'llama.cpp':
|
||||
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
|
||||
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
|
||||
from transformers import LogitsProcessorList
|
||||
|
||||
from modules.transformers_loader import (
|
||||
|
|
@ -527,9 +527,9 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
|
||||
if logprob_proc:
|
||||
completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
|
||||
elif shared.args.loader == 'llama.cpp':
|
||||
llama_logprobs = get_logprobs_from_llama_cpp()
|
||||
completion_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
backend_logprobs = get_logprobs_from_backend()
|
||||
completion_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None
|
||||
else:
|
||||
completion_logprobs = None
|
||||
|
||||
|
|
@ -576,9 +576,9 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
# begin streaming
|
||||
if logprob_proc:
|
||||
chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
|
||||
elif shared.args.loader == 'llama.cpp':
|
||||
llama_logprobs = get_logprobs_from_llama_cpp()
|
||||
chunk_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
backend_logprobs = get_logprobs_from_backend()
|
||||
chunk_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None
|
||||
else:
|
||||
chunk_logprobs = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import queue
|
||||
import threading
|
||||
import traceback
|
||||
|
|
@ -9,6 +10,7 @@ import torch
|
|||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from exllamav3.generator import Job
|
||||
from exllamav3.generator.filter import Filter
|
||||
from exllamav3.generator.sampler import (
|
||||
CustomSampler,
|
||||
SS_AdaptiveP,
|
||||
|
|
@ -36,6 +38,29 @@ except Exception:
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
class LogitBiasFilter(Filter):
|
||||
"""Filter subclass that applies a static additive logit bias mask."""
|
||||
|
||||
def __init__(self, tokenizer, logit_bias_dict):
|
||||
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
|
||||
self.logit_bias_dict = logit_bias_dict
|
||||
self._mask = None
|
||||
|
||||
def reset(self): pass
|
||||
def accept_token(self, token): pass
|
||||
def is_completed(self): return False
|
||||
def use_background_worker(self): return False
|
||||
|
||||
def get_next_logit_mask(self):
|
||||
if self._mask is None:
|
||||
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
|
||||
for token_id_str, bias in self.logit_bias_dict.items():
|
||||
token_id = int(token_id_str)
|
||||
if 0 <= token_id < self.vocab_size:
|
||||
self._mask[0, token_id] = bias
|
||||
return self._mask
|
||||
|
||||
|
||||
class ConcurrentGenerator:
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
|
|
@ -98,6 +123,10 @@ class Exllamav3Model:
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(0)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_to_model):
|
||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||
|
|
@ -390,6 +419,16 @@ class Exllamav3Model:
|
|||
if eos_id is not None:
|
||||
stop_conditions.append(eos_id)
|
||||
|
||||
# Build filters for logit_bias (OpenAI API)
|
||||
filters = []
|
||||
logit_bias = state.get('logit_bias')
|
||||
if logit_bias:
|
||||
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
|
||||
|
||||
# Logprobs support (OpenAI API)
|
||||
logprobs = state.get('logprobs', 0) or 0
|
||||
return_top_tokens = logprobs if logprobs > 0 else 0
|
||||
|
||||
seed = state.get('seed', -1)
|
||||
job = Job(
|
||||
input_ids=input_ids,
|
||||
|
|
@ -399,11 +438,15 @@ class Exllamav3Model:
|
|||
sampler=sampler,
|
||||
seed=seed if seed >= 0 else None,
|
||||
stop_conditions=stop_conditions if stop_conditions else None,
|
||||
filters=filters if filters else None,
|
||||
return_top_tokens=return_top_tokens,
|
||||
return_probs=return_top_tokens > 0,
|
||||
)
|
||||
|
||||
# Stream generation
|
||||
response_text = ""
|
||||
stop_event = state.get('stop_event')
|
||||
self.last_completion_probabilities = []
|
||||
|
||||
result_queue = self.parallel_generator.submit(job)
|
||||
try:
|
||||
|
|
@ -415,14 +458,41 @@ class Exllamav3Model:
|
|||
except queue.Empty:
|
||||
continue
|
||||
if result is None or result.get("eos"):
|
||||
# Capture logprobs from the final eos result too
|
||||
if result is not None and return_top_tokens > 0:
|
||||
self._capture_logprobs(result)
|
||||
break
|
||||
chunk = result.get("text", "")
|
||||
|
||||
# Capture logprobs from streaming results
|
||||
if return_top_tokens > 0:
|
||||
self._capture_logprobs(result)
|
||||
|
||||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
finally:
|
||||
self.parallel_generator.cancel(job)
|
||||
|
||||
def _capture_logprobs(self, result):
|
||||
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
|
||||
top_k_tokens = result.get("top_k_tokens")
|
||||
top_k_probs = result.get("top_k_probs")
|
||||
if top_k_tokens is None or top_k_probs is None:
|
||||
return
|
||||
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
|
||||
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
|
||||
for seq_idx in range(top_k_tokens.shape[1]):
|
||||
entry = {"top_logprobs": []}
|
||||
for k_idx in range(top_k_tokens.shape[2]):
|
||||
token_id = top_k_tokens[0, seq_idx, k_idx].item()
|
||||
prob = top_k_probs[0, seq_idx, k_idx].item()
|
||||
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
|
||||
logprob = math.log(prob) if prob > 0 else float("-inf")
|
||||
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
|
||||
self.last_completion_probabilities.append(entry)
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
for chunk in self.generate_with_streaming(prompt, state):
|
||||
|
|
|
|||
Loading…
Reference in a new issue