tiny improvements to llama_cpp_server.py

This commit is contained in:
oobabooga 2026-03-04 15:53:24 -08:00
parent 83cc207ef7
commit da3010c3ed

View file

@ -208,8 +208,9 @@ class LlamaServer:
# Make the generation request
response = self.session.post(url, json=payload, stream=True)
try:
if response.status_code == 400 and response.json()["error"]["type"] == "exceed_context_size_error":
if response.status_code == 400 and response.json().get("error", {}).get("type") == "exceed_context_size_error":
logger.error("The request exceeds the available context size, try increasing it")
return
else:
response.raise_for_status() # Raise an exception for HTTP errors
@ -286,6 +287,8 @@ class LlamaServer:
return result["completion_probabilities"][0]["top_probs"]
else:
return result["completion_probabilities"][0]["top_logprobs"]
time.sleep(0.05)
else:
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
@ -390,7 +393,7 @@ class LlamaServer:
else:
model_file = sorted(path.glob('*.gguf'))[0]
cmd += ["--model-draft", model_file]
cmd += ["--model-draft", str(model_file)]
if shared.args.draft_max > 0:
cmd += ["--draft-max", str(shared.args.draft_max)]
if shared.args.gpu_layers_draft > 0:
@ -417,8 +420,11 @@ class LlamaServer:
extra_flags = extra_flags[1:-1].strip()
for flag_item in extra_flags.split(','):
flag_item = flag_item.strip()
if '=' in flag_item:
flag, value = flag_item.split('=', 1)
flag = flag.strip()
value = value.strip()
if len(flag) <= 3:
cmd += [f"-{flag}", value]
else: