Fix ExLlamaV3 EOS handling, load order, and perplexity evaluation

- Use config.eos_token_id_list for all EOS tokens as stop conditions
  (fixes models like Llama-3 that define multiple EOS token IDs)
- Load vision/draft models before main model so autosplit accounts
  for their VRAM usage
- Fix loss computation in ExLlamav3_HF: use cache across chunks so
  sequences longer than 2048 tokens get correct perplexity values
This commit is contained in:
oobabooga 2026-03-09 23:55:51 -03:00
parent 39e6c997cc
commit 7f485274eb
2 changed files with 29 additions and 20 deletions

View file

@ -158,8 +158,21 @@ class Exllamav3Model:
load_params['tensor_p'] = True
load_params['tp_backend'] = shared.args.tp_backend
model.load(**load_params)
tokenizer = Tokenizer.from_config(config)
# Load vision and draft before the main model so autosplit
# accounts for their VRAM usage.
# Load vision model component (ExLlamaV3 native)
vision_model = None
if "vision_config" in config.config_dict:
logger.info("Vision component detected in model config. Attempting to load...")
try:
vision_model = Model.from_config(config, component="vision")
vision_model.load(progressbar=True)
logger.info("Vision model loaded successfully.")
except Exception as e:
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
else:
logger.info("No vision component in model config. Skipping multimodal setup.")
# Initialize draft model for speculative decoding
draft_model = None
@ -185,18 +198,9 @@ class Exllamav3Model:
draft_model.load(**draft_load_params)
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
# Load vision model component (ExLlamaV3 native)
vision_model = None
if "vision_config" in config.config_dict:
logger.info("Vision component detected in model config. Attempting to load...")
try:
vision_model = Model.from_config(config, component="vision")
vision_model.load(progressbar=True)
logger.info("Vision model loaded successfully.")
except Exception as e:
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
else:
logger.info("No vision component in model config. Skipping multimodal setup.")
# Load main model last
model.load(**load_params)
tokenizer = Tokenizer.from_config(config)
generator = Generator(
model=model,
@ -379,11 +383,12 @@ class Exllamav3Model:
else:
max_new_tokens = state['max_new_tokens']
# Get stop conditions
# Use full EOS token list from config (may contain multiple IDs)
stop_conditions = []
if not state['ban_eos_token']:
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
stop_conditions.append(self.tokenizer.eos_token_id)
for eos_id in self.config.eos_token_id_list:
if eos_id is not None:
stop_conditions.append(eos_id)
seed = state.get('seed', -1)
job = Job(

View file

@ -201,19 +201,23 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
}
).to(input_ids.device).float()
else:
# When processing with labels, handle as a complete sequence
# Process in chunks if the number of tokens is large
# Labels path: use cache for cross-chunk attention.
tokens_to_process = seq_tensor
all_logits = None
current_len = 0
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
chunk_logits = self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn_nc",
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens),
}
).float()
current_len += chunk.shape[0]
if all_logits is None:
all_logits = chunk_logits