mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-31 04:44:44 +01:00
llama.cpp: Fix the sampler priority handling
This commit is contained in:
parent
5ad080ff25
commit
71ae05e0a4
|
|
@ -56,17 +56,6 @@ class LlamaServer:
|
|||
return result.get("content", "")
|
||||
|
||||
def prepare_payload(self, state):
|
||||
# Prepare DRY
|
||||
dry_sequence_breakers = state['dry_sequence_breakers']
|
||||
if not dry_sequence_breakers.startswith("["):
|
||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
||||
|
||||
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
||||
if state["temperature_last"] and "temperature" in samplers:
|
||||
samplers.remove("temperature")
|
||||
samplers.append("temperature")
|
||||
|
||||
payload = {
|
||||
"temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2,
|
||||
"dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2,
|
||||
|
|
@ -84,7 +73,6 @@ class LlamaServer:
|
|||
"dry_base": state["dry_base"],
|
||||
"dry_allowed_length": state["dry_allowed_length"],
|
||||
"dry_penalty_last_n": state["repetition_penalty_range"],
|
||||
"dry_sequence_breakers": dry_sequence_breakers,
|
||||
"xtc_probability": state["xtc_probability"],
|
||||
"xtc_threshold": state["xtc_threshold"],
|
||||
"mirostat": state["mirostat_mode"],
|
||||
|
|
@ -95,6 +83,14 @@ class LlamaServer:
|
|||
"ignore_eos": state["ban_eos_token"],
|
||||
}
|
||||
|
||||
# DRY
|
||||
dry_sequence_breakers = state['dry_sequence_breakers']
|
||||
if not dry_sequence_breakers.startswith("["):
|
||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
|
||||
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
||||
payload["dry_sequence_breakers"] = dry_sequence_breakers
|
||||
|
||||
# Sampler order
|
||||
if state["sampler_priority"]:
|
||||
samplers = state["sampler_priority"]
|
||||
|
|
@ -109,6 +105,11 @@ class LlamaServer:
|
|||
filtered_samplers.append("penalties")
|
||||
penalty_found = True
|
||||
|
||||
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
||||
if state["temperature_last"] and "temperature" in samplers:
|
||||
samplers.remove("temperature")
|
||||
samplers.append("temperature")
|
||||
|
||||
payload["samplers"] = filtered_samplers
|
||||
|
||||
if state['custom_token_bans']:
|
||||
|
|
|
|||
Loading…
Reference in a new issue