mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-02-26 01:24:22 +01:00
ExLlamaV3: Add speculative decoding
This commit is contained in:
parent
0882970a94
commit
2238302b49
|
|
@ -85,6 +85,7 @@ class Exllamav3Model:
|
|||
cache = Cache(model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||
|
||||
load_params = {'progressbar': True}
|
||||
split = None
|
||||
if shared.args.gpu_split:
|
||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||
load_params['use_per_device'] = split
|
||||
|
|
@ -92,6 +93,45 @@ class Exllamav3Model:
|
|||
model.load(**load_params)
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
|
||||
# Initialize draft model for speculative decoding
|
||||
draft_model = None
|
||||
draft_cache = None
|
||||
if shared.args.model_draft and shared.args.model_draft.lower() not in ["", "none"]:
|
||||
logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}")
|
||||
|
||||
draft_path = Path(shared.args.model_draft)
|
||||
if not draft_path.is_dir():
|
||||
draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft)
|
||||
|
||||
if not draft_path.is_dir():
|
||||
logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.")
|
||||
else:
|
||||
draft_config = Config.from_directory(str(draft_path))
|
||||
|
||||
# Set context size for draft model with 256-multiple validation
|
||||
if shared.args.ctx_size_draft > 0:
|
||||
draft_max_tokens = shared.args.ctx_size_draft
|
||||
else:
|
||||
draft_max_tokens = shared.args.ctx_size
|
||||
|
||||
# Validate draft model context size is a multiple of 256
|
||||
if draft_max_tokens % 256 != 0:
|
||||
adjusted_draft_tokens = ((draft_max_tokens // 256) + 1) * 256
|
||||
logger.warning(f"Draft model max_num_tokens must be a multiple of 256. Adjusting from {draft_max_tokens} to {adjusted_draft_tokens}")
|
||||
draft_max_tokens = adjusted_draft_tokens
|
||||
|
||||
draft_config.max_seq_len = draft_max_tokens
|
||||
|
||||
draft_model = Model.from_config(draft_config)
|
||||
draft_cache = Cache(draft_model, max_num_tokens=draft_max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||
|
||||
draft_load_params = {'progressbar': True}
|
||||
if split:
|
||||
draft_load_params['use_per_device'] = split
|
||||
|
||||
draft_model.load(**draft_load_params)
|
||||
logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}")
|
||||
|
||||
# Load vision model component (ExLlamaV3 native)
|
||||
vision_model = None
|
||||
if "vision_config" in config.config_dict:
|
||||
|
|
@ -109,6 +149,9 @@ class Exllamav3Model:
|
|||
model=model,
|
||||
cache=cache,
|
||||
tokenizer=tokenizer,
|
||||
draft_model=draft_model,
|
||||
draft_cache=draft_cache,
|
||||
num_speculative_tokens=shared.args.draft_max if draft_model is not None else 0,
|
||||
)
|
||||
|
||||
result = cls()
|
||||
|
|
@ -119,6 +162,8 @@ class Exllamav3Model:
|
|||
result.config = config
|
||||
result.max_tokens = max_tokens
|
||||
result.vision_model = vision_model
|
||||
result.draft_model = draft_model
|
||||
result.draft_cache = draft_cache
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -289,6 +334,7 @@ class Exllamav3Model:
|
|||
self.generator.enqueue(job)
|
||||
|
||||
response_text = ""
|
||||
|
||||
try:
|
||||
while self.generator.num_remaining_jobs():
|
||||
results = self.generator.iterate()
|
||||
|
|
@ -300,6 +346,7 @@ class Exllamav3Model:
|
|||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
|
||||
finally:
|
||||
self.generator.clear_queue()
|
||||
|
||||
|
|
@ -331,6 +378,17 @@ class Exllamav3Model:
|
|||
logger.warning(f"Error unloading vision model: {e}")
|
||||
self.vision_model = None
|
||||
|
||||
if hasattr(self, 'draft_model') and self.draft_model is not None:
|
||||
try:
|
||||
self.draft_model.unload()
|
||||
del self.draft_model
|
||||
except Exception as e:
|
||||
logger.warning(f"Error unloading draft model: {e}")
|
||||
self.draft_model = None
|
||||
|
||||
if hasattr(self, 'draft_cache') and self.draft_cache is not None:
|
||||
self.draft_cache = None
|
||||
|
||||
if hasattr(self, 'model') and self.model is not None:
|
||||
try:
|
||||
self.model.unload()
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ loaders_and_params = OrderedDict({
|
|||
'ctx_size',
|
||||
'cache_type',
|
||||
'gpu_split',
|
||||
'model_draft',
|
||||
'draft_max',
|
||||
'ctx_size_draft',
|
||||
'speculative_decoding_accordion',
|
||||
],
|
||||
'ExLlamav2_HF': [
|
||||
'ctx_size',
|
||||
|
|
|
|||
Loading…
Reference in a new issue