Add native logit_bias and logprobs support for ExLlamav3

This commit is contained in:
oobabooga 2026-03-10 11:03:00 -03:00
parent 8aeaa76365
commit 3304b57bdf
2 changed files with 82 additions and 12 deletions

View file

@ -37,13 +37,13 @@ def load_chat_template_file(filepath):
return text
def get_logprobs_from_llama_cpp():
"""Read logprobs captured from the llama.cpp server response."""
def get_logprobs_from_backend():
"""Read logprobs captured from llama.cpp or ExLlamav3 native backend."""
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
return None
# Convert llama.cpp format to {token: logprob} dict
# llama.cpp returns: [{"token": "text", "logprob": -0.5, "top_logprobs": [{"token": "t", "logprob": -0.1}, ...]}, ...]
# Both backends store data in shared.model.last_completion_probabilities
# Format: [{"top_logprobs": [{"token": "text", "logprob": -0.5}, ...]}, ...]
result = {}
for entry in shared.model.last_completion_probabilities:
top = entry.get('top_logprobs', entry.get('top_probs', []))
@ -90,8 +90,8 @@ def process_parameters(body, is_legacy=False):
elif isinstance(body['stop'], list):
generate_params['custom_stopping_strings'] = body['stop']
# For llama.cpp, logit_bias and logprobs are forwarded natively via prepare_payload()
if shared.args.loader != 'llama.cpp':
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
from transformers import LogitsProcessorList
from modules.transformers_loader import (
@ -527,9 +527,9 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
if logprob_proc:
completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
elif shared.args.loader == 'llama.cpp':
llama_logprobs = get_logprobs_from_llama_cpp()
completion_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
backend_logprobs = get_logprobs_from_backend()
completion_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None
else:
completion_logprobs = None
@ -576,9 +576,9 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
# begin streaming
if logprob_proc:
chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
elif shared.args.loader == 'llama.cpp':
llama_logprobs = get_logprobs_from_llama_cpp()
chunk_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
backend_logprobs = get_logprobs_from_backend()
chunk_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None
else:
chunk_logprobs = None

View file

@ -1,3 +1,4 @@
import math
import queue
import threading
import traceback
@ -9,6 +10,7 @@ import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job
from exllamav3.generator.filter import Filter
from exllamav3.generator.sampler import (
CustomSampler,
SS_AdaptiveP,
@ -36,6 +38,29 @@ except Exception:
traceback.print_exc()
class LogitBiasFilter(Filter):
"""Filter subclass that applies a static additive logit bias mask."""
def __init__(self, tokenizer, logit_bias_dict):
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
self.logit_bias_dict = logit_bias_dict
self._mask = None
def reset(self): pass
def accept_token(self, token): pass
def is_completed(self): return False
def use_background_worker(self): return False
def get_next_logit_mask(self):
if self._mask is None:
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
for token_id_str, bias in self.logit_bias_dict.items():
token_id = int(token_id_str)
if 0 <= token_id < self.vocab_size:
self._mask[0, token_id] = bias
return self._mask
class ConcurrentGenerator:
def __init__(self, generator):
self.generator = generator
@ -98,6 +123,10 @@ class Exllamav3Model:
def __init__(self):
pass
@property
def device(self) -> torch.device:
return torch.device(0)
@classmethod
def from_pretrained(cls, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
@ -390,6 +419,16 @@ class Exllamav3Model:
if eos_id is not None:
stop_conditions.append(eos_id)
# Build filters for logit_bias (OpenAI API)
filters = []
logit_bias = state.get('logit_bias')
if logit_bias:
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
# Logprobs support (OpenAI API)
logprobs = state.get('logprobs', 0) or 0
return_top_tokens = logprobs if logprobs > 0 else 0
seed = state.get('seed', -1)
job = Job(
input_ids=input_ids,
@ -399,11 +438,15 @@ class Exllamav3Model:
sampler=sampler,
seed=seed if seed >= 0 else None,
stop_conditions=stop_conditions if stop_conditions else None,
filters=filters if filters else None,
return_top_tokens=return_top_tokens,
return_probs=return_top_tokens > 0,
)
# Stream generation
response_text = ""
stop_event = state.get('stop_event')
self.last_completion_probabilities = []
result_queue = self.parallel_generator.submit(job)
try:
@ -415,14 +458,41 @@ class Exllamav3Model:
except queue.Empty:
continue
if result is None or result.get("eos"):
# Capture logprobs from the final eos result too
if result is not None and return_top_tokens > 0:
self._capture_logprobs(result)
break
chunk = result.get("text", "")
# Capture logprobs from streaming results
if return_top_tokens > 0:
self._capture_logprobs(result)
if chunk:
response_text += chunk
yield response_text
finally:
self.parallel_generator.cancel(job)
def _capture_logprobs(self, result):
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
top_k_tokens = result.get("top_k_tokens")
top_k_probs = result.get("top_k_probs")
if top_k_tokens is None or top_k_probs is None:
return
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
for seq_idx in range(top_k_tokens.shape[1]):
entry = {"top_logprobs": []}
for k_idx in range(top_k_tokens.shape[2]):
token_id = top_k_tokens[0, seq_idx, k_idx].item()
prob = top_k_probs[0, seq_idx, k_idx].item()
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
logprob = math.log(prob) if prob > 0 else float("-inf")
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
self.last_completion_probabilities.append(entry)
def generate(self, prompt, state):
output = ""
for chunk in self.generate_with_streaming(prompt, state):