API: Speed up chat completions by ~85ms per request

This commit is contained in:
oobabooga 2026-03-05 18:36:07 -08:00
parent 249bd6eea2
commit 4c406e024f
2 changed files with 18 additions and 6 deletions

View file

@ -288,8 +288,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
return chunk
# generate reply #######################################
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
if prompt_only:
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
yield {'prompt': prompt}
return
@ -335,7 +335,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if len(tool_calls) > 0:
break
token_count = len(encode(prompt)[0])
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if len(tool_calls) > 0:

View file

@ -76,6 +76,18 @@ jinja_env = ImmutableSandboxedEnvironment(
)
jinja_env.globals["strftime_now"] = strftime_now
_template_cache = {}
def get_compiled_template(template_str):
"""Cache compiled Jinja2 templates keyed by their source string."""
compiled = _template_cache.get(template_str)
if compiled is None:
compiled = jinja_env.from_string(template_str)
_template_cache[template_str] = compiled
return compiled
def str_presenter(dumper, data):
"""
@ -106,8 +118,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
if state['mode'] != 'instruct':
chat_template_str = replace_character_names(chat_template_str, state['name1'], state['name2'])
instruction_template = jinja_env.from_string(state['instruction_template_str'])
chat_template = jinja_env.from_string(chat_template_str)
instruction_template = get_compiled_template(state['instruction_template_str'])
chat_template = get_compiled_template(chat_template_str)
instruct_renderer = partial(
instruction_template.render,
@ -481,12 +493,12 @@ def get_stopping_strings(state):
renderers = []
if state['mode'] in ['instruct', 'chat-instruct']:
template = jinja_env.from_string(state['instruction_template_str'])
template = get_compiled_template(state['instruction_template_str'])
renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token)
renderers.append(renderer)
if state['mode'] in ['chat']:
template = jinja_env.from_string(state['chat_template_str'])
template = get_compiled_template(state['chat_template_str'])
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
renderers.append(renderer)