mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-08 06:33:51 +01:00
API: Speed up chat completions by ~85ms per request
This commit is contained in:
parent
249bd6eea2
commit
4c406e024f
|
|
@ -288,8 +288,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
return chunk
|
||||
|
||||
# generate reply #######################################
|
||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
||||
if prompt_only:
|
||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
||||
yield {'prompt': prompt}
|
||||
return
|
||||
|
||||
|
|
@ -335,7 +335,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
if len(tool_calls) > 0:
|
||||
break
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if len(tool_calls) > 0:
|
||||
|
|
|
|||
|
|
@ -76,6 +76,18 @@ jinja_env = ImmutableSandboxedEnvironment(
|
|||
)
|
||||
jinja_env.globals["strftime_now"] = strftime_now
|
||||
|
||||
_template_cache = {}
|
||||
|
||||
|
||||
def get_compiled_template(template_str):
|
||||
"""Cache compiled Jinja2 templates keyed by their source string."""
|
||||
compiled = _template_cache.get(template_str)
|
||||
if compiled is None:
|
||||
compiled = jinja_env.from_string(template_str)
|
||||
_template_cache[template_str] = compiled
|
||||
|
||||
return compiled
|
||||
|
||||
|
||||
def str_presenter(dumper, data):
|
||||
"""
|
||||
|
|
@ -106,8 +118,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
if state['mode'] != 'instruct':
|
||||
chat_template_str = replace_character_names(chat_template_str, state['name1'], state['name2'])
|
||||
|
||||
instruction_template = jinja_env.from_string(state['instruction_template_str'])
|
||||
chat_template = jinja_env.from_string(chat_template_str)
|
||||
instruction_template = get_compiled_template(state['instruction_template_str'])
|
||||
chat_template = get_compiled_template(chat_template_str)
|
||||
|
||||
instruct_renderer = partial(
|
||||
instruction_template.render,
|
||||
|
|
@ -481,12 +493,12 @@ def get_stopping_strings(state):
|
|||
renderers = []
|
||||
|
||||
if state['mode'] in ['instruct', 'chat-instruct']:
|
||||
template = jinja_env.from_string(state['instruction_template_str'])
|
||||
template = get_compiled_template(state['instruction_template_str'])
|
||||
renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token)
|
||||
renderers.append(renderer)
|
||||
|
||||
if state['mode'] in ['chat']:
|
||||
template = jinja_env.from_string(state['chat_template_str'])
|
||||
template = get_compiled_template(state['chat_template_str'])
|
||||
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
|
||||
renderers.append(renderer)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue