llama.cpp: Fix the sampler priority handling

This commit is contained in:
oobabooga 2025-04-18 18:06:36 -07:00
parent 5ad080ff25
commit 71ae05e0a4

View file

@ -56,17 +56,6 @@ class LlamaServer:
return result.get("content", "")
def prepare_payload(self, state):
# Prepare DRY
dry_sequence_breakers = state['dry_sequence_breakers']
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
dry_sequence_breakers = json.loads(dry_sequence_breakers)
# Move temperature to the end if temperature_last is true and temperature exists in the list
if state["temperature_last"] and "temperature" in samplers:
samplers.remove("temperature")
samplers.append("temperature")
payload = {
"temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2,
"dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2,
@ -84,7 +73,6 @@ class LlamaServer:
"dry_base": state["dry_base"],
"dry_allowed_length": state["dry_allowed_length"],
"dry_penalty_last_n": state["repetition_penalty_range"],
"dry_sequence_breakers": dry_sequence_breakers,
"xtc_probability": state["xtc_probability"],
"xtc_threshold": state["xtc_threshold"],
"mirostat": state["mirostat_mode"],
@ -95,6 +83,14 @@ class LlamaServer:
"ignore_eos": state["ban_eos_token"],
}
# DRY
dry_sequence_breakers = state['dry_sequence_breakers']
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
dry_sequence_breakers = json.loads(dry_sequence_breakers)
payload["dry_sequence_breakers"] = dry_sequence_breakers
# Sampler order
if state["sampler_priority"]:
samplers = state["sampler_priority"]
@ -109,6 +105,11 @@ class LlamaServer:
filtered_samplers.append("penalties")
penalty_found = True
# Move temperature to the end if temperature_last is true and temperature exists in the list
if state["temperature_last"] and "temperature" in samplers:
samplers.remove("temperature")
samplers.append("temperature")
payload["samplers"] = filtered_samplers
if state['custom_token_bans']: