Add native logit_bias and logprobs support for ExLlamav3

This commit is contained in:
oobabooga 2026-03-10 11:03:00 -03:00
parent 8aeaa76365
commit 3304b57bdf
2 changed files with 82 additions and 12 deletions

View file

@ -1,3 +1,4 @@
import math
import queue
import threading
import traceback
@ -9,6 +10,7 @@ import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job
from exllamav3.generator.filter import Filter
from exllamav3.generator.sampler import (
CustomSampler,
SS_AdaptiveP,
@ -36,6 +38,29 @@ except Exception:
traceback.print_exc()
class LogitBiasFilter(Filter):
"""Filter subclass that applies a static additive logit bias mask."""
def __init__(self, tokenizer, logit_bias_dict):
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
self.logit_bias_dict = logit_bias_dict
self._mask = None
def reset(self): pass
def accept_token(self, token): pass
def is_completed(self): return False
def use_background_worker(self): return False
def get_next_logit_mask(self):
if self._mask is None:
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
for token_id_str, bias in self.logit_bias_dict.items():
token_id = int(token_id_str)
if 0 <= token_id < self.vocab_size:
self._mask[0, token_id] = bias
return self._mask
class ConcurrentGenerator:
def __init__(self, generator):
self.generator = generator
@ -98,6 +123,10 @@ class Exllamav3Model:
def __init__(self):
pass
@property
def device(self) -> torch.device:
return torch.device(0)
@classmethod
def from_pretrained(cls, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
@ -390,6 +419,16 @@ class Exllamav3Model:
if eos_id is not None:
stop_conditions.append(eos_id)
# Build filters for logit_bias (OpenAI API)
filters = []
logit_bias = state.get('logit_bias')
if logit_bias:
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
# Logprobs support (OpenAI API)
logprobs = state.get('logprobs', 0) or 0
return_top_tokens = logprobs if logprobs > 0 else 0
seed = state.get('seed', -1)
job = Job(
input_ids=input_ids,
@ -399,11 +438,15 @@ class Exllamav3Model:
sampler=sampler,
seed=seed if seed >= 0 else None,
stop_conditions=stop_conditions if stop_conditions else None,
filters=filters if filters else None,
return_top_tokens=return_top_tokens,
return_probs=return_top_tokens > 0,
)
# Stream generation
response_text = ""
stop_event = state.get('stop_event')
self.last_completion_probabilities = []
result_queue = self.parallel_generator.submit(job)
try:
@ -415,14 +458,41 @@ class Exllamav3Model:
except queue.Empty:
continue
if result is None or result.get("eos"):
# Capture logprobs from the final eos result too
if result is not None and return_top_tokens > 0:
self._capture_logprobs(result)
break
chunk = result.get("text", "")
# Capture logprobs from streaming results
if return_top_tokens > 0:
self._capture_logprobs(result)
if chunk:
response_text += chunk
yield response_text
finally:
self.parallel_generator.cancel(job)
def _capture_logprobs(self, result):
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
top_k_tokens = result.get("top_k_tokens")
top_k_probs = result.get("top_k_probs")
if top_k_tokens is None or top_k_probs is None:
return
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
for seq_idx in range(top_k_tokens.shape[1]):
entry = {"top_logprobs": []}
for k_idx in range(top_k_tokens.shape[2]):
token_id = top_k_tokens[0, seq_idx, k_idx].item()
prob = top_k_probs[0, seq_idx, k_idx].item()
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
logprob = math.log(prob) if prob > 0 else float("-inf")
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
self.last_completion_probabilities.append(entry)
def generate(self, prompt, state):
output = ""
for chunk in self.generate_with_streaming(prompt, state):