From 7f06aec3a1fc2e6924d87035483ce10ce65af058 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 9 Oct 2025 11:24:25 -0700 Subject: [PATCH] exllamav3: Implement the logits function for /v1/internal/logits --- modules/exllamav3.py | 45 ++++++++++++++++++++++++++++++++++++++++++-- modules/logits.py | 3 ++- modules/models.py | 3 +-- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/modules/exllamav3.py b/modules/exllamav3.py index f7078028..d884bbf7 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -2,6 +2,8 @@ import traceback from pathlib import Path from typing import Any, List, Tuple +import torch + from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.generator import Job @@ -16,7 +18,6 @@ from exllamav3.generator.sampler import ( SS_TopK, SS_TopP ) - from modules import shared from modules.image_utils import ( convert_image_attachments_to_pil, @@ -171,7 +172,7 @@ class Exllamav3Model: result.draft_model = draft_model result.draft_cache = draft_cache - return result + return result, result def is_multimodal(self) -> bool: """Check if this model supports multimodal input.""" @@ -367,11 +368,51 @@ class Exllamav3Model: return output + def get_logits(self, token_ids, **kwargs): + """ + Process a batch of token_ids and return the logits for the last token. + This will reset and overwrite the model's cache. + """ + # Initialize a single params dictionary that will be updated in-place + params = { + "cache": self.cache, + "reconstruct": False, + "attn_mode": "flash_attn", + "batch_shape": (1, self.max_tokens), + "past_len": 0 + } + params.update(kwargs) + + # Process prefix tokens to fill the cache and generate recurrent state + if token_ids.shape[-1] > 1: + prefix_ids = token_ids[:, :-1] + + # This forward call updates the 'params' dict with the recurrent state + self.model.forward( + input_ids=prefix_ids, + params=params + ) + + # Update past_len for the next call + params["past_len"] = prefix_ids.shape[-1] + + # Process the last token, now using the state-filled 'params' dict + last_token_ids = token_ids[:, -1:] + logits = self.model.forward( + input_ids=last_token_ids, + params=params + ) + + return logits.float().cpu() + def encode(self, string, **kwargs): add_bos = kwargs.pop('add_bos', True) return self.tokenizer.encode(string, add_bos=add_bos, **kwargs) def decode(self, ids, **kwargs): + if isinstance(ids, torch.Tensor) and ids.dim() == 0: + ids = ids.view(1) + return self.tokenizer.decode(ids, **kwargs) @property diff --git a/modules/logits.py b/modules/logits.py index 56a20572..d668e44e 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -71,6 +71,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur from modules.torch_utils import get_device is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model' + is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model' if not use_samplers: state = {'stream': True} @@ -88,7 +89,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur scores = sampler_hijack.global_scores[-1] else: - if is_non_hf_exllamav2: + if is_non_hf_exllamav2 or is_non_hf_exllamav3: device = get_device() tokens = shared.tokenizer.encode(prompt) if device: diff --git a/modules/models.py b/modules/models.py index 9535ea82..8c0f1c37 100644 --- a/modules/models.py +++ b/modules/models.py @@ -104,8 +104,7 @@ def ExLlamav3_HF_loader(model_name): def ExLlamav3_loader(model_name): from modules.exllamav3 import Exllamav3Model - model = Exllamav3Model.from_pretrained(model_name) - tokenizer = model.tokenizer + model, tokenizer = Exllamav3Model.from_pretrained(model_name) return model, tokenizer