Add extension modifier function for bot reply token stream

This commit is contained in:
Th-Underscore 2025-12-08 16:01:42 -05:00
parent 85269d7fbb
commit a8e3dc36b3
No known key found for this signature in database
GPG key ID: 8D0551EF8593B2F0
3 changed files with 37 additions and 11 deletions

View file

@ -81,7 +81,7 @@ def iterator():
# Extension functions that map string -> string
def _apply_string_extensions(function_name, text, state, is_chat=False):
def _apply_string_extensions(function_name, text, state, is_chat=False, **extra_kwargs):
for extension, _ in iterator():
if hasattr(extension, function_name):
func = getattr(extension, function_name)
@ -89,23 +89,22 @@ def _apply_string_extensions(function_name, text, state, is_chat=False):
# Handle old extensions without the 'state' arg or
# the 'is_chat' kwarg
count = 0
has_chat = False
for k in signature(func).parameters:
func_params = signature(func).parameters
kwargs = {}
for k in func_params:
if k == 'is_chat':
has_chat = True
kwargs['is_chat'] = is_chat
elif k in extra_kwargs:
kwargs[k] = extra_kwargs[k]
else:
count += 1
if count == 2:
if count >= 2:
args = [text, state]
else:
args = [text]
if has_chat:
kwargs = {'is_chat': is_chat}
else:
kwargs = {}
text = func(*args, **kwargs)
return text
@ -231,6 +230,7 @@ EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"chat_input": _apply_chat_input_extensions,
"output_stream": partial(_apply_string_extensions, "output_stream_modifier"),
"state": _apply_state_modifier_extensions,
"history": _apply_history_modifier_extensions,
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),

View file

@ -77,6 +77,16 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat):
cur_time = time.monotonic()
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
try:
reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=False)
except Exception:
try:
logger.error('Error in streaming extension hook')
except Exception:
pass
traceback.print_exc()
if escape_html:
reply = html.escape(reply)
@ -102,6 +112,15 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
break
try:
reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=True)
except Exception:
try:
logger.error('Error in streaming extension hook')
except Exception:
pass
traceback.print_exc()
if not is_chat:
reply = apply_extensions('output', reply, state)