API: Speed up chat completions by ~85ms per request

This commit is contained in:
oobabooga 2026-03-05 18:36:07 -08:00
parent 249bd6eea2
commit 4c406e024f
2 changed files with 18 additions and 6 deletions

View file

@ -76,6 +76,18 @@ jinja_env = ImmutableSandboxedEnvironment(
)
jinja_env.globals["strftime_now"] = strftime_now
_template_cache = {}
def get_compiled_template(template_str):
"""Cache compiled Jinja2 templates keyed by their source string."""
compiled = _template_cache.get(template_str)
if compiled is None:
compiled = jinja_env.from_string(template_str)
_template_cache[template_str] = compiled
return compiled
def str_presenter(dumper, data):
"""
@ -106,8 +118,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
if state['mode'] != 'instruct':
chat_template_str = replace_character_names(chat_template_str, state['name1'], state['name2'])
instruction_template = jinja_env.from_string(state['instruction_template_str'])
chat_template = jinja_env.from_string(chat_template_str)
instruction_template = get_compiled_template(state['instruction_template_str'])
chat_template = get_compiled_template(chat_template_str)
instruct_renderer = partial(
instruction_template.render,
@ -481,12 +493,12 @@ def get_stopping_strings(state):
renderers = []
if state['mode'] in ['instruct', 'chat-instruct']:
template = jinja_env.from_string(state['instruction_template_str'])
template = get_compiled_template(state['instruction_template_str'])
renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token)
renderers.append(renderer)
if state['mode'] in ['chat']:
template = jinja_env.from_string(state['chat_template_str'])
template = get_compiled_template(state['chat_template_str'])
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
renderers.append(renderer)