From 2238302b496a4145ee98e0eab0bf3d9f19a9c83b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 12 Aug 2025 08:50:45 -0700 Subject: [PATCH] ExLlamaV3: Add speculative decoding --- modules/exllamav3.py | 58 ++++++++++++++++++++++++++++++++++++++++++++ modules/loaders.py | 4 +++ 2 files changed, 62 insertions(+) diff --git a/modules/exllamav3.py b/modules/exllamav3.py index 980230f8..7fc6c5b1 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -85,6 +85,7 @@ class Exllamav3Model: cache = Cache(model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) load_params = {'progressbar': True} + split = None if shared.args.gpu_split: split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] load_params['use_per_device'] = split @@ -92,6 +93,45 @@ class Exllamav3Model: model.load(**load_params) tokenizer = Tokenizer.from_config(config) + # Initialize draft model for speculative decoding + draft_model = None + draft_cache = None + if shared.args.model_draft and shared.args.model_draft.lower() not in ["", "none"]: + logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}") + + draft_path = Path(shared.args.model_draft) + if not draft_path.is_dir(): + draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft) + + if not draft_path.is_dir(): + logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.") + else: + draft_config = Config.from_directory(str(draft_path)) + + # Set context size for draft model with 256-multiple validation + if shared.args.ctx_size_draft > 0: + draft_max_tokens = shared.args.ctx_size_draft + else: + draft_max_tokens = shared.args.ctx_size + + # Validate draft model context size is a multiple of 256 + if draft_max_tokens % 256 != 0: + adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256 + logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}") + draft_max_tokens = adjusted_draft_tokens + + draft_config.max_seq_len = draft_max_tokens + + draft_model = Model.from_config(draft_config) + draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs) + + draft_load_params = {'progressbar': True} + if split: + draft_load_params['use_per_device'] = split + + draft_model.load(**draft_load_params) + logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}") + # Load vision model component (ExLlamaV3 native) vision_model = None if "vision_config" in config.config_dict: @@ -109,6 +149,9 @@ class Exllamav3Model: model=model, cache=cache, tokenizer=tokenizer, + draft_model=draft_model, + draft_cache=draft_cache, + num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0, ) result = cls() @@ -119,6 +162,8 @@ class Exllamav3Model: result.config = config result.max_tokens = max_tokens result.vision_model = vision_model + result.draft_model = draft_model + result.draft_cache = draft_cache return result @@ -289,6 +334,7 @@ class Exllamav3Model: self.generator.enqueue(job) response_text = "" + try: while self.generator.num_remaining_jobs(): results = self.generator.iterate() @@ -300,6 +346,7 @@ class Exllamav3Model: if chunk: response_text += chunk yield response_text + finally: self.generator.clear_queue() @@ -331,6 +378,17 @@ class Exllamav3Model: logger.warning(f"Error unloading vision model: {e}") self.vision_model = None + if hasattr(self, 'draft_model') and self.draft_model is not None: + try: + self.draft_model.unload() + del self.draft_model + except Exception as e: + logger.warning(f"Error unloading draft model: {e}") + self.draft_model = None + + if hasattr(self, 'draft_cache') and self.draft_cache is not None: + self.draft_cache = None + if hasattr(self, 'model') and self.model is not None: try: self.model.unload() diff --git a/modules/loaders.py b/modules/loaders.py index feca9985..8b7e6cce 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -61,6 +61,10 @@ loaders_and_params = OrderedDict({ 'ctx_size', 'cache_type', 'gpu_split', + 'model_draft', + 'draft_max', + 'ctx_size_draft', + 'speculative_decoding_accordion', ], 'ExLlamav2_HF': [ 'ctx_size',