Set context lengths to at most 8192 by default (to prevent out of memory errors) (#6835)

This commit is contained in:
oobabooga 2025-04-07 21:42:33 -03:00 committed by GitHub
parent f1f32386b4
commit a5855c345c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 15 additions and 12 deletions

View file

@ -17,6 +17,7 @@ def get_fallback_settings():
'compress_pos_emb': 1,
'alpha_value': 1,
'truncation_length': shared.settings['truncation_length'],
'truncation_length_info': shared.settings['truncation_length'],
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
}
@ -53,7 +54,8 @@ def get_model_metadata(model):
for k in metadata:
if k.endswith('context_length'):
model_settings['n_ctx'] = metadata[k]
model_settings['n_ctx'] = min(metadata[k], 8192)
model_settings['truncation_length_info'] = metadata[k]
elif k.endswith('rope.freq_base'):
model_settings['rope_freq_base'] = metadata[k]
elif k.endswith('rope.scale_linear'):
@ -89,7 +91,8 @@ def get_model_metadata(model):
for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']:
if k in metadata:
model_settings['truncation_length'] = metadata[k]
model_settings['max_seq_len'] = metadata[k]
model_settings['truncation_length_info'] = metadata[k]
model_settings['max_seq_len'] = min(metadata[k], 8192)
if 'rope_theta' in metadata:
model_settings['rope_freq_base'] = metadata['rope_theta']

View file

@ -53,7 +53,7 @@ settings = {
'skip_special_tokens': True,
'stream': True,
'static_cache': False,
'truncation_length': 2048,
'truncation_length': 8192,
'seed': -1,
'custom_stopping_strings': '',
'custom_token_bans': '',
@ -117,7 +117,7 @@ group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for
group = parser.add_argument_group('llama.cpp')
group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.')
group.add_argument('--tensorcores', action='store_true', help='NVIDIA only: use llama-cpp-python compiled without GGML_CUDA_FORCE_MMQ. This may improve performance on newer cards.')
group.add_argument('--n_ctx', type=int, default=2048, help='Size of the prompt context.')
group.add_argument('--n_ctx', type=int, default=8192, help='Size of the prompt context.')
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
group.add_argument('--no_mul_mat_q', action='store_true', help='Disable the mulmat kernels.')
@ -139,7 +139,7 @@ group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from th
group = parser.add_argument_group('ExLlamaV2')
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.')
group.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
group.add_argument('--max_seq_len', type=int, default=8192, help='Maximum sequence length.')
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')

View file

@ -200,8 +200,10 @@ def create_event_handlers():
def load_model_wrapper(selected_model, loader, autoload=False):
settings = get_model_metadata(selected_model)
if not autoload:
yield f"The settings for `{selected_model}` have been updated.\n\nClick on \"Load\" to load it."
yield "### {}\n\n- Settings updated: Click \"Load\" to load the model\n- Max sequence length: {}".format(selected_model, settings['truncation_length_info'])
return
if selected_model == 'None':
@ -214,11 +216,9 @@ def load_model_wrapper(selected_model, loader, autoload=False):
shared.model, shared.tokenizer = load_model(selected_model, loader)
if shared.model is not None:
output = f"Successfully loaded `{selected_model}`."
settings = get_model_metadata(selected_model)
output = f"Successfully loaded `{selected_model}`.\n\n"
if 'instruction_template' in settings:
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])
output += '- It seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.\n'.format(settings['instruction_template'])
yield output
else:

View file

@ -87,7 +87,7 @@ def create_ui(default_preset):
shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache', info='Use a static cache for improved performance.')
with gr.Column():
shared.gradio['truncation_length'] = gr.Number(precision=0, step=256, value=get_truncation_length(), label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
shared.gradio['truncation_length'] = gr.Number(precision=0, step=256, value=get_truncation_length(), label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length.')
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])

View file

@ -25,7 +25,7 @@ add_bos_token: true
skip_special_tokens: true
stream: true
static_cache: false
truncation_length: 2048
truncation_length: 8192
seed: -1
custom_stopping_strings: ''
custom_token_bans: ''