diff --git a/modules/exllamav3.py b/modules/exllamav3.py index 9ea38432..aeb68564 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -158,8 +158,21 @@ class Exllamav3Model: load_params['tensor_p'] = True load_params['tp_backend'] = shared.args.tp_backend - model.load(**load_params) - tokenizer = Tokenizer.from_config(config) + # Load vision and draft before the main model so autosplit + # accounts for their VRAM usage. + + # Load vision model component (ExLlamaV3 native) + vision_model = None + if "vision_config" in config.config_dict: + logger.info("Vision component detected in model config. Attempting to load...") + try: + vision_model = Model.from_config(config, component="vision") + vision_model.load(progressbar=True) + logger.info("Vision model loaded successfully.") + except Exception as e: + logger.warning(f"Vision model loading failed (multimodal disabled): {e}") + else: + logger.info("No vision component in model config. Skipping multimodal setup.") # Initialize draft model for speculative decoding draft_model = None @@ -185,18 +198,9 @@ class Exllamav3Model: draft_model.load(**draft_load_params) logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}") - # Load vision model component (ExLlamaV3 native) - vision_model = None - if "vision_config" in config.config_dict: - logger.info("Vision component detected in model config. Attempting to load...") - try: - vision_model = Model.from_config(config, component="vision") - vision_model.load(progressbar=True) - logger.info("Vision model loaded successfully.") - except Exception as e: - logger.warning(f"Vision model loading failed (multimodal disabled): {e}") - else: - logger.info("No vision component in model config. Skipping multimodal setup.") + # Load main model last + model.load(**load_params) + tokenizer = Tokenizer.from_config(config) generator = Generator( model=model, @@ -379,11 +383,12 @@ class Exllamav3Model: else: max_new_tokens = state['max_new_tokens'] - # Get stop conditions + # Use full EOS token list from config (may contain multiple IDs) stop_conditions = [] if not state['ban_eos_token']: - if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: - stop_conditions.append(self.tokenizer.eos_token_id) + for eos_id in self.config.eos_token_id_list: + if eos_id is not None: + stop_conditions.append(eos_id) seed = state.get('seed', -1) job = Job( diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index b4b6ad20..d3c1cb90 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -201,19 +201,23 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): } ).to(input_ids.device).float() else: - # When processing with labels, handle as a complete sequence - # Process in chunks if the number of tokens is large + # Labels path: use cache for cross-chunk attention. tokens_to_process = seq_tensor all_logits = None + current_len = 0 for i in range(0, tokens_to_process.shape[0], max_chunk_size): chunk = tokens_to_process[i:i + max_chunk_size] chunk_logits = self.ex_model.forward( input_ids=chunk.view(1, -1), params={ - "attn_mode": "flash_attn_nc", + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": current_len, + "batch_shape": (1, self.max_tokens), } ).float() + current_len += chunk.shape[0] if all_logits is None: all_logits = chunk_logits