ExLlamaV3: Add speculative decoding

This commit is contained in:
oobabooga 2025-08-12 08:50:45 -07:00
parent 0882970a94
commit 2238302b49
2 changed files with 62 additions and 0 deletions

View file

@ -85,6 +85,7 @@ class Exllamav3Model:
cache = Cache(model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
load_params = {'progressbar': True}
split = None
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
load_params['use_per_device'] = split
@ -92,6 +93,45 @@ class Exllamav3Model:
model.load(**load_params)
tokenizer = Tokenizer.from_config(config)
# Initialize draft model for speculative decoding
draft_model = None
draft_cache = None
if shared.args.model_draft and shared.args.model_draft.lower() not in ["", "none"]:
logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}")
draft_path = Path(shared.args.model_draft)
if not draft_path.is_dir():
draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft)
if not draft_path.is_dir():
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
else:
draft_config = Config.from_directory(str(draft_path))
# Set context size for draft model with 256-multiple validation
if shared.args.ctx_size_draft > 0:
draft_max_tokens = shared.args.ctx_size_draft
else:
draft_max_tokens = shared.args.ctx_size
# Validate draft model context size is a multiple of 256
if draft_max_tokens % 256 != 0:
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
draft_max_tokens = adjusted_draft_tokens
draft_config.max_seq_len = draft_max_tokens
draft_model = Model.from_config(draft_config)
draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
draft_load_params = {'progressbar': True}
if split:
draft_load_params['use_per_device'] = split
draft_model.load(**draft_load_params)
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
# Load vision model component (ExLlamaV3 native)
vision_model = None
if "vision_config" in config.config_dict:
@ -109,6 +149,9 @@ class Exllamav3Model:
model=model,
cache=cache,
tokenizer=tokenizer,
draft_model=draft_model,
draft_cache=draft_cache,
num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0,
)
result = cls()
@ -119,6 +162,8 @@ class Exllamav3Model:
result.config = config
result.max_tokens = max_tokens
result.vision_model = vision_model
result.draft_model = draft_model
result.draft_cache = draft_cache
return result
@ -289,6 +334,7 @@ class Exllamav3Model:
self.generator.enqueue(job)
response_text = ""
try:
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
@ -300,6 +346,7 @@ class Exllamav3Model:
if chunk:
response_text += chunk
yield response_text
finally:
self.generator.clear_queue()
@ -331,6 +378,17 @@ class Exllamav3Model:
logger.warning(f"Error unloading vision model: {e}")
self.vision_model = None
if hasattr(self, 'draft_model') and self.draft_model is not None:
try:
self.draft_model.unload()
del self.draft_model
except Exception as e:
logger.warning(f"Error unloading draft model: {e}")
self.draft_model = None
if hasattr(self, 'draft_cache') and self.draft_cache is not None:
self.draft_cache = None
if hasattr(self, 'model') and self.model is not None:
try:
self.model.unload()

View file

@ -61,6 +61,10 @@ loaders_and_params = OrderedDict({
'ctx_size',
'cache_type',
'gpu_split',
'model_draft',
'draft_max',
'ctx_size_draft',
'speculative_decoding_accordion',
],
'ExLlamav2_HF': [
'ctx_size',