ExLlamaV2: Add speculative decoding (#6899)

This commit is contained in:
oobabooga 2025-04-25 00:11:04 -03:00 committed by GitHub
parent 8f2493cc60
commit ae1fe87365
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 53 additions and 2 deletions

View file

@ -85,7 +85,44 @@ class Exllamav2Model:
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
# Initialize draft model for speculative decoding
draft_model = None
draft_cache = None
if shared.args.model_draft and shared.args.model_draft.lower() not in ["none", ""]:
logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}")
# Find the draft model path
draft_path = Path(shared.args.model_draft)
if not draft_path.exists():
draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft)
draft_config = ExLlamaV2Config()
draft_config.model_dir = str(draft_path)
draft_config.prepare()
draft_config.arch_compat_overrides()
# Set context size for draft model
if shared.args.ctx_size_draft > 0:
draft_config.max_seq_len = shared.args.ctx_size_draft
else:
draft_config.max_seq_len = config.max_seq_len
draft_model = ExLlamaV2(draft_config)
draft_cache = cache_type(draft_model, lazy=True)
draft_model.load_autosplit(draft_cache)
logger.info(f"Draft model loaded successfully with max_draft={shared.args.draft_max}")
generator = ExLlamaV2StreamingGenerator(
model,
cache,
tokenizer,
draft_model=draft_model,
draft_cache=draft_cache,
num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0
)
result = self()
result.model = model
@ -93,6 +130,8 @@ class Exllamav2Model:
result.tokenizer = tokenizer
result.generator = generator
result.loras = None
result.draft_model = draft_model
result.draft_cache = draft_cache
return result, result
def encode(self, string, **kwargs):
@ -179,6 +218,10 @@ class Exllamav2Model:
else:
max_new_tokens = state['max_new_tokens']
# Reset speculative decoding stats if using a draft model
if hasattr(self, 'draft_model') and self.draft_model is not None:
self.generator.reset_sd_stats()
self.generator.begin_stream(ids, settings, loras=self.loras)
decoded_text = ''
@ -190,6 +233,11 @@ class Exllamav2Model:
decoded_text += chunk
yield decoded_text
# Log speculative decoding stats if using draft model
if hasattr(self, 'draft_model') and self.draft_model is not None:
efficiency, accuracy, total_tokens, total_draft_tokens, accepted_draft_tokens = self.generator.get_sd_stats()
logger.info(f"Speculative decoding: accepted={accepted_draft_tokens}/{total_draft_tokens} tokens")
def generate(self, prompt, state):
output = ''
for output in self.generate_with_streaming(prompt, state):

View file

@ -25,7 +25,7 @@ loaders_and_params = OrderedDict({
'gpu_layers_draft',
'device_draft',
'ctx_size_draft',
'speculative_decoding_accordion'
'speculative_decoding_accordion',
],
'Transformers': [
'gpu_split',
@ -82,6 +82,9 @@ loaders_and_params = OrderedDict({
'no_xformers',
'no_sdpa',
'exllamav2_info',
'model_draft',
'ctx_size_draft',
'speculative_decoding_accordion',
],
'HQQ': [
'hqq_backend',