diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 5f0e0128..288205a5 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -47,6 +47,7 @@ class GenerationOptions(BaseModel): seed: int = -1 sampler_priority: List[str] | str | None = Field(default=None, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].") custom_token_bans: str = "" + show_after: str = "" negative_prompt: str = '' dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"' grammar_string: str = "" diff --git a/modules/chat.py b/modules/chat.py index 0e47da29..2852aaf3 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -412,8 +412,16 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_ yield history return + show_after = html.escape(state["show_after"]) if state["show_after"] else None for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui): - yield history + if show_after: + after = history["visible"][-1][1].partition(show_after)[2] or "*Is thinking...*" + yield { + 'internal': history['internal'], + 'visible': history['visible'][:-1] + [[history['visible'][-1][0], after]] + } + else: + yield history def character_is_loaded(state, raise_exception=False): diff --git a/modules/shared.py b/modules/shared.py index f1e12673..2e91f4d5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -57,6 +57,7 @@ settings = { 'seed': -1, 'custom_stopping_strings': '', 'custom_token_bans': '', + 'show_after': '', 'negative_prompt': '', 'autoload_model': False, 'dark_theme': True, diff --git a/modules/ui.py b/modules/ui.py index df948a14..b776e19c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -215,6 +215,7 @@ def list_interface_input_elements(): 'sampler_priority', 'custom_stopping_strings', 'custom_token_bans', + 'show_after', 'negative_prompt', 'dry_sequence_breakers', 'grammar_string', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index c8fd6bc7..265840ed 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -92,6 +92,7 @@ def create_ui(default_preset): shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar']) shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=2, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"') shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Token bans', info='Token IDs to ban, separated by commas. The IDs can be found in the Default or Notebook tab.') + shared.gradio['show_after'] = gr.Textbox(value=shared.settings['show_after'] or None, label='Show after', info='Hide the reply before this text.', placeholder="") shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', info='For CFG. Only used when guidance_scale is different than 1.', lines=3, elem_classes=['add_scrollbar']) shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.') with gr.Row() as shared.gradio['grammar_file_row']: