mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-09 15:13:56 +01:00
ExLlamav3: fix draft cache size to match main cache
This commit is contained in:
parent
6ff111d18e
commit
baf4e13ff1
|
|
@ -175,23 +175,8 @@ class Exllamav3Model:
|
|||
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
|
||||
else:
|
||||
draft_config = Config.from_directory(str(draft_path))
|
||||
|
||||
# Set context size for draft model with 256-multiple validation
|
||||
if shared.args.ctx_size_draft > 0:
|
||||
draft_max_tokens = shared.args.ctx_size_draft
|
||||
else:
|
||||
draft_max_tokens = shared.args.ctx_size
|
||||
|
||||
# Validate draft model context size is a multiple of 256
|
||||
if draft_max_tokens % 256 != 0:
|
||||
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
|
||||
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
|
||||
draft_max_tokens = adjusted_draft_tokens
|
||||
|
||||
draft_config.max_seq_len = draft_max_tokens
|
||||
|
||||
draft_model = Model.from_config(draft_config)
|
||||
draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||
draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||
|
||||
draft_load_params = {'progressbar': True}
|
||||
if split:
|
||||
|
|
|
|||
Loading…
Reference in a new issue