UI: Correctly propagate truncation_length when ctx_size is auto

This commit is contained in:
oobabooga 2026-03-12 14:54:05 -07:00
parent 3e6bd1a310
commit bbd43d9463
3 changed files with 14 additions and 2 deletions

View file

@ -36,6 +36,7 @@ class LlamaServer:
self.process = None
self.session = requests.Session()
self.vocabulary_size = None
self.n_ctx = None
self.bos_token = "<s>"
self.last_prompt_token_count = 0
@ -320,12 +321,17 @@ class LlamaServer:
self.vocabulary_size = model_info["meta"]["n_vocab"]
def _get_bos_token(self):
"""Get and store the model's BOS token."""
"""Get and store the model's BOS token and context size."""
url = f"http://127.0.0.1:{self.port}/props"
response = self.session.get(url).json()
if "bos_token" in response:
self.bos_token = response["bos_token"]
# Get actual n_ctx from the server (important when --fit auto-selects it)
n_ctx = response.get("default_generation_settings", {}).get("n_ctx")
if n_ctx:
self.n_ctx = n_ctx
def _is_port_available(self, port):
"""Check if a port is available for use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:

View file

@ -54,6 +54,8 @@ def load_model(model_name, loader=None):
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
if shared.args.ctx_size > 0:
shared.settings['truncation_length'] = shared.args.ctx_size
elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx:
shared.settings['truncation_length'] = model.n_ctx
shared.is_multimodal = False
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):

View file

@ -388,7 +388,11 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
def update_truncation_length(current_length, state):
if 'loader' in state:
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
return state['ctx_size']
if state['ctx_size'] > 0:
return state['ctx_size']
# ctx_size == 0 means auto: use the actual value from the server
return shared.settings['truncation_length']
return current_length