From ae1fe8736549b8bea2dee4daebad66f5828860c8 Mon Sep 17 00:00:00 2001 From: oobabooga Date: Fri, 25 Apr 2025 00:11:04 -0300 Subject: [PATCH] ExLlamaV2: Add speculative decoding (#6899) --- modules/exllamav2.py | 50 +++++++++++++++++++++++++++++++++++++++++++- modules/loaders.py | 5 ++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 0289bb21..7d79e516 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -85,7 +85,44 @@ class Exllamav2Model: model.load_autosplit(cache) tokenizer = ExLlamaV2Tokenizer(config) - generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) + + # Initialize draft model for speculative decoding + draft_model = None + draft_cache = None + + if shared.args.model_draft and shared.args.model_draft.lower() not in ["none", ""]: + logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}") + + # Find the draft model path + draft_path = Path(shared.args.model_draft) + if not draft_path.exists(): + draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft) + + draft_config = ExLlamaV2Config() + draft_config.model_dir = str(draft_path) + draft_config.prepare() + draft_config.arch_compat_overrides() + + # Set context size for draft model + if shared.args.ctx_size_draft > 0: + draft_config.max_seq_len = shared.args.ctx_size_draft + else: + draft_config.max_seq_len = config.max_seq_len + + draft_model = ExLlamaV2(draft_config) + draft_cache = cache_type(draft_model, lazy=True) + draft_model.load_autosplit(draft_cache) + + logger.info(f"Draft model loaded successfully with max_draft={shared.args.draft_max}") + + generator = ExLlamaV2StreamingGenerator( + model, + cache, + tokenizer, + draft_model=draft_model, + draft_cache=draft_cache, + num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0 + ) result = self() result.model = model @@ -93,6 +130,8 @@ class Exllamav2Model: result.tokenizer = tokenizer result.generator = generator result.loras = None + result.draft_model = draft_model + result.draft_cache = draft_cache return result, result def encode(self, string, **kwargs): @@ -179,6 +218,10 @@ class Exllamav2Model: else: max_new_tokens = state['max_new_tokens'] + # Reset speculative decoding stats if using a draft model + if hasattr(self, 'draft_model') and self.draft_model is not None: + self.generator.reset_sd_stats() + self.generator.begin_stream(ids, settings, loras=self.loras) decoded_text = '' @@ -190,6 +233,11 @@ class Exllamav2Model: decoded_text += chunk yield decoded_text + # Log speculative decoding stats if using draft model + if hasattr(self, 'draft_model') and self.draft_model is not None: + efficiency, accuracy, total_tokens, total_draft_tokens, accepted_draft_tokens = self.generator.get_sd_stats() + logger.info(f"Speculative decoding: accepted={accepted_draft_tokens}/{total_draft_tokens} tokens") + def generate(self, prompt, state): output = '' for output in self.generate_with_streaming(prompt, state): diff --git a/modules/loaders.py b/modules/loaders.py index 167b2c98..d256e1e7 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -25,7 +25,7 @@ loaders_and_params = OrderedDict({ 'gpu_layers_draft', 'device_draft', 'ctx_size_draft', - 'speculative_decoding_accordion' + 'speculative_decoding_accordion', ], 'Transformers': [ 'gpu_split', @@ -82,6 +82,9 @@ loaders_and_params = OrderedDict({ 'no_xformers', 'no_sdpa', 'exllamav2_info', + 'model_draft', + 'ctx_size_draft', + 'speculative_decoding_accordion', ], 'HQQ': [ 'hqq_backend',