mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-31 04:44:44 +01:00
Fix the API without streaming or without 'sampler_priority' (closes #6851)
This commit is contained in:
parent
5135523429
commit
4fabd729c9
|
|
@ -61,20 +61,6 @@ class LlamaServer:
|
|||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
||||
|
||||
# Prepare the sampler order
|
||||
samplers = state["sampler_priority"]
|
||||
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
|
||||
penalty_found = False
|
||||
filtered_samplers = []
|
||||
for s in samplers:
|
||||
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
|
||||
filtered_samplers.append(s.strip())
|
||||
elif not penalty_found and s.strip() == "repetition_penalty":
|
||||
filtered_samplers.append("penalties")
|
||||
penalty_found = True
|
||||
|
||||
samplers = filtered_samplers
|
||||
|
||||
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
||||
if state["temperature_last"] and "temperature" in samplers:
|
||||
samplers.remove("temperature")
|
||||
|
|
@ -106,20 +92,31 @@ class LlamaServer:
|
|||
"grammar": state["grammar_string"],
|
||||
"seed": state["seed"],
|
||||
"ignore_eos": state["ban_eos_token"],
|
||||
"samplers": samplers,
|
||||
}
|
||||
|
||||
# Sampler order
|
||||
if state["sampler_priority"]:
|
||||
samplers = state["sampler_priority"]
|
||||
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
|
||||
filtered_samplers = []
|
||||
|
||||
penalty_found = False
|
||||
for s in samplers:
|
||||
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
|
||||
filtered_samplers.append(s.strip())
|
||||
elif not penalty_found and s.strip() == "repetition_penalty":
|
||||
filtered_samplers.append("penalties")
|
||||
penalty_found = True
|
||||
|
||||
payload["samplers"] = filtered_samplers
|
||||
|
||||
if state['custom_token_bans']:
|
||||
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
||||
payload["logit_bias"] = to_ban
|
||||
|
||||
return payload
|
||||
|
||||
def generate_with_streaming(
|
||||
self,
|
||||
prompt,
|
||||
state,
|
||||
):
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
url = f"http://localhost:{self.port}/completion"
|
||||
payload = self.prepare_payload(state)
|
||||
|
||||
|
|
@ -178,6 +175,13 @@ class LlamaServer:
|
|||
print(f"Problematic line: {line}")
|
||||
continue
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
for output in self.generate_with_streaming(prompt, state):
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
||||
"""Get the logits/probabilities for the next token after a prompt"""
|
||||
url = f"http://localhost:{self.port}/completion"
|
||||
|
|
|
|||
Loading…
Reference in a new issue