mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
UI: Correctly propagate truncation_length when ctx_size is auto
This commit is contained in:
parent
3e6bd1a310
commit
bbd43d9463
|
|
@ -36,6 +36,7 @@ class LlamaServer:
|
|||
self.process = None
|
||||
self.session = requests.Session()
|
||||
self.vocabulary_size = None
|
||||
self.n_ctx = None
|
||||
self.bos_token = "<s>"
|
||||
self.last_prompt_token_count = 0
|
||||
|
||||
|
|
@ -320,12 +321,17 @@ class LlamaServer:
|
|||
self.vocabulary_size = model_info["meta"]["n_vocab"]
|
||||
|
||||
def _get_bos_token(self):
|
||||
"""Get and store the model's BOS token."""
|
||||
"""Get and store the model's BOS token and context size."""
|
||||
url = f"http://127.0.0.1:{self.port}/props"
|
||||
response = self.session.get(url).json()
|
||||
if "bos_token" in response:
|
||||
self.bos_token = response["bos_token"]
|
||||
|
||||
# Get actual n_ctx from the server (important when --fit auto-selects it)
|
||||
n_ctx = response.get("default_generation_settings", {}).get("n_ctx")
|
||||
if n_ctx:
|
||||
self.n_ctx = n_ctx
|
||||
|
||||
def _is_port_available(self, port):
|
||||
"""Check if a port is available for use."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ def load_model(model_name, loader=None):
|
|||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
||||
if shared.args.ctx_size > 0:
|
||||
shared.settings['truncation_length'] = shared.args.ctx_size
|
||||
elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx:
|
||||
shared.settings['truncation_length'] = model.n_ctx
|
||||
|
||||
shared.is_multimodal = False
|
||||
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):
|
||||
|
|
|
|||
|
|
@ -388,7 +388,11 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
|||
def update_truncation_length(current_length, state):
|
||||
if 'loader' in state:
|
||||
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
|
||||
return state['ctx_size']
|
||||
if state['ctx_size'] > 0:
|
||||
return state['ctx_size']
|
||||
|
||||
# ctx_size == 0 means auto: use the actual value from the server
|
||||
return shared.settings['truncation_length']
|
||||
|
||||
return current_length
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue