mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-08 16:13:41 +00:00
Add native logit_bias and logprobs support for ExLlamav3
This commit is contained in:
parent
8aeaa76365
commit
3304b57bdf
2 changed files with 82 additions and 12 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import queue
|
||||
import threading
|
||||
import traceback
|
||||
|
|
@ -9,6 +10,7 @@ import torch
|
|||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from exllamav3.generator import Job
|
||||
from exllamav3.generator.filter import Filter
|
||||
from exllamav3.generator.sampler import (
|
||||
CustomSampler,
|
||||
SS_AdaptiveP,
|
||||
|
|
@ -36,6 +38,29 @@ except Exception:
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
class LogitBiasFilter(Filter):
|
||||
"""Filter subclass that applies a static additive logit bias mask."""
|
||||
|
||||
def __init__(self, tokenizer, logit_bias_dict):
|
||||
super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False)
|
||||
self.logit_bias_dict = logit_bias_dict
|
||||
self._mask = None
|
||||
|
||||
def reset(self): pass
|
||||
def accept_token(self, token): pass
|
||||
def is_completed(self): return False
|
||||
def use_background_worker(self): return False
|
||||
|
||||
def get_next_logit_mask(self):
|
||||
if self._mask is None:
|
||||
self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype)
|
||||
for token_id_str, bias in self.logit_bias_dict.items():
|
||||
token_id = int(token_id_str)
|
||||
if 0 <= token_id < self.vocab_size:
|
||||
self._mask[0, token_id] = bias
|
||||
return self._mask
|
||||
|
||||
|
||||
class ConcurrentGenerator:
|
||||
def __init__(self, generator):
|
||||
self.generator = generator
|
||||
|
|
@ -98,6 +123,10 @@ class Exllamav3Model:
|
|||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(0)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_to_model):
|
||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||
|
|
@ -390,6 +419,16 @@ class Exllamav3Model:
|
|||
if eos_id is not None:
|
||||
stop_conditions.append(eos_id)
|
||||
|
||||
# Build filters for logit_bias (OpenAI API)
|
||||
filters = []
|
||||
logit_bias = state.get('logit_bias')
|
||||
if logit_bias:
|
||||
filters.append(LogitBiasFilter(self.tokenizer, logit_bias))
|
||||
|
||||
# Logprobs support (OpenAI API)
|
||||
logprobs = state.get('logprobs', 0) or 0
|
||||
return_top_tokens = logprobs if logprobs > 0 else 0
|
||||
|
||||
seed = state.get('seed', -1)
|
||||
job = Job(
|
||||
input_ids=input_ids,
|
||||
|
|
@ -399,11 +438,15 @@ class Exllamav3Model:
|
|||
sampler=sampler,
|
||||
seed=seed if seed >= 0 else None,
|
||||
stop_conditions=stop_conditions if stop_conditions else None,
|
||||
filters=filters if filters else None,
|
||||
return_top_tokens=return_top_tokens,
|
||||
return_probs=return_top_tokens > 0,
|
||||
)
|
||||
|
||||
# Stream generation
|
||||
response_text = ""
|
||||
stop_event = state.get('stop_event')
|
||||
self.last_completion_probabilities = []
|
||||
|
||||
result_queue = self.parallel_generator.submit(job)
|
||||
try:
|
||||
|
|
@ -415,14 +458,41 @@ class Exllamav3Model:
|
|||
except queue.Empty:
|
||||
continue
|
||||
if result is None or result.get("eos"):
|
||||
# Capture logprobs from the final eos result too
|
||||
if result is not None and return_top_tokens > 0:
|
||||
self._capture_logprobs(result)
|
||||
break
|
||||
chunk = result.get("text", "")
|
||||
|
||||
# Capture logprobs from streaming results
|
||||
if return_top_tokens > 0:
|
||||
self._capture_logprobs(result)
|
||||
|
||||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
finally:
|
||||
self.parallel_generator.cancel(job)
|
||||
|
||||
def _capture_logprobs(self, result):
|
||||
"""Convert ExLlamav3 top-k token data to the shared logprobs format."""
|
||||
top_k_tokens = result.get("top_k_tokens")
|
||||
top_k_probs = result.get("top_k_probs")
|
||||
if top_k_tokens is None or top_k_probs is None:
|
||||
return
|
||||
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
|
||||
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
|
||||
for seq_idx in range(top_k_tokens.shape[1]):
|
||||
entry = {"top_logprobs": []}
|
||||
for k_idx in range(top_k_tokens.shape[2]):
|
||||
token_id = top_k_tokens[0, seq_idx, k_idx].item()
|
||||
prob = top_k_probs[0, seq_idx, k_idx].item()
|
||||
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
|
||||
logprob = math.log(prob) if prob > 0 else float("-inf")
|
||||
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
|
||||
self.last_completion_probabilities.append(entry)
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
for chunk in self.generate_with_streaming(prompt, state):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue