mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
exllamav3: Implement the logits function for /v1/internal/logits
This commit is contained in:
parent
218dc01b51
commit
7f06aec3a1
|
|
@ -2,6 +2,8 @@ import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||||
from exllamav3.generator import Job
|
from exllamav3.generator import Job
|
||||||
|
|
@ -16,7 +18,6 @@ from exllamav3.generator.sampler import (
|
||||||
SS_TopK,
|
SS_TopK,
|
||||||
SS_TopP
|
SS_TopP
|
||||||
)
|
)
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.image_utils import (
|
from modules.image_utils import (
|
||||||
convert_image_attachments_to_pil,
|
convert_image_attachments_to_pil,
|
||||||
|
|
@ -171,7 +172,7 @@ class Exllamav3Model:
|
||||||
result.draft_model = draft_model
|
result.draft_model = draft_model
|
||||||
result.draft_cache = draft_cache
|
result.draft_cache = draft_cache
|
||||||
|
|
||||||
return result
|
return result, result
|
||||||
|
|
||||||
def is_multimodal(self) -> bool:
|
def is_multimodal(self) -> bool:
|
||||||
"""Check if this model supports multimodal input."""
|
"""Check if this model supports multimodal input."""
|
||||||
|
|
@ -367,11 +368,51 @@ class Exllamav3Model:
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def get_logits(self, token_ids, **kwargs):
|
||||||
|
"""
|
||||||
|
Process a batch of token_ids and return the logits for the last token.
|
||||||
|
This will reset and overwrite the model's cache.
|
||||||
|
"""
|
||||||
|
# Initialize a single params dictionary that will be updated in-place
|
||||||
|
params = {
|
||||||
|
"cache": self.cache,
|
||||||
|
"reconstruct": False,
|
||||||
|
"attn_mode": "flash_attn",
|
||||||
|
"batch_shape": (1, self.max_tokens),
|
||||||
|
"past_len": 0
|
||||||
|
}
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
# Process prefix tokens to fill the cache and generate recurrent state
|
||||||
|
if token_ids.shape[-1] > 1:
|
||||||
|
prefix_ids = token_ids[:, :-1]
|
||||||
|
|
||||||
|
# This forward call updates the 'params' dict with the recurrent state
|
||||||
|
self.model.forward(
|
||||||
|
input_ids=prefix_ids,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update past_len for the next call
|
||||||
|
params["past_len"] = prefix_ids.shape[-1]
|
||||||
|
|
||||||
|
# Process the last token, now using the state-filled 'params' dict
|
||||||
|
last_token_ids = token_ids[:, -1:]
|
||||||
|
logits = self.model.forward(
|
||||||
|
input_ids=last_token_ids,
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits.float().cpu()
|
||||||
|
|
||||||
def encode(self, string, **kwargs):
|
def encode(self, string, **kwargs):
|
||||||
add_bos = kwargs.pop('add_bos', True)
|
add_bos = kwargs.pop('add_bos', True)
|
||||||
return self.tokenizer.encode(string, add_bos=add_bos, **kwargs)
|
return self.tokenizer.encode(string, add_bos=add_bos, **kwargs)
|
||||||
|
|
||||||
def decode(self, ids, **kwargs):
|
def decode(self, ids, **kwargs):
|
||||||
|
if isinstance(ids, torch.Tensor) and ids.dim() == 0:
|
||||||
|
ids = ids.view(1)
|
||||||
|
|
||||||
return self.tokenizer.decode(ids, **kwargs)
|
return self.tokenizer.decode(ids, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||||
from modules.torch_utils import get_device
|
from modules.torch_utils import get_device
|
||||||
|
|
||||||
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
|
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
|
||||||
|
is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model'
|
||||||
|
|
||||||
if not use_samplers:
|
if not use_samplers:
|
||||||
state = {'stream': True}
|
state = {'stream': True}
|
||||||
|
|
@ -88,7 +89,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||||
|
|
||||||
scores = sampler_hijack.global_scores[-1]
|
scores = sampler_hijack.global_scores[-1]
|
||||||
else:
|
else:
|
||||||
if is_non_hf_exllamav2:
|
if is_non_hf_exllamav2 or is_non_hf_exllamav3:
|
||||||
device = get_device()
|
device = get_device()
|
||||||
tokens = shared.tokenizer.encode(prompt)
|
tokens = shared.tokenizer.encode(prompt)
|
||||||
if device:
|
if device:
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,7 @@ def ExLlamav3_HF_loader(model_name):
|
||||||
def ExLlamav3_loader(model_name):
|
def ExLlamav3_loader(model_name):
|
||||||
from modules.exllamav3 import Exllamav3Model
|
from modules.exllamav3 import Exllamav3Model
|
||||||
|
|
||||||
model = Exllamav3Model.from_pretrained(model_name)
|
model, tokenizer = Exllamav3Model.from_pretrained(model_name)
|
||||||
tokenizer = model.tokenizer
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue