From 70946c6d77225a200bfb7e6e91ca69c6d2d28328 Mon Sep 17 00:00:00 2001 From: Th-Underscore Date: Mon, 8 Dec 2025 15:35:08 -0500 Subject: [PATCH] Add extension modifier function for bot reply token stream --- docs/07 - Extensions.md | 9 ++++++++- modules/extensions.py | 1 + modules/text_generation.py | 19 +++++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/docs/07 - Extensions.md b/docs/07 - Extensions.md index ebcd3c0e..8281aa0e 100644 --- a/docs/07 - Extensions.md +++ b/docs/07 - Extensions.md @@ -40,12 +40,13 @@ The extensions framework is based on special functions and variables that you ca | Function | Description | |-------------|-------------| | `def setup()` | Is executed when the extension gets imported. | -| `def ui()` | Creates custom gradio elements when the UI is launched. | +| `def ui()` | Creates custom gradio elements when the UI is launched. | | `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. | | `def custom_js()` | Same as above but for javascript. | | `def input_modifier(string, state, is_chat=False)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def output_modifier(string, state, is_chat=False)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def chat_input_modifier(text, visible_text, state)` | Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. | +| `def output_stream_modifier(string, state, is_chat=False, is_final=False)` | Overrides the full text mid-stream. Called for each partial token/chunk while the UI is streaming output. Includes the last generated token (is_final). | | `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | @@ -209,6 +210,12 @@ def output_modifier(string, state, is_chat=False): """ return string +def output_stream_modifier(string, state, is_chat=False): + """ + Modifies the text stream of the LLM output in realtime. + """ + return string + def custom_generate_chat_prompt(user_input, state, **kwargs): """ Replaces the function that generates the prompt from the chat history. diff --git a/modules/extensions.py b/modules/extensions.py index e0010312..5573ec65 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -231,6 +231,7 @@ EXTENSION_MAP = { "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), "chat_input": _apply_chat_input_extensions, + "output_stream": partial(_apply_string_extensions, "output_stream_modifier"), "state": _apply_state_modifier_extensions, "history": _apply_history_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), diff --git a/modules/text_generation.py b/modules/text_generation.py index 27c5de7d..f17e2fea 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -77,6 +77,16 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat): cur_time = time.monotonic() reply, stop_found = apply_stopping_strings(reply, all_stop_strings) + + try: + reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=False) + except Exception: + try: + logger.error('Error in streaming extension hook') + except Exception: + pass + traceback.print_exc() + if escape_html: reply = html.escape(reply) @@ -103,6 +113,15 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap break if not is_chat: + try: + reply = apply_extensions('stream', reply, state, is_chat=is_chat, is_final=True) + except Exception: + try: + logger.error('Error in streaming extension hook') + except Exception: + pass + traceback.print_exc() + reply = apply_extensions('output', reply, state) yield reply