mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Use --ctx-size to specify the context size for all loaders
Old flags are still recognized as alternatives.
This commit is contained in:
parent
faababc4ea
commit
d4b1e31c49
|
|
@ -3,6 +3,7 @@ import traceback
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
|
|
@ -15,7 +16,6 @@ from exllamav2 import (
|
|||
ExLlamaV2Tokenizer
|
||||
)
|
||||
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import get_max_prompt_length
|
||||
|
|
@ -40,7 +40,7 @@ class Exllamav2Model:
|
|||
config.model_dir = str(path_to_model)
|
||||
config.prepare()
|
||||
|
||||
config.max_seq_len = shared.args.max_seq_len
|
||||
config.max_seq_len = shared.args.ctx_size
|
||||
config.scale_pos_emb = shared.args.compress_pos_emb
|
||||
config.scale_alpha_value = shared.args.alpha_value
|
||||
config.no_flash_attn = shared.args.no_flash_attn
|
||||
|
|
|
|||
|
|
@ -4,6 +4,15 @@ from pathlib import Path
|
|||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
|
|
@ -14,15 +23,6 @@ from exllamav2 import (
|
|||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Config
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
|
@ -192,7 +192,7 @@ class Exllamav2HF(PreTrainedModel, GenerationMixin):
|
|||
config.model_dir = str(pretrained_model_name_or_path)
|
||||
config.prepare()
|
||||
|
||||
config.max_seq_len = shared.args.max_seq_len
|
||||
config.max_seq_len = shared.args.ctx_size
|
||||
config.scale_pos_emb = shared.args.compress_pos_emb
|
||||
config.scale_alpha_value = shared.args.alpha_value
|
||||
config.no_flash_attn = shared.args.no_flash_attn
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
|||
self.ex_model = Model.from_config(config)
|
||||
|
||||
# Calculate the closest multiple of 256 at or above the chosen value
|
||||
max_tokens = shared.args.max_seq_len
|
||||
max_tokens = shared.args.ctx_size
|
||||
if max_tokens % 256 != 0:
|
||||
adjusted_tokens = ((max_tokens // 256) + 1) * 256
|
||||
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ class LlamaServer:
|
|||
cmd = [
|
||||
self.server_path,
|
||||
"--model", self.model_path,
|
||||
"--ctx-size", str(shared.args.n_ctx),
|
||||
"--ctx-size", str(shared.args.ctx_size),
|
||||
"--n-gpu-layers", str(shared.args.n_gpu_layers),
|
||||
"--batch-size", str(shared.args.batch_size),
|
||||
"--port", str(self.port),
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ loaders_and_params = OrderedDict({
|
|||
'threads',
|
||||
'threads_batch',
|
||||
'batch_size',
|
||||
'n_ctx',
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
'tensor_split',
|
||||
'extra_flags',
|
||||
|
|
@ -48,14 +48,14 @@ loaders_and_params = OrderedDict({
|
|||
'no_use_fast',
|
||||
],
|
||||
'ExLlamav3_HF': [
|
||||
'max_seq_len',
|
||||
'ctx_size',
|
||||
'gpu_split',
|
||||
'cfg_cache',
|
||||
'trust_remote_code',
|
||||
'no_use_fast',
|
||||
],
|
||||
'ExLlamav2_HF': [
|
||||
'max_seq_len',
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
'gpu_split',
|
||||
'alpha_value',
|
||||
|
|
@ -71,7 +71,7 @@ loaders_and_params = OrderedDict({
|
|||
'no_use_fast',
|
||||
],
|
||||
'ExLlamav2': [
|
||||
'max_seq_len',
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
'gpu_split',
|
||||
'alpha_value',
|
||||
|
|
@ -93,7 +93,7 @@ loaders_and_params = OrderedDict({
|
|||
'no_use_fast',
|
||||
],
|
||||
'TensorRT-LLM': [
|
||||
'max_seq_len',
|
||||
'ctx_size',
|
||||
'cpp_runner',
|
||||
'tensorrt_llm_info',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -52,10 +52,8 @@ def load_model(model_name, loader=None):
|
|||
tokenizer = load_tokenizer(model_name)
|
||||
|
||||
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'):
|
||||
shared.settings['truncation_length'] = shared.args.max_seq_len
|
||||
elif loader == 'llama.cpp':
|
||||
shared.settings['truncation_length'] = shared.args.n_ctx
|
||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
||||
shared.settings['truncation_length'] = shared.args.ctx_size
|
||||
|
||||
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
||||
logger.info(f"LOADER: \"{loader}\"")
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ def get_fallback_settings():
|
|||
return {
|
||||
'bf16': False,
|
||||
'use_eager_attention': False,
|
||||
'max_seq_len': 2048,
|
||||
'n_ctx': 2048,
|
||||
'ctx_size': 2048,
|
||||
'rope_freq_base': 0,
|
||||
'compress_pos_emb': 1,
|
||||
'alpha_value': 1,
|
||||
|
|
@ -59,7 +58,7 @@ def get_model_metadata(model):
|
|||
|
||||
for k in metadata:
|
||||
if k.endswith('context_length'):
|
||||
model_settings['n_ctx'] = min(metadata[k], 8192)
|
||||
model_settings['ctx_size'] = min(metadata[k], 8192)
|
||||
model_settings['truncation_length_info'] = metadata[k]
|
||||
elif k.endswith('rope.freq_base'):
|
||||
model_settings['rope_freq_base'] = metadata[k]
|
||||
|
|
@ -97,7 +96,7 @@ def get_model_metadata(model):
|
|||
if k in metadata:
|
||||
model_settings['truncation_length'] = metadata[k]
|
||||
model_settings['truncation_length_info'] = metadata[k]
|
||||
model_settings['max_seq_len'] = min(metadata[k], 8192)
|
||||
model_settings['ctx_size'] = min(metadata[k], 8192)
|
||||
|
||||
if 'rope_theta' in metadata:
|
||||
model_settings['rope_freq_base'] = metadata['rope_theta']
|
||||
|
|
|
|||
|
|
@ -116,7 +116,6 @@ group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for
|
|||
# llama.cpp
|
||||
group = parser.add_argument_group('llama.cpp')
|
||||
group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.')
|
||||
group.add_argument('--n_ctx', type=int, default=8192, help='Size of the prompt context.')
|
||||
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
||||
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
||||
group.add_argument('--batch-size', type=int, default=256, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
|
||||
|
|
@ -130,6 +129,11 @@ group.add_argument('--row-split', action='store_true', help='Split the model by
|
|||
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1;flag2;flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||
|
||||
# Cache
|
||||
group = parser.add_argument_group('Context and cache management')
|
||||
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, help='Context size in tokens.')
|
||||
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
|
||||
|
||||
# Speculative decoding
|
||||
group = parser.add_argument_group('Speculative decoding')
|
||||
group.add_argument('--model-draft', type=str, default=None, help='Path to the draft model for speculative decoding.')
|
||||
|
|
@ -142,7 +146,6 @@ group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the pr
|
|||
group = parser.add_argument_group('ExLlamaV2')
|
||||
group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
|
||||
group.add_argument('--autosplit', action='store_true', help='Autosplit the model tensors across the available GPUs. This causes --gpu-split to be ignored.')
|
||||
group.add_argument('--max_seq_len', type=int, default=8192, help='Maximum sequence length.')
|
||||
group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.')
|
||||
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
|
||||
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from pathlib import Path
|
||||
|
||||
import tensorrt_llm
|
||||
import torch
|
||||
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
|
||||
|
||||
import tensorrt_llm
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import (
|
||||
get_max_prompt_length,
|
||||
get_reply_from_output_ids
|
||||
)
|
||||
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
|
||||
|
||||
|
||||
class TensorRTLLMModel:
|
||||
|
|
@ -35,7 +35,7 @@ class TensorRTLLMModel:
|
|||
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
|
||||
runner_kwargs.update(
|
||||
max_batch_size=1,
|
||||
max_input_len=shared.args.max_seq_len - 512,
|
||||
max_input_len=shared.args.ctx_size - 512,
|
||||
max_output_len=512,
|
||||
max_beam_width=1,
|
||||
max_attention_window_size=None,
|
||||
|
|
|
|||
|
|
@ -110,8 +110,7 @@ def list_model_elements():
|
|||
'threads_batch',
|
||||
'batch_size',
|
||||
'hqq_backend',
|
||||
'n_ctx',
|
||||
'max_seq_len',
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
'tensor_split',
|
||||
'extra_flags',
|
||||
|
|
|
|||
|
|
@ -51,8 +51,7 @@ def create_ui():
|
|||
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
|
||||
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
||||
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
|
||||
shared.gradio['n_ctx'] = gr.Number(label="n_ctx", precision=0, step=256, value=shared.args.n_ctx, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.')
|
||||
shared.gradio['max_seq_len'] = gr.Number(label='max_seq_len', precision=0, step=256, value=shared.args.max_seq_len, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.')
|
||||
shared.gradio['ctx_size'] = gr.Number(label='ctx_size', precision=0, step=256, value=shared.args.ctx_size, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768, 65536.')
|
||||
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
|
||||
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
||||
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||
|
|
@ -92,7 +91,7 @@ def create_ui():
|
|||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
||||
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
|
||||
shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.")
|
||||
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `max_seq_len` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
||||
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
||||
|
||||
# Speculative decoding
|
||||
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
||||
|
|
@ -247,10 +246,8 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
|||
|
||||
def update_truncation_length(current_length, state):
|
||||
if 'loader' in state:
|
||||
if state['loader'].lower().startswith('exllama'):
|
||||
return state['max_seq_len']
|
||||
elif state['loader'] == 'llama.cpp':
|
||||
return state['n_ctx']
|
||||
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
|
||||
return state['ctx_size']
|
||||
|
||||
return current_length
|
||||
|
||||
|
|
|
|||
|
|
@ -121,10 +121,8 @@ def create_event_handlers():
|
|||
|
||||
|
||||
def get_truncation_length():
|
||||
if 'max_seq_len' in shared.provided_arguments or shared.args.max_seq_len != shared.args_defaults.max_seq_len:
|
||||
return shared.args.max_seq_len
|
||||
elif 'n_ctx' in shared.provided_arguments or shared.args.n_ctx != shared.args_defaults.n_ctx:
|
||||
return shared.args.n_ctx
|
||||
if 'ctx_size' in shared.provided_arguments or shared.args.ctx_size != shared.args_defaults.ctx_size:
|
||||
return shared.args.ctx_size
|
||||
else:
|
||||
return shared.settings['truncation_length']
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue