From a5855c345cc3e361bc8a436daf995fe6a2a5dd33 Mon Sep 17 00:00:00 2001 From: oobabooga Date: Mon, 7 Apr 2025 21:42:33 -0300 Subject: [PATCH] Set context lengths to at most 8192 by default (to prevent out of memory errors) (#6835) --- modules/models_settings.py | 7 +++++-- modules/shared.py | 6 +++--- modules/ui_model_menu.py | 10 +++++----- modules/ui_parameters.py | 2 +- settings-template.yaml | 2 +- 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/modules/models_settings.py b/modules/models_settings.py index 8d658523..b67d28a0 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -17,6 +17,7 @@ def get_fallback_settings(): 'compress_pos_emb': 1, 'alpha_value': 1, 'truncation_length': shared.settings['truncation_length'], + 'truncation_length_info': shared.settings['truncation_length'], 'skip_special_tokens': shared.settings['skip_special_tokens'], 'custom_stopping_strings': shared.settings['custom_stopping_strings'], } @@ -53,7 +54,8 @@ def get_model_metadata(model): for k in metadata: if k.endswith('context_length'): - model_settings['n_ctx'] = metadata[k] + model_settings['n_ctx'] = min(metadata[k], 8192) + model_settings['truncation_length_info'] = metadata[k] elif k.endswith('rope.freq_base'): model_settings['rope_freq_base'] = metadata[k] elif k.endswith('rope.scale_linear'): @@ -89,7 +91,8 @@ def get_model_metadata(model): for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']: if k in metadata: model_settings['truncation_length'] = metadata[k] - model_settings['max_seq_len'] = metadata[k] + model_settings['truncation_length_info'] = metadata[k] + model_settings['max_seq_len'] = min(metadata[k], 8192) if 'rope_theta' in metadata: model_settings['rope_freq_base'] = metadata['rope_theta'] diff --git a/modules/shared.py b/modules/shared.py index ea6c581a..77bd7639 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -53,7 +53,7 @@ settings = { 'skip_special_tokens': True, 'stream': True, 'static_cache': False, - 'truncation_length': 2048, + 'truncation_length': 8192, 'seed': -1, 'custom_stopping_strings': '', 'custom_token_bans': '', @@ -117,7 +117,7 @@ group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for group = parser.add_argument_group('llama.cpp') group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.') group.add_argument('--tensorcores', action='store_true', help='NVIDIA only: use llama-cpp-python compiled without GGML_CUDA_FORCE_MMQ. This may improve performance on newer cards.') -group.add_argument('--n_ctx', type=int, default=2048, help='Size of the prompt context.') +group.add_argument('--n_ctx', type=int, default=8192, help='Size of the prompt context.') group.add_argument('--threads', type=int, default=0, help='Number of threads to use.') group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.') group.add_argument('--no_mul_mat_q', action='store_true', help='Disable the mulmat kernels.') @@ -139,7 +139,7 @@ group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from th group = parser.add_argument_group('ExLlamaV2') group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.') group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.') -group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.') +group.add_argument('--max_seq_len', type=int, default=8192, help='Maximum sequence length.') group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.') group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.') diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 1264a9fd..c23b991a 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -200,8 +200,10 @@ def create_event_handlers(): def load_model_wrapper(selected_model, loader, autoload=False): + settings = get_model_metadata(selected_model) + if not autoload: - yield f"The settings for `{selected_model}` have been updated.\n\nClick on \"Load\" to load it." + yield "### {}\n\n- Settings updated: Click \"Load\" to load the model\n- Max sequence length: {}".format(selected_model, settings['truncation_length_info']) return if selected_model == 'None': @@ -214,11 +216,9 @@ def load_model_wrapper(selected_model, loader, autoload=False): shared.model, shared.tokenizer = load_model(selected_model, loader) if shared.model is not None: - output = f"Successfully loaded `{selected_model}`." - - settings = get_model_metadata(selected_model) + output = f"Successfully loaded `{selected_model}`.\n\n" if 'instruction_template' in settings: - output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template']) + output += '- It seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.\n'.format(settings['instruction_template']) yield output else: diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 846fcfe7..c3245a9d 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -87,7 +87,7 @@ def create_ui(default_preset): shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache', info='Use a static cache for improved performance.') with gr.Column(): - shared.gradio['truncation_length'] = gr.Number(precision=0, step=256, value=get_truncation_length(), label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') + shared.gradio['truncation_length'] = gr.Number(precision=0, step=256, value=get_truncation_length(), label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length.') shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar']) diff --git a/settings-template.yaml b/settings-template.yaml index 74935a60..0343df0a 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -25,7 +25,7 @@ add_bos_token: true skip_special_tokens: true stream: true static_cache: false -truncation_length: 2048 +truncation_length: 8192 seed: -1 custom_stopping_strings: '' custom_token_bans: ''