From 4ff91b6588deb358a802fb28caee189fa442785f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 26 Apr 2025 17:24:40 -0700 Subject: [PATCH] Better default settings for Speculative Decoding --- modules/shared.py | 3 +-- modules/ui_model_menu.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 63bdb536..a2ff61e2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,6 @@ from modules.logging_colors import logger model = None tokenizer = None model_name = 'None' -draft_model_name = 'None' is_seq2seq = False model_dirty_from_training = False lora_names = [] @@ -138,7 +137,7 @@ group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type group = parser.add_argument_group('Speculative decoding') group.add_argument('--model-draft', type=str, default=None, help='Path to the draft model for speculative decoding.') group.add_argument('--draft-max', type=int, default=4, help='Number of tokens to draft for speculative decoding.') -group.add_argument('--gpu-layers-draft', type=int, default=0, help='Number of layers to offload to the GPU for the draft model.') +group.add_argument('--gpu-layers-draft', type=int, default=256, help='Number of layers to offload to the GPU for the draft model.') group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1') group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.') diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index dc09c899..546200f9 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -96,7 +96,7 @@ def create_ui(): # Speculative decoding with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']: with gr.Row(): - shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=utils.get_available_models(), value=lambda: shared.draft_model_name, elem_classes='slim-dropdown', interactive=not mu) + shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', interactive=not mu) ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu) shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding.')