diff --git a/modules/exllamav3.py b/modules/exllamav3.py index d9772682..9ea38432 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -175,23 +175,8 @@ class Exllamav3Model: logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.") else: draft_config = Config.from_directory(str(draft_path)) - - # Set context size for draft model with 256-multiple validation - if shared.args.ctx_size_draft > 0: - draft_max_tokens = shared.args.ctx_size_draft - else: - draft_max_tokens = shared.args.ctx_size - - # Validate draft model context size is a multiple of 256 - if draft_max_tokens % 256 != 0: - adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256 - logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}") - draft_max_tokens = adjusted_draft_tokens - - draft_config.max_seq_len = draft_max_tokens - draft_model = Model.from_config(draft_config) - draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs) + draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) draft_load_params = {'progressbar': True} if split: