ExLlamav3: fix draft cache size to match main cache

This commit is contained in:
oobabooga 2026-03-07 22:34:48 -03:00
parent 6ff111d18e
commit baf4e13ff1

View file

@ -175,23 +175,8 @@ class Exllamav3Model:
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
else:
draft_config = Config.from_directory(str(draft_path))
# Set context size for draft model with 256-multiple validation
if shared.args.ctx_size_draft > 0:
draft_max_tokens = shared.args.ctx_size_draft
else:
draft_max_tokens = shared.args.ctx_size
# Validate draft model context size is a multiple of 256
if draft_max_tokens % 256 != 0:
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
draft_max_tokens = adjusted_draft_tokens
draft_config.max_seq_len = draft_max_tokens
draft_model = Model.from_config(draft_config)
draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
draft_load_params = {'progressbar': True}
if split: