exllamav3: Implement the logits function for /v1/internal/logits

This commit is contained in:
oobabooga 2025-10-09 11:24:25 -07:00
parent 218dc01b51
commit 7f06aec3a1
3 changed files with 46 additions and 5 deletions

View file

@ -2,6 +2,8 @@ import traceback
from pathlib import Path from pathlib import Path
from typing import Any, List, Tuple from typing import Any, List, Tuple
import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job from exllamav3.generator import Job
@ -16,7 +18,6 @@ from exllamav3.generator.sampler import (
SS_TopK, SS_TopK,
SS_TopP SS_TopP
) )
from modules import shared from modules import shared
from modules.image_utils import ( from modules.image_utils import (
convert_image_attachments_to_pil, convert_image_attachments_to_pil,
@ -171,7 +172,7 @@ class Exllamav3Model:
result.draft_model = draft_model result.draft_model = draft_model
result.draft_cache = draft_cache result.draft_cache = draft_cache
return result return result, result
def is_multimodal(self) -> bool: def is_multimodal(self) -> bool:
"""Check if this model supports multimodal input.""" """Check if this model supports multimodal input."""
@ -367,11 +368,51 @@ class Exllamav3Model:
return output return output
def get_logits(self, token_ids, **kwargs):
"""
Process a batch of token_ids and return the logits for the last token.
This will reset and overwrite the model's cache.
"""
# Initialize a single params dictionary that will be updated in-place
params = {
"cache": self.cache,
"reconstruct": False,
"attn_mode": "flash_attn",
"batch_shape": (1, self.max_tokens),
"past_len": 0
}
params.update(kwargs)
# Process prefix tokens to fill the cache and generate recurrent state
if token_ids.shape[-1] > 1:
prefix_ids = token_ids[:, :-1]
# This forward call updates the 'params' dict with the recurrent state
self.model.forward(
input_ids=prefix_ids,
params=params
)
# Update past_len for the next call
params["past_len"] = prefix_ids.shape[-1]
# Process the last token, now using the state-filled 'params' dict
last_token_ids = token_ids[:, -1:]
logits = self.model.forward(
input_ids=last_token_ids,
params=params
)
return logits.float().cpu()
def encode(self, string, **kwargs): def encode(self, string, **kwargs):
add_bos = kwargs.pop('add_bos', True) add_bos = kwargs.pop('add_bos', True)
return self.tokenizer.encode(string, add_bos=add_bos, **kwargs) return self.tokenizer.encode(string, add_bos=add_bos, **kwargs)
def decode(self, ids, **kwargs): def decode(self, ids, **kwargs):
if isinstance(ids, torch.Tensor) and ids.dim() == 0:
ids = ids.view(1)
return self.tokenizer.decode(ids, **kwargs) return self.tokenizer.decode(ids, **kwargs)
@property @property

View file

@ -71,6 +71,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
from modules.torch_utils import get_device from modules.torch_utils import get_device
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model' is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model'
if not use_samplers: if not use_samplers:
state = {'stream': True} state = {'stream': True}
@ -88,7 +89,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
scores = sampler_hijack.global_scores[-1] scores = sampler_hijack.global_scores[-1]
else: else:
if is_non_hf_exllamav2: if is_non_hf_exllamav2 or is_non_hf_exllamav3:
device = get_device() device = get_device()
tokens = shared.tokenizer.encode(prompt) tokens = shared.tokenizer.encode(prompt)
if device: if device:

View file

@ -104,8 +104,7 @@ def ExLlamav3_HF_loader(model_name):
def ExLlamav3_loader(model_name): def ExLlamav3_loader(model_name):
from modules.exllamav3 import Exllamav3Model from modules.exllamav3 import Exllamav3Model
model = Exllamav3Model.from_pretrained(model_name) model, tokenizer = Exllamav3Model.from_pretrained(model_name)
tokenizer = model.tokenizer
return model, tokenizer return model, tokenizer