mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-05 08:10:07 +01:00
Better default settings for Speculative Decoding
This commit is contained in:
parent
bf2aa19b21
commit
4ff91b6588
|
|
@ -14,7 +14,6 @@ from modules.logging_colors import logger
|
|||
model = None
|
||||
tokenizer = None
|
||||
model_name = 'None'
|
||||
draft_model_name = 'None'
|
||||
is_seq2seq = False
|
||||
model_dirty_from_training = False
|
||||
lora_names = []
|
||||
|
|
@ -138,7 +137,7 @@ group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type
|
|||
group = parser.add_argument_group('Speculative decoding')
|
||||
group.add_argument('--model-draft', type=str, default=None, help='Path to the draft model for speculative decoding.')
|
||||
group.add_argument('--draft-max', type=int, default=4, help='Number of tokens to draft for speculative decoding.')
|
||||
group.add_argument('--gpu-layers-draft', type=int, default=0, help='Number of layers to offload to the GPU for the draft model.')
|
||||
group.add_argument('--gpu-layers-draft', type=int, default=256, help='Number of layers to offload to the GPU for the draft model.')
|
||||
group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
|
||||
group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
|
||||
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ def create_ui():
|
|||
# Speculative decoding
|
||||
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
||||
with gr.Row():
|
||||
shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=utils.get_available_models(), value=lambda: shared.draft_model_name, elem_classes='slim-dropdown', interactive=not mu)
|
||||
shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', interactive=not mu)
|
||||
ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
|
||||
|
||||
shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding.')
|
||||
|
|
|
|||
Loading…
Reference in a new issue