Use --ctx-size to specify the context size for all loaders

Old flags are still recognized as alternatives.
This commit is contained in:
oobabooga 2025-04-25 16:59:03 -07:00
parent faababc4ea
commit d4b1e31c49
12 changed files with 39 additions and 45 deletions

View file

@ -3,6 +3,7 @@ import traceback
from pathlib import Path
import torch
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
@ -15,7 +16,6 @@ from exllamav2 import (
ExLlamaV2Tokenizer
)
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length
@ -40,7 +40,7 @@ class Exllamav2Model:
config.model_dir = str(path_to_model)
config.prepare()
config.max_seq_len = shared.args.max_seq_len
config.max_seq_len = shared.args.ctx_size
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn

View file

@ -4,6 +4,15 @@ from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from torch.nn import CrossEntropyLoss
from transformers import (
GenerationConfig,
GenerationMixin,
PretrainedConfig,
PreTrainedModel
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
@ -14,15 +23,6 @@ from exllamav2 import (
ExLlamaV2Cache_TP,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss
from transformers import (
GenerationConfig,
GenerationMixin,
PretrainedConfig,
PreTrainedModel
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
@ -192,7 +192,7 @@ class Exllamav2HF(PreTrainedModel, GenerationMixin):
config.model_dir = str(pretrained_model_name_or_path)
config.prepare()
config.max_seq_len = shared.args.max_seq_len
config.max_seq_len = shared.args.ctx_size
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn

View file

@ -33,7 +33,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
self.ex_model = Model.from_config(config)
# Calculate the closest multiple of 256 at or above the chosen value
max_tokens = shared.args.max_seq_len
max_tokens = shared.args.ctx_size
if max_tokens % 256 != 0:
adjusted_tokens = ((max_tokens // 256) + 1) * 256
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")

View file

@ -254,7 +254,7 @@ class LlamaServer:
cmd = [
self.server_path,
"--model", self.model_path,
"--ctx-size", str(shared.args.n_ctx),
"--ctx-size", str(shared.args.ctx_size),
"--n-gpu-layers", str(shared.args.n_gpu_layers),
"--batch-size", str(shared.args.batch_size),
"--port", str(self.port),

View file

@ -9,7 +9,7 @@ loaders_and_params = OrderedDict({
'threads',
'threads_batch',
'batch_size',
'n_ctx',
'ctx_size',
'cache_type',
'tensor_split',
'extra_flags',
@ -48,14 +48,14 @@ loaders_and_params = OrderedDict({
'no_use_fast',
],
'ExLlamav3_HF': [
'max_seq_len',
'ctx_size',
'gpu_split',
'cfg_cache',
'trust_remote_code',
'no_use_fast',
],
'ExLlamav2_HF': [
'max_seq_len',
'ctx_size',
'cache_type',
'gpu_split',
'alpha_value',
@ -71,7 +71,7 @@ loaders_and_params = OrderedDict({
'no_use_fast',
],
'ExLlamav2': [
'max_seq_len',
'ctx_size',
'cache_type',
'gpu_split',
'alpha_value',
@ -93,7 +93,7 @@ loaders_and_params = OrderedDict({
'no_use_fast',
],
'TensorRT-LLM': [
'max_seq_len',
'ctx_size',
'cpp_runner',
'tensorrt_llm_info',
]

View file

@ -52,10 +52,8 @@ def load_model(model_name, loader=None):
tokenizer = load_tokenizer(model_name)
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'):
shared.settings['truncation_length'] = shared.args.max_seq_len
elif loader == 'llama.cpp':
shared.settings['truncation_length'] = shared.args.n_ctx
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
shared.settings['truncation_length'] = shared.args.ctx_size
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
logger.info(f"LOADER: \"{loader}\"")

View file

@ -11,8 +11,7 @@ def get_fallback_settings():
return {
'bf16': False,
'use_eager_attention': False,
'max_seq_len': 2048,
'n_ctx': 2048,
'ctx_size': 2048,
'rope_freq_base': 0,
'compress_pos_emb': 1,
'alpha_value': 1,
@ -59,7 +58,7 @@ def get_model_metadata(model):
for k in metadata:
if k.endswith('context_length'):
model_settings['n_ctx'] = min(metadata[k], 8192)
model_settings['ctx_size'] = min(metadata[k], 8192)
model_settings['truncation_length_info'] = metadata[k]
elif k.endswith('rope.freq_base'):
model_settings['rope_freq_base'] = metadata[k]
@ -97,7 +96,7 @@ def get_model_metadata(model):
if k in metadata:
model_settings['truncation_length'] = metadata[k]
model_settings['truncation_length_info'] = metadata[k]
model_settings['max_seq_len'] = min(metadata[k], 8192)
model_settings['ctx_size'] = min(metadata[k], 8192)
if 'rope_theta' in metadata:
model_settings['rope_freq_base'] = metadata['rope_theta']

View file

@ -116,7 +116,6 @@ group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for
# llama.cpp
group = parser.add_argument_group('llama.cpp')
group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.')
group.add_argument('--n_ctx', type=int, default=8192, help='Size of the prompt context.')
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
group.add_argument('--batch-size', type=int, default=256, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
@ -130,6 +129,11 @@ group.add_argument('--row-split', action='store_true', help='Split the model by
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1;flag2;flag3=value3". Example: "override-tensor=exps=CPU"')
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
# Cache
group = parser.add_argument_group('Context and cache management')
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, help='Context size in tokens.')
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
# Speculative decoding
group = parser.add_argument_group('Speculative decoding')
group.add_argument('--model-draft', type=str, default=None, help='Path to the draft model for speculative decoding.')
@ -142,7 +146,6 @@ group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the pr
group = parser.add_argument_group('ExLlamaV2')
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.')
group.add_argument('--max_seq_len', type=int, default=8192, help='Maximum sequence length.')
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')

View file

@ -1,15 +1,15 @@
from pathlib import Path
import tensorrt_llm
import torch
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
import tensorrt_llm
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import (
get_max_prompt_length,
get_reply_from_output_ids
)
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
class TensorRTLLMModel:
@ -35,7 +35,7 @@ class TensorRTLLMModel:
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
runner_kwargs.update(
max_batch_size=1,
max_input_len=shared.args.max_seq_len - 512,
max_input_len=shared.args.ctx_size - 512,
max_output_len=512,
max_beam_width=1,
max_attention_window_size=None,

View file

@ -110,8 +110,7 @@ def list_model_elements():
'threads_batch',
'batch_size',
'hqq_backend',
'n_ctx',
'max_seq_len',
'ctx_size',
'cache_type',
'tensor_split',
'extra_flags',

View file

@ -51,8 +51,7 @@ def create_ui():
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
shared.gradio['n_ctx'] = gr.Number(label="n_ctx", precision=0, step=256, value=shared.args.n_ctx, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.')
shared.gradio['max_seq_len'] = gr.Number(label='max_seq_len', precision=0, step=256, value=shared.args.max_seq_len, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.')
shared.gradio['ctx_size'] = gr.Number(label='ctx_size', precision=0, step=256, value=shared.args.ctx_size, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768, 65536.')
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
@ -92,7 +91,7 @@ def create_ui():
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.")
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `max_seq_len` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
# Speculative decoding
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
@ -247,10 +246,8 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
def update_truncation_length(current_length, state):
if 'loader' in state:
if state['loader'].lower().startswith('exllama'):
return state['max_seq_len']
elif state['loader'] == 'llama.cpp':
return state['n_ctx']
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
return state['ctx_size']
return current_length

View file

@ -121,10 +121,8 @@ def create_event_handlers():
def get_truncation_length():
if 'max_seq_len' in shared.provided_arguments or shared.args.max_seq_len != shared.args_defaults.max_seq_len:
return shared.args.max_seq_len
elif 'n_ctx' in shared.provided_arguments or shared.args.n_ctx != shared.args_defaults.n_ctx:
return shared.args.n_ctx
if 'ctx_size' in shared.provided_arguments or shared.args.ctx_size != shared.args_defaults.ctx_size:
return shared.args.ctx_size
else:
return shared.settings['truncation_length']