mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 15:13:38 +00:00
Fix ExLlamav3 OOM on prompt logprobs and qwen3_5_moe HF compat
This commit is contained in:
parent
328534b762
commit
4073164be0
3 changed files with 17 additions and 61 deletions
|
|
@ -91,17 +91,14 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
|
|||
import torch
|
||||
|
||||
if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'):
|
||||
# Native ExLlamav3: call the underlying Model.forward() directly
|
||||
# Native ExLlamav3: call the underlying Model.forward() in chunks
|
||||
# to avoid OOM from giant logits tensors (seq_len * vocab_size * 4 bytes)
|
||||
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
|
||||
input_ids_tensor = input_ids_tensor.view(-1).cpu()
|
||||
with torch.no_grad():
|
||||
logits = model.model.forward(
|
||||
input_ids=input_ids_tensor,
|
||||
params={
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": model.cache,
|
||||
"past_len": 0,
|
||||
"batch_shape": (1, model.max_tokens),
|
||||
}
|
||||
input_ids=input_ids_tensor.view(1, -1),
|
||||
params={"attn_mode": "flash_attn_nc"}
|
||||
).float().cpu()
|
||||
|
||||
elif hasattr(model, 'forward'):
|
||||
|
|
|
|||
|
|
@ -530,39 +530,14 @@ class Exllamav3Model:
|
|||
def get_logits(self, token_ids, **kwargs):
|
||||
"""
|
||||
Process a batch of token_ids and return the logits for the last token.
|
||||
This will reset and overwrite the model's cache.
|
||||
Uses flash_attn_nc (no cache) for correct results with recurrent models.
|
||||
"""
|
||||
# Initialize a single params dictionary that will be updated in-place
|
||||
params = {
|
||||
"cache": self.cache,
|
||||
"reconstruct": False,
|
||||
"attn_mode": "flash_attn",
|
||||
"batch_shape": (1, self.max_tokens),
|
||||
"past_len": 0
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
# Process prefix tokens to fill the cache and generate recurrent state
|
||||
if token_ids.shape[-1] > 1:
|
||||
prefix_ids = token_ids[:, :-1]
|
||||
|
||||
# This forward call updates the 'params' dict with the recurrent state
|
||||
self.model.forward(
|
||||
input_ids=prefix_ids,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Update past_len for the next call
|
||||
params["past_len"] = prefix_ids.shape[-1]
|
||||
|
||||
# Process the last token, now using the state-filled 'params' dict
|
||||
last_token_ids = token_ids[:, -1:]
|
||||
logits = self.model.forward(
|
||||
input_ids=last_token_ids,
|
||||
params=params
|
||||
input_ids=token_ids,
|
||||
params={"attn_mode": "flash_attn_nc"}
|
||||
)
|
||||
|
||||
return logits.float().cpu()
|
||||
return logits[:, -1:, :].float().cpu()
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
add_bos = kwargs.pop('add_bos', True)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@ except Exception:
|
|||
class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||
def __init__(self, model_dir):
|
||||
hf_config = PretrainedConfig.from_pretrained(model_dir)
|
||||
# Ensure text_config is a proper object, not a dict (fixes qwen3_5_moe + transformers compat)
|
||||
if isinstance(getattr(hf_config, 'text_config', None), dict):
|
||||
hf_config.text_config = PretrainedConfig(**hf_config.text_config)
|
||||
super().__init__(hf_config)
|
||||
|
||||
exl3_config = Config.from_directory(model_dir)
|
||||
|
|
@ -199,30 +202,11 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
|||
}
|
||||
).to(input_ids.device).float()
|
||||
else:
|
||||
# Labels path: use cache for cross-chunk attention.
|
||||
tokens_to_process = seq_tensor
|
||||
all_logits = None
|
||||
current_len = 0
|
||||
|
||||
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||
chunk_logits = self.ex_model.forward(
|
||||
input_ids=chunk.view(1, -1),
|
||||
params={
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": ex_cache,
|
||||
"past_len": current_len,
|
||||
"batch_shape": (1, self.max_tokens),
|
||||
}
|
||||
).float()
|
||||
current_len += chunk.shape[0]
|
||||
|
||||
if all_logits is None:
|
||||
all_logits = chunk_logits
|
||||
else:
|
||||
all_logits = torch.cat([all_logits, chunk_logits], dim=1)
|
||||
|
||||
logits = all_logits
|
||||
# Labels path: single pass without cache for correct logits
|
||||
logits = self.ex_model.forward(
|
||||
input_ids=seq_tensor.view(1, -1),
|
||||
params={"attn_mode": "flash_attn_nc"}
|
||||
).float().cpu()
|
||||
|
||||
if is_negative:
|
||||
self.past_seq_negative = seq_tensor
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue