diff --git a/modules/api/completions.py b/modules/api/completions.py index 587ad6ea..a15e1f86 100644 --- a/modules/api/completions.py +++ b/modules/api/completions.py @@ -91,17 +91,14 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): import torch if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'): - # Native ExLlamav3: call the underlying Model.forward() directly + # Native ExLlamav3: call the underlying Model.forward() in chunks + # to avoid OOM from giant logits tensors (seq_len * vocab_size * 4 bytes) input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long) + input_ids_tensor = input_ids_tensor.view(-1).cpu() with torch.no_grad(): logits = model.model.forward( - input_ids=input_ids_tensor, - params={ - "attn_mode": "flash_attn", - "cache": model.cache, - "past_len": 0, - "batch_shape": (1, model.max_tokens), - } + input_ids=input_ids_tensor.view(1, -1), + params={"attn_mode": "flash_attn_nc"} ).float().cpu() elif hasattr(model, 'forward'): diff --git a/modules/exllamav3.py b/modules/exllamav3.py index 3782a693..7556a908 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -530,39 +530,14 @@ class Exllamav3Model: def get_logits(self, token_ids, **kwargs): """ Process a batch of token_ids and return the logits for the last token. - This will reset and overwrite the model's cache. + Uses flash_attn_nc (no cache) for correct results with recurrent models. """ - # Initialize a single params dictionary that will be updated in-place - params = { - "cache": self.cache, - "reconstruct": False, - "attn_mode": "flash_attn", - "batch_shape": (1, self.max_tokens), - "past_len": 0 - } - params.update(kwargs) - - # Process prefix tokens to fill the cache and generate recurrent state - if token_ids.shape[-1] > 1: - prefix_ids = token_ids[:, :-1] - - # This forward call updates the 'params' dict with the recurrent state - self.model.forward( - input_ids=prefix_ids, - params=params - ) - - # Update past_len for the next call - params["past_len"] = prefix_ids.shape[-1] - - # Process the last token, now using the state-filled 'params' dict - last_token_ids = token_ids[:, -1:] logits = self.model.forward( - input_ids=last_token_ids, - params=params + input_ids=token_ids, + params={"attn_mode": "flash_attn_nc"} ) - return logits.float().cpu() + return logits[:, -1:, :].float().cpu() def encode(self, string, **kwargs): add_bos = kwargs.pop('add_bos', True) diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index e0ad5002..5e634e22 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -26,6 +26,9 @@ except Exception: class Exllamav3HF(PreTrainedModel, GenerationMixin): def __init__(self, model_dir): hf_config = PretrainedConfig.from_pretrained(model_dir) + # Ensure text_config is a proper object, not a dict (fixes qwen3_5_moe + transformers compat) + if isinstance(getattr(hf_config, 'text_config', None), dict): + hf_config.text_config = PretrainedConfig(**hf_config.text_config) super().__init__(hf_config) exl3_config = Config.from_directory(model_dir) @@ -199,30 +202,11 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): } ).to(input_ids.device).float() else: - # Labels path: use cache for cross-chunk attention. - tokens_to_process = seq_tensor - all_logits = None - current_len = 0 - - for i in range(0, tokens_to_process.shape[0], max_chunk_size): - chunk = tokens_to_process[i:i + max_chunk_size] - chunk_logits = self.ex_model.forward( - input_ids=chunk.view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": current_len, - "batch_shape": (1, self.max_tokens), - } - ).float() - current_len += chunk.shape[0] - - if all_logits is None: - all_logits = chunk_logits - else: - all_logits = torch.cat([all_logits, chunk_logits], dim=1) - - logits = all_logits + # Labels path: single pass without cache for correct logits + logits = self.ex_model.forward( + input_ids=seq_tensor.view(1, -1), + params={"attn_mode": "flash_attn_nc"} + ).float().cpu() if is_negative: self.past_seq_negative = seq_tensor