From 616ea6966d4821357076ff0c3b0a37967b736dd1 Mon Sep 17 00:00:00 2001 From: oobabooga Date: Tue, 20 May 2025 12:51:28 -0300 Subject: [PATCH] Store previous reply versions on regenerate (#7004) --- modules/chat.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/modules/chat.py b/modules/chat.py index 13f733e9..3efc55db 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -365,6 +365,34 @@ def get_stopping_strings(state): return result +def add_message_version(history, row_idx, is_current=True): + """Add the current message as a version in the history metadata""" + if 'metadata' not in history: + history['metadata'] = {} + + if row_idx >= len(history['internal']) or not history['internal'][row_idx][1].strip(): + return # Skip if row doesn't exist or message is empty + + key = f"assistant_{row_idx}" + + # Initialize metadata structures if needed + if key not in history['metadata']: + history['metadata'][key] = {"timestamp": get_current_timestamp()} + if "versions" not in history['metadata'][key]: + history['metadata'][key]["versions"] = [] + + # Add current message as a version + history['metadata'][key]["versions"].append({ + "content": history['internal'][row_idx][1], + "visible_content": history['visible'][row_idx][1], + "timestamp": get_current_timestamp() + }) + + # Update index if this is the current version + if is_current: + history['metadata'][key]["current_version_index"] = len(history['metadata'][key]["versions"]) - 1 + + def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False): history = state['history'] output = copy.deepcopy(history) @@ -405,6 +433,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess text, visible_text = output['internal'][-1][0], output['visible'][-1][0] if regenerate: row_idx = len(output['internal']) - 1 + + # Store the existing response as a version before regenerating + add_message_version(output, row_idx, is_current=False) + if loading_message: yield { 'visible': output['visible'][:-1] + [[visible_text, shared.processing_message]], @@ -465,6 +497,11 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess if is_stream: yield output + # Add the newly generated response as a version (only for regeneration) + if regenerate: + row_idx = len(output['internal']) - 1 + add_message_version(output, row_idx, is_current=True) + output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True) yield output