Fix the API without streaming or without 'sampler_priority' (closes #6851)

This commit is contained in:
oobabooga 2025-04-18 17:25:22 -07:00
parent 5135523429
commit 4fabd729c9

View file

@ -61,20 +61,6 @@ class LlamaServer:
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
dry_sequence_breakers = json.loads(dry_sequence_breakers)
# Prepare the sampler order
samplers = state["sampler_priority"]
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
penalty_found = False
filtered_samplers = []
for s in samplers:
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
filtered_samplers.append(s.strip())
elif not penalty_found and s.strip() == "repetition_penalty":
filtered_samplers.append("penalties")
penalty_found = True
samplers = filtered_samplers
# Move temperature to the end if temperature_last is true and temperature exists in the list
if state["temperature_last"] and "temperature" in samplers:
samplers.remove("temperature")
@ -106,20 +92,31 @@ class LlamaServer:
"grammar": state["grammar_string"],
"seed": state["seed"],
"ignore_eos": state["ban_eos_token"],
"samplers": samplers,
}
# Sampler order
if state["sampler_priority"]:
samplers = state["sampler_priority"]
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
filtered_samplers = []
penalty_found = False
for s in samplers:
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
filtered_samplers.append(s.strip())
elif not penalty_found and s.strip() == "repetition_penalty":
filtered_samplers.append("penalties")
penalty_found = True
payload["samplers"] = filtered_samplers
if state['custom_token_bans']:
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
payload["logit_bias"] = to_ban
return payload
def generate_with_streaming(
self,
prompt,
state,
):
def generate_with_streaming(self, prompt, state):
url = f"http://localhost:{self.port}/completion"
payload = self.prepare_payload(state)
@ -178,6 +175,13 @@ class LlamaServer:
print(f"Problematic line: {line}")
continue
def generate(self, prompt, state):
output = ""
for output in self.generate_with_streaming(prompt, state):
pass
return output
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
"""Get the logits/probabilities for the next token after a prompt"""
url = f"http://localhost:{self.port}/completion"