mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-03 23:30:06 +01:00
ExLlamaV2: Add speculative decoding (#6899)
This commit is contained in:
parent
8f2493cc60
commit
ae1fe87365
|
|
@ -85,7 +85,44 @@ class Exllamav2Model:
|
|||
model.load_autosplit(cache)
|
||||
|
||||
tokenizer = ExLlamaV2Tokenizer(config)
|
||||
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
|
||||
|
||||
# Initialize draft model for speculative decoding
|
||||
draft_model = None
|
||||
draft_cache = None
|
||||
|
||||
if shared.args.model_draft and shared.args.model_draft.lower() not in ["none", ""]:
|
||||
logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}")
|
||||
|
||||
# Find the draft model path
|
||||
draft_path = Path(shared.args.model_draft)
|
||||
if not draft_path.exists():
|
||||
draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft)
|
||||
|
||||
draft_config = ExLlamaV2Config()
|
||||
draft_config.model_dir = str(draft_path)
|
||||
draft_config.prepare()
|
||||
draft_config.arch_compat_overrides()
|
||||
|
||||
# Set context size for draft model
|
||||
if shared.args.ctx_size_draft > 0:
|
||||
draft_config.max_seq_len = shared.args.ctx_size_draft
|
||||
else:
|
||||
draft_config.max_seq_len = config.max_seq_len
|
||||
|
||||
draft_model = ExLlamaV2(draft_config)
|
||||
draft_cache = cache_type(draft_model, lazy=True)
|
||||
draft_model.load_autosplit(draft_cache)
|
||||
|
||||
logger.info(f"Draft model loaded successfully with max_draft={shared.args.draft_max}")
|
||||
|
||||
generator = ExLlamaV2StreamingGenerator(
|
||||
model,
|
||||
cache,
|
||||
tokenizer,
|
||||
draft_model=draft_model,
|
||||
draft_cache=draft_cache,
|
||||
num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0
|
||||
)
|
||||
|
||||
result = self()
|
||||
result.model = model
|
||||
|
|
@ -93,6 +130,8 @@ class Exllamav2Model:
|
|||
result.tokenizer = tokenizer
|
||||
result.generator = generator
|
||||
result.loras = None
|
||||
result.draft_model = draft_model
|
||||
result.draft_cache = draft_cache
|
||||
return result, result
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
|
|
@ -179,6 +218,10 @@ class Exllamav2Model:
|
|||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
# Reset speculative decoding stats if using a draft model
|
||||
if hasattr(self, 'draft_model') and self.draft_model is not None:
|
||||
self.generator.reset_sd_stats()
|
||||
|
||||
self.generator.begin_stream(ids, settings, loras=self.loras)
|
||||
|
||||
decoded_text = ''
|
||||
|
|
@ -190,6 +233,11 @@ class Exllamav2Model:
|
|||
decoded_text += chunk
|
||||
yield decoded_text
|
||||
|
||||
# Log speculative decoding stats if using draft model
|
||||
if hasattr(self, 'draft_model') and self.draft_model is not None:
|
||||
efficiency, accuracy, total_tokens, total_draft_tokens, accepted_draft_tokens = self.generator.get_sd_stats()
|
||||
logger.info(f"Speculative decoding: accepted={accepted_draft_tokens}/{total_draft_tokens} tokens")
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ''
|
||||
for output in self.generate_with_streaming(prompt, state):
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ loaders_and_params = OrderedDict({
|
|||
'gpu_layers_draft',
|
||||
'device_draft',
|
||||
'ctx_size_draft',
|
||||
'speculative_decoding_accordion'
|
||||
'speculative_decoding_accordion',
|
||||
],
|
||||
'Transformers': [
|
||||
'gpu_split',
|
||||
|
|
@ -82,6 +82,9 @@ loaders_and_params = OrderedDict({
|
|||
'no_xformers',
|
||||
'no_sdpa',
|
||||
'exllamav2_info',
|
||||
'model_draft',
|
||||
'ctx_size_draft',
|
||||
'speculative_decoding_accordion',
|
||||
],
|
||||
'HQQ': [
|
||||
'hqq_backend',
|
||||
|
|
|
|||
Loading…
Reference in a new issue