diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index d70e69e6..04e644d6 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -37,13 +37,13 @@ def load_chat_template_file(filepath): return text -def get_logprobs_from_llama_cpp(): - """Read logprobs captured from the llama.cpp server response.""" +def get_logprobs_from_backend(): + """Read logprobs captured from llama.cpp or ExLlamav3 native backend.""" if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities: return None - # Convert llama.cpp format to {token: logprob} dict - # llama.cpp returns: [{"token": "text", "logprob": -0.5, "top_logprobs": [{"token": "t", "logprob": -0.1}, ...]}, ...] + # Both backends store data in shared.model.last_completion_probabilities + # Format: [{"top_logprobs": [{"token": "text", "logprob": -0.5}, ...]}, ...] result = {} for entry in shared.model.last_completion_probabilities: top = entry.get('top_logprobs', entry.get('top_probs', [])) @@ -90,8 +90,8 @@ def process_parameters(body, is_legacy=False): elif isinstance(body['stop'], list): generate_params['custom_stopping_strings'] = body['stop'] - # For llama.cpp, logit_bias and logprobs are forwarded natively via prepare_payload() - if shared.args.loader != 'llama.cpp': + # For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively + if shared.args.loader not in ('llama.cpp', 'ExLlamav3'): from transformers import LogitsProcessorList from modules.transformers_loader import ( @@ -527,9 +527,9 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e if logprob_proc: completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]} - elif shared.args.loader == 'llama.cpp': - llama_logprobs = get_logprobs_from_llama_cpp() - completion_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + backend_logprobs = get_logprobs_from_backend() + completion_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None else: completion_logprobs = None @@ -576,9 +576,9 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e # begin streaming if logprob_proc: chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]} - elif shared.args.loader == 'llama.cpp': - llama_logprobs = get_logprobs_from_llama_cpp() - chunk_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + backend_logprobs = get_logprobs_from_backend() + chunk_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None else: chunk_logprobs = None diff --git a/modules/exllamav3.py b/modules/exllamav3.py index aeb68564..1c682e49 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -1,3 +1,4 @@ +import math import queue import threading import traceback @@ -9,6 +10,7 @@ import torch from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.generator import Job +from exllamav3.generator.filter import Filter from exllamav3.generator.sampler import ( CustomSampler, SS_AdaptiveP, @@ -36,6 +38,29 @@ except Exception: traceback.print_exc() +class LogitBiasFilter(Filter): + """Filter subclass that applies a static additive logit bias mask.""" + + def __init__(self, tokenizer, logit_bias_dict): + super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False) + self.logit_bias_dict = logit_bias_dict + self._mask = None + + def reset(self): pass + def accept_token(self, token): pass + def is_completed(self): return False + def use_background_worker(self): return False + + def get_next_logit_mask(self): + if self._mask is None: + self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype) + for token_id_str, bias in self.logit_bias_dict.items(): + token_id = int(token_id_str) + if 0 <= token_id < self.vocab_size: + self._mask[0, token_id] = bias + return self._mask + + class ConcurrentGenerator: def __init__(self, generator): self.generator = generator @@ -98,6 +123,10 @@ class Exllamav3Model: def __init__(self): pass + @property + def device(self) -> torch.device: + return torch.device(0) + @classmethod def from_pretrained(cls, path_to_model): path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) @@ -390,6 +419,16 @@ class Exllamav3Model: if eos_id is not None: stop_conditions.append(eos_id) + # Build filters for logit_bias (OpenAI API) + filters = [] + logit_bias = state.get('logit_bias') + if logit_bias: + filters.append(LogitBiasFilter(self.tokenizer, logit_bias)) + + # Logprobs support (OpenAI API) + logprobs = state.get('logprobs', 0) or 0 + return_top_tokens = logprobs if logprobs > 0 else 0 + seed = state.get('seed', -1) job = Job( input_ids=input_ids, @@ -399,11 +438,15 @@ class Exllamav3Model: sampler=sampler, seed=seed if seed >= 0 else None, stop_conditions=stop_conditions if stop_conditions else None, + filters=filters if filters else None, + return_top_tokens=return_top_tokens, + return_probs=return_top_tokens > 0, ) # Stream generation response_text = "" stop_event = state.get('stop_event') + self.last_completion_probabilities = [] result_queue = self.parallel_generator.submit(job) try: @@ -415,14 +458,41 @@ class Exllamav3Model: except queue.Empty: continue if result is None or result.get("eos"): + # Capture logprobs from the final eos result too + if result is not None and return_top_tokens > 0: + self._capture_logprobs(result) break chunk = result.get("text", "") + + # Capture logprobs from streaming results + if return_top_tokens > 0: + self._capture_logprobs(result) + if chunk: response_text += chunk yield response_text finally: self.parallel_generator.cancel(job) + def _capture_logprobs(self, result): + """Convert ExLlamav3 top-k token data to the shared logprobs format.""" + top_k_tokens = result.get("top_k_tokens") + top_k_probs = result.get("top_k_probs") + if top_k_tokens is None or top_k_probs is None: + return + + id_to_piece = self.tokenizer.get_id_to_piece_list(True) + # top_k_tokens shape: (batch, seq_len, k), top_k_probs same + for seq_idx in range(top_k_tokens.shape[1]): + entry = {"top_logprobs": []} + for k_idx in range(top_k_tokens.shape[2]): + token_id = top_k_tokens[0, seq_idx, k_idx].item() + prob = top_k_probs[0, seq_idx, k_idx].item() + token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>" + logprob = math.log(prob) if prob > 0 else float("-inf") + entry["top_logprobs"].append({"token": token_str, "logprob": logprob}) + self.last_completion_probabilities.append(entry) + def generate(self, prompt, state): output = "" for chunk in self.generate_with_streaming(prompt, state):