From bbd43d9463cded0aae3ce4e1ced1519693bebf9a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 12 Mar 2026 14:54:05 -0700 Subject: [PATCH] UI: Correctly propagate truncation_length when ctx_size is auto --- modules/llama_cpp_server.py | 8 +++++++- modules/models.py | 2 ++ modules/ui_model_menu.py | 6 +++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index a3e431ac..192aa9e4 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -36,6 +36,7 @@ class LlamaServer: self.process = None self.session = requests.Session() self.vocabulary_size = None + self.n_ctx = None self.bos_token = "" self.last_prompt_token_count = 0 @@ -320,12 +321,17 @@ class LlamaServer: self.vocabulary_size = model_info["meta"]["n_vocab"] def _get_bos_token(self): - """Get and store the model's BOS token.""" + """Get and store the model's BOS token and context size.""" url = f"http://127.0.0.1:{self.port}/props" response = self.session.get(url).json() if "bos_token" in response: self.bos_token = response["bos_token"] + # Get actual n_ctx from the server (important when --fit auto-selects it) + n_ctx = response.get("default_generation_settings", {}).get("n_ctx") + if n_ctx: + self.n_ctx = n_ctx + def _is_port_available(self, port): """Check if a port is available for use.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: diff --git a/modules/models.py b/modules/models.py index 48d68b0b..d83b98d7 100644 --- a/modules/models.py +++ b/modules/models.py @@ -54,6 +54,8 @@ def load_model(model_name, loader=None): if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp': if shared.args.ctx_size > 0: shared.settings['truncation_length'] = shared.args.ctx_size + elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx: + shared.settings['truncation_length'] = model.n_ctx shared.is_multimodal = False if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'): diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 7e91f1ce..5c83096f 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -388,7 +388,11 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur def update_truncation_length(current_length, state): if 'loader' in state: if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp': - return state['ctx_size'] + if state['ctx_size'] > 0: + return state['ctx_size'] + + # ctx_size == 0 means auto: use the actual value from the server + return shared.settings['truncation_length'] return current_length