From 4c406e024fd1b0342baf35130a0a79868a4ef539 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 5 Mar 2026 18:36:07 -0800 Subject: [PATCH] API: Speed up chat completions by ~85ms per request --- extensions/openai/completions.py | 4 ++-- modules/chat.py | 20 ++++++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index de944a8f..56d2059d 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -288,8 +288,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p return chunk # generate reply ####################################### - prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_) if prompt_only: + prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_) yield {'prompt': prompt} return @@ -335,7 +335,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p if len(tool_calls) > 0: break - token_count = len(encode(prompt)[0]) + token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0 completion_token_count = len(encode(answer)[0]) stop_reason = "stop" if len(tool_calls) > 0: diff --git a/modules/chat.py b/modules/chat.py index 21a2104a..f9a98bb6 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -76,6 +76,18 @@ jinja_env = ImmutableSandboxedEnvironment( ) jinja_env.globals["strftime_now"] = strftime_now +_template_cache = {} + + +def get_compiled_template(template_str): + """Cache compiled Jinja2 templates keyed by their source string.""" + compiled = _template_cache.get(template_str) + if compiled is None: + compiled = jinja_env.from_string(template_str) + _template_cache[template_str] = compiled + + return compiled + def str_presenter(dumper, data): """ @@ -106,8 +118,8 @@ def generate_chat_prompt(user_input, state, **kwargs): if state['mode'] != 'instruct': chat_template_str = replace_character_names(chat_template_str, state['name1'], state['name2']) - instruction_template = jinja_env.from_string(state['instruction_template_str']) - chat_template = jinja_env.from_string(chat_template_str) + instruction_template = get_compiled_template(state['instruction_template_str']) + chat_template = get_compiled_template(chat_template_str) instruct_renderer = partial( instruction_template.render, @@ -481,12 +493,12 @@ def get_stopping_strings(state): renderers = [] if state['mode'] in ['instruct', 'chat-instruct']: - template = jinja_env.from_string(state['instruction_template_str']) + template = get_compiled_template(state['instruction_template_str']) renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token) renderers.append(renderer) if state['mode'] in ['chat']: - template = jinja_env.from_string(state['chat_template_str']) + template = get_compiled_template(state['chat_template_str']) renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2']) renderers.append(renderer)