diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 8180f974..123f9471 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -61,20 +61,6 @@ class LlamaServer: dry_sequence_breakers = "[" + dry_sequence_breakers + "]" dry_sequence_breakers = json.loads(dry_sequence_breakers) - # Prepare the sampler order - samplers = state["sampler_priority"] - samplers = samplers.split("\n") if isinstance(samplers, str) else samplers - penalty_found = False - filtered_samplers = [] - for s in samplers: - if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]: - filtered_samplers.append(s.strip()) - elif not penalty_found and s.strip() == "repetition_penalty": - filtered_samplers.append("penalties") - penalty_found = True - - samplers = filtered_samplers - # Move temperature to the end if temperature_last is true and temperature exists in the list if state["temperature_last"] and "temperature" in samplers: samplers.remove("temperature") @@ -106,20 +92,31 @@ class LlamaServer: "grammar": state["grammar_string"], "seed": state["seed"], "ignore_eos": state["ban_eos_token"], - "samplers": samplers, } + # Sampler order + if state["sampler_priority"]: + samplers = state["sampler_priority"] + samplers = samplers.split("\n") if isinstance(samplers, str) else samplers + filtered_samplers = [] + + penalty_found = False + for s in samplers: + if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]: + filtered_samplers.append(s.strip()) + elif not penalty_found and s.strip() == "repetition_penalty": + filtered_samplers.append("penalties") + penalty_found = True + + payload["samplers"] = filtered_samplers + if state['custom_token_bans']: to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')] payload["logit_bias"] = to_ban return payload - def generate_with_streaming( - self, - prompt, - state, - ): + def generate_with_streaming(self, prompt, state): url = f"http://localhost:{self.port}/completion" payload = self.prepare_payload(state) @@ -178,6 +175,13 @@ class LlamaServer: print(f"Problematic line: {line}") continue + def generate(self, prompt, state): + output = "" + for output in self.generate_with_streaming(prompt, state): + pass + + return output + def get_logits(self, prompt, state, n_probs=128, use_samplers=False): """Get the logits/probabilities for the next token after a prompt""" url = f"http://localhost:{self.port}/completion"