Set context lengths to at most 8192 by default (to prevent out of memory errors) (#6835)

This commit is contained in:
oobabooga 2025-04-07 21:42:33 -03:00 committed by GitHub
parent f1f32386b4
commit a5855c345c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 15 additions and 12 deletions

View file

@ -17,6 +17,7 @@ def get_fallback_settings():
'compress_pos_emb': 1,
'alpha_value': 1,
'truncation_length': shared.settings['truncation_length'],
'truncation_length_info': shared.settings['truncation_length'],
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
}
@ -53,7 +54,8 @@ def get_model_metadata(model):
for k in metadata:
if k.endswith('context_length'):
model_settings['n_ctx'] = metadata[k]
model_settings['n_ctx'] = min(metadata[k], 8192)
model_settings['truncation_length_info'] = metadata[k]
elif k.endswith('rope.freq_base'):
model_settings['rope_freq_base'] = metadata[k]
elif k.endswith('rope.scale_linear'):
@ -89,7 +91,8 @@ def get_model_metadata(model):
for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']:
if k in metadata:
model_settings['truncation_length'] = metadata[k]
model_settings['max_seq_len'] = metadata[k]
model_settings['truncation_length_info'] = metadata[k]
model_settings['max_seq_len'] = min(metadata[k], 8192)
if 'rope_theta' in metadata:
model_settings['rope_freq_base'] = metadata['rope_theta']