From 4fabd729c9c05b2a5ee423bd22e06792962fa703 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 18 Apr 2025 17:25:22 -0700 Subject: [PATCH] Fix the API without streaming or without 'sampler_priority' (closes #6851) --- modules/llama_cpp_server.py | 44 ++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 8180f974..123f9471 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -61,20 +61,6 @@ class LlamaServer: dry_sequence_breakers = "[" + dry_sequence_breakers + "]" dry_sequence_breakers = json.loads(dry_sequence_breakers) - # Prepare the sampler order - samplers = state["sampler_priority"] - samplers = samplers.split("\n") if isinstance(samplers, str) else samplers - penalty_found = False - filtered_samplers = [] - for s in samplers: - if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]: - filtered_samplers.append(s.strip()) - elif not penalty_found and s.strip() == "repetition_penalty": - filtered_samplers.append("penalties") - penalty_found = True - - samplers = filtered_samplers - # Move temperature to the end if temperature_last is true and temperature exists in the list if state["temperature_last"] and "temperature" in samplers: samplers.remove("temperature") @@ -106,20 +92,31 @@ class LlamaServer: "grammar": state["grammar_string"], "seed": state["seed"], "ignore_eos": state["ban_eos_token"], - "samplers": samplers, } + # Sampler order + if state["sampler_priority"]: + samplers = state["sampler_priority"] + samplers = samplers.split("\n") if isinstance(samplers, str) else samplers + filtered_samplers = [] + + penalty_found = False + for s in samplers: + if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]: + filtered_samplers.append(s.strip()) + elif not penalty_found and s.strip() == "repetition_penalty": + filtered_samplers.append("penalties") + penalty_found = True + + payload["samplers"] = filtered_samplers + if state['custom_token_bans']: to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')] payload["logit_bias"] = to_ban return payload - def generate_with_streaming( - self, - prompt, - state, - ): + def generate_with_streaming(self, prompt, state): url = f"http://localhost:{self.port}/completion" payload = self.prepare_payload(state) @@ -178,6 +175,13 @@ class LlamaServer: print(f"Problematic line: {line}") continue + def generate(self, prompt, state): + output = "" + for output in self.generate_with_streaming(prompt, state): + pass + + return output + def get_logits(self, prompt, state, n_probs=128, use_samplers=False): """Get the logits/probabilities for the next token after a prompt""" url = f"http://localhost:{self.port}/completion"