exllamav3: Implement the logits function for /v1/internal/logits

This commit is contained in:
oobabooga 2025-10-09 11:24:25 -07:00
parent 218dc01b51
commit 7f06aec3a1
3 changed files with 46 additions and 5 deletions

View file

@ -2,6 +2,8 @@ import traceback
from pathlib import Path
from typing import Any, List, Tuple
import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job
@ -16,7 +18,6 @@ from exllamav3.generator.sampler import (
SS_TopK,
SS_TopP
)
from modules import shared
from modules.image_utils import (
convert_image_attachments_to_pil,
@ -171,7 +172,7 @@ class Exllamav3Model:
result.draft_model = draft_model
result.draft_cache = draft_cache
return result
return result, result
def is_multimodal(self) -> bool:
"""Check if this model supports multimodal input."""
@ -367,11 +368,51 @@ class Exllamav3Model:
return output
def get_logits(self, token_ids, **kwargs):
"""
Process a batch of token_ids and return the logits for the last token.
This will reset and overwrite the model's cache.
"""
# Initialize a single params dictionary that will be updated in-place
params = {
"cache": self.cache,
"reconstruct": False,
"attn_mode": "flash_attn",
"batch_shape": (1, self.max_tokens),
"past_len": 0
}
params.update(kwargs)
# Process prefix tokens to fill the cache and generate recurrent state
if token_ids.shape[-1] > 1:
prefix_ids = token_ids[:, :-1]
# This forward call updates the 'params' dict with the recurrent state
self.model.forward(
input_ids=prefix_ids,
params=params
)
# Update past_len for the next call
params["past_len"] = prefix_ids.shape[-1]
# Process the last token, now using the state-filled 'params' dict
last_token_ids = token_ids[:, -1:]
logits = self.model.forward(
input_ids=last_token_ids,
params=params
)
return logits.float().cpu()
def encode(self, string, **kwargs):
add_bos = kwargs.pop('add_bos', True)
return self.tokenizer.encode(string, add_bos=add_bos, **kwargs)
def decode(self, ids, **kwargs):
if isinstance(ids, torch.Tensor) and ids.dim() == 0:
ids = ids.view(1)
return self.tokenizer.decode(ids, **kwargs)
@property

View file

@ -71,6 +71,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
from modules.torch_utils import get_device
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model'
if not use_samplers:
state = {'stream': True}
@ -88,7 +89,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
scores = sampler_hijack.global_scores[-1]
else:
if is_non_hf_exllamav2:
if is_non_hf_exllamav2 or is_non_hf_exllamav3:
device = get_device()
tokens = shared.tokenizer.encode(prompt)
if device:

View file

@ -104,8 +104,7 @@ def ExLlamav3_HF_loader(model_name):
def ExLlamav3_loader(model_name):
from modules.exllamav3 import Exllamav3Model
model = Exllamav3Model.from_pretrained(model_name)
tokenizer = model.tokenizer
model, tokenizer = Exllamav3Model.from_pretrained(model_name)
return model, tokenizer