mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
exllamav3: Implement the logits function for /v1/internal/logits
This commit is contained in:
parent
218dc01b51
commit
7f06aec3a1
|
|
@ -2,6 +2,8 @@ import traceback
|
|||
from pathlib import Path
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from exllamav3.generator import Job
|
||||
|
|
@ -16,7 +18,6 @@ from exllamav3.generator.sampler import (
|
|||
SS_TopK,
|
||||
SS_TopP
|
||||
)
|
||||
|
||||
from modules import shared
|
||||
from modules.image_utils import (
|
||||
convert_image_attachments_to_pil,
|
||||
|
|
@ -171,7 +172,7 @@ class Exllamav3Model:
|
|||
result.draft_model = draft_model
|
||||
result.draft_cache = draft_cache
|
||||
|
||||
return result
|
||||
return result, result
|
||||
|
||||
def is_multimodal(self) -> bool:
|
||||
"""Check if this model supports multimodal input."""
|
||||
|
|
@ -367,11 +368,51 @@ class Exllamav3Model:
|
|||
|
||||
return output
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
"""
|
||||
Process a batch of token_ids and return the logits for the last token.
|
||||
This will reset and overwrite the model's cache.
|
||||
"""
|
||||
# Initialize a single params dictionary that will be updated in-place
|
||||
params = {
|
||||
"cache": self.cache,
|
||||
"reconstruct": False,
|
||||
"attn_mode": "flash_attn",
|
||||
"batch_shape": (1, self.max_tokens),
|
||||
"past_len": 0
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
# Process prefix tokens to fill the cache and generate recurrent state
|
||||
if token_ids.shape[-1] > 1:
|
||||
prefix_ids = token_ids[:, :-1]
|
||||
|
||||
# This forward call updates the 'params' dict with the recurrent state
|
||||
self.model.forward(
|
||||
input_ids=prefix_ids,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Update past_len for the next call
|
||||
params["past_len"] = prefix_ids.shape[-1]
|
||||
|
||||
# Process the last token, now using the state-filled 'params' dict
|
||||
last_token_ids = token_ids[:, -1:]
|
||||
logits = self.model.forward(
|
||||
input_ids=last_token_ids,
|
||||
params=params
|
||||
)
|
||||
|
||||
return logits.float().cpu()
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
add_bos = kwargs.pop('add_bos', True)
|
||||
return self.tokenizer.encode(string, add_bos=add_bos, **kwargs)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, torch.Tensor) and ids.dim() == 0:
|
||||
ids = ids.view(1)
|
||||
|
||||
return self.tokenizer.decode(ids, **kwargs)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
|||
from modules.torch_utils import get_device
|
||||
|
||||
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
|
||||
is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model'
|
||||
|
||||
if not use_samplers:
|
||||
state = {'stream': True}
|
||||
|
|
@ -88,7 +89,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
|||
|
||||
scores = sampler_hijack.global_scores[-1]
|
||||
else:
|
||||
if is_non_hf_exllamav2:
|
||||
if is_non_hf_exllamav2 or is_non_hf_exllamav3:
|
||||
device = get_device()
|
||||
tokens = shared.tokenizer.encode(prompt)
|
||||
if device:
|
||||
|
|
|
|||
|
|
@ -104,8 +104,7 @@ def ExLlamav3_HF_loader(model_name):
|
|||
def ExLlamav3_loader(model_name):
|
||||
from modules.exllamav3 import Exllamav3Model
|
||||
|
||||
model = Exllamav3Model.from_pretrained(model_name)
|
||||
tokenizer = model.tokenizer
|
||||
model, tokenizer = Exllamav3Model.from_pretrained(model_name)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue