mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
Fix ExLlamaV3 EOS handling, load order, and perplexity evaluation
- Use config.eos_token_id_list for all EOS tokens as stop conditions (fixes models like Llama-3 that define multiple EOS token IDs) - Load vision/draft models before main model so autosplit accounts for their VRAM usage - Fix loss computation in ExLlamav3_HF: use cache across chunks so sequences longer than 2048 tokens get correct perplexity values
This commit is contained in:
parent
39e6c997cc
commit
7f485274eb
|
|
@ -158,8 +158,21 @@ class Exllamav3Model:
|
|||
load_params['tensor_p'] = True
|
||||
load_params['tp_backend'] = shared.args.tp_backend
|
||||
|
||||
model.load(**load_params)
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
# Load vision and draft before the main model so autosplit
|
||||
# accounts for their VRAM usage.
|
||||
|
||||
# Load vision model component (ExLlamaV3 native)
|
||||
vision_model = None
|
||||
if "vision_config" in config.config_dict:
|
||||
logger.info("Vision component detected in model config. Attempting to load...")
|
||||
try:
|
||||
vision_model = Model.from_config(config, component="vision")
|
||||
vision_model.load(progressbar=True)
|
||||
logger.info("Vision model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
||||
else:
|
||||
logger.info("No vision component in model config. Skipping multimodal setup.")
|
||||
|
||||
# Initialize draft model for speculative decoding
|
||||
draft_model = None
|
||||
|
|
@ -185,18 +198,9 @@ class Exllamav3Model:
|
|||
draft_model.load(**draft_load_params)
|
||||
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
|
||||
|
||||
# Load vision model component (ExLlamaV3 native)
|
||||
vision_model = None
|
||||
if "vision_config" in config.config_dict:
|
||||
logger.info("Vision component detected in model config. Attempting to load...")
|
||||
try:
|
||||
vision_model = Model.from_config(config, component="vision")
|
||||
vision_model.load(progressbar=True)
|
||||
logger.info("Vision model loaded successfully.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
||||
else:
|
||||
logger.info("No vision component in model config. Skipping multimodal setup.")
|
||||
# Load main model last
|
||||
model.load(**load_params)
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
|
||||
generator = Generator(
|
||||
model=model,
|
||||
|
|
@ -379,11 +383,12 @@ class Exllamav3Model:
|
|||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
# Get stop conditions
|
||||
# Use full EOS token list from config (may contain multiple IDs)
|
||||
stop_conditions = []
|
||||
if not state['ban_eos_token']:
|
||||
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||
for eos_id in self.config.eos_token_id_list:
|
||||
if eos_id is not None:
|
||||
stop_conditions.append(eos_id)
|
||||
|
||||
seed = state.get('seed', -1)
|
||||
job = Job(
|
||||
|
|
|
|||
|
|
@ -201,19 +201,23 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
|||
}
|
||||
).to(input_ids.device).float()
|
||||
else:
|
||||
# When processing with labels, handle as a complete sequence
|
||||
# Process in chunks if the number of tokens is large
|
||||
# Labels path: use cache for cross-chunk attention.
|
||||
tokens_to_process = seq_tensor
|
||||
all_logits = None
|
||||
current_len = 0
|
||||
|
||||
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
|
||||
chunk = tokens_to_process[i:i + max_chunk_size]
|
||||
chunk_logits = self.ex_model.forward(
|
||||
input_ids=chunk.view(1, -1),
|
||||
params={
|
||||
"attn_mode": "flash_attn_nc",
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": ex_cache,
|
||||
"past_len": current_len,
|
||||
"batch_shape": (1, self.max_tokens),
|
||||
}
|
||||
).float()
|
||||
current_len += chunk.shape[0]
|
||||
|
||||
if all_logits is None:
|
||||
all_logits = chunk_logits
|
||||
|
|
|
|||
Loading…
Reference in a new issue