mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-09 00:23:38 +00:00
Add extension modifier function for bot reply token stream
This commit is contained in:
parent
85269d7fbb
commit
a8e3dc36b3
3 changed files with 37 additions and 11 deletions
|
|
@ -81,7 +81,7 @@ def iterator():
|
|||
|
||||
|
||||
# Extension functions that map string -> string
|
||||
def _apply_string_extensions(function_name, text, state, is_chat=False):
|
||||
def _apply_string_extensions(function_name, text, state, is_chat=False, **extra_kwargs):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
func = getattr(extension, function_name)
|
||||
|
|
@ -89,23 +89,22 @@ def _apply_string_extensions(function_name, text, state, is_chat=False):
|
|||
# Handle old extensions without the 'state' arg or
|
||||
# the 'is_chat' kwarg
|
||||
count = 0
|
||||
has_chat = False
|
||||
for k in signature(func).parameters:
|
||||
func_params = signature(func).parameters
|
||||
kwargs = {}
|
||||
|
||||
for k in func_params:
|
||||
if k == 'is_chat':
|
||||
has_chat = True
|
||||
kwargs['is_chat'] = is_chat
|
||||
elif k in extra_kwargs:
|
||||
kwargs[k] = extra_kwargs[k]
|
||||
else:
|
||||
count += 1
|
||||
|
||||
if count == 2:
|
||||
if count >= 2:
|
||||
args = [text, state]
|
||||
else:
|
||||
args = [text]
|
||||
|
||||
if has_chat:
|
||||
kwargs = {'is_chat': is_chat}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
text = func(*args, **kwargs)
|
||||
|
||||
return text
|
||||
|
|
@ -231,6 +230,7 @@ EXTENSION_MAP = {
|
|||
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||
"chat_input": _apply_chat_input_extensions,
|
||||
"output_stream": partial(_apply_string_extensions, "output_stream_modifier"),
|
||||
"state": _apply_state_modifier_extensions,
|
||||
"history": _apply_history_modifier_extensions,
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
|
|
|
|||
|
|
@ -77,6 +77,16 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat):
|
||||
cur_time = time.monotonic()
|
||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||
|
||||
try:
|
||||
reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=False)
|
||||
except Exception:
|
||||
try:
|
||||
logger.error('Error in streaming extension hook')
|
||||
except Exception:
|
||||
pass
|
||||
traceback.print_exc()
|
||||
|
||||
if escape_html:
|
||||
reply = html.escape(reply)
|
||||
|
||||
|
|
@ -102,6 +112,15 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
|
||||
break
|
||||
|
||||
try:
|
||||
reply = apply_extensions('output_stream', reply, state, is_chat=is_chat, is_final=True)
|
||||
except Exception:
|
||||
try:
|
||||
logger.error('Error in streaming extension hook')
|
||||
except Exception:
|
||||
pass
|
||||
traceback.print_exc()
|
||||
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply, state)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue