Cleanup modules/chat.py

This commit is contained in:
oobabooga 2026-03-18 21:05:42 -07:00
parent 779e7611ff
commit dde1764763

View file

@ -70,9 +70,7 @@ def update_message_metadata(metadata_dict, role, index, **fields):
if key not in metadata_dict:
metadata_dict[key] = {}
# Update with provided fields
for field_name, field_value in fields.items():
metadata_dict[key][field_name] = field_value
metadata_dict[key].update(fields)
jinja_env = ImmutableSandboxedEnvironment(
@ -212,6 +210,24 @@ def _expand_tool_sequence(tool_seq):
return messages
def _format_attachments(attachments, include_text=True):
"""Build image ref and text attachment strings from a list of attachments."""
attachments_text = ""
image_refs = ""
for attachment in attachments:
if attachment.get("type") == "image":
image_refs += "<__media__>"
elif include_text:
filename = attachment.get("name", "file")
content = attachment.get("content", "")
if attachment.get("type") == "text/html" and attachment.get("url"):
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
else:
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
return image_refs, attachments_text
def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
@ -328,41 +344,19 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.insert(insert_pos, msg_dict)
# Handle Seed-OSS
elif '<seed:think>' in assistant_msg:
# Handle <think> blocks (Kimi, DeepSeek, Qwen, etc.) and Seed-OSS
elif '<think>' in assistant_msg or '<seed:think>' in assistant_msg:
open_tag = '<think>' if '<think>' in assistant_msg else '<seed:think>'
close_tag = '</think>' if open_tag == '<think>' else '</seed:think>'
thinking_content = ""
final_content = assistant_msg
# Extract thinking content if present
if '<seed:think>' in assistant_msg:
parts = assistant_msg.split('<seed:think>', 1)
if len(parts) > 1:
potential_content = parts[1]
if '</seed:think>' in potential_content:
thinking_content = potential_content.split('</seed:think>', 1)[0].strip()
final_content = parts[0] + potential_content.split('</seed:think>', 1)[1]
else:
thinking_content = potential_content.strip()
final_content = parts[0]
# Insert as structured message
msg_dict = {"role": "assistant", "content": final_content.strip()}
if thinking_content:
msg_dict["reasoning_content"] = thinking_content
messages.insert(insert_pos, msg_dict)
# Handle <think> blocks (Kimi, DeepSeek, Qwen, etc.)
elif '<think>' in assistant_msg:
thinking_content = ""
final_content = assistant_msg
parts = assistant_msg.split('<think>', 1)
parts = assistant_msg.split(open_tag, 1)
if len(parts) > 1:
potential_content = parts[1]
if '</think>' in potential_content:
thinking_content = potential_content.split('</think>', 1)[0].strip()
final_content = parts[0] + potential_content.split('</think>', 1)[1]
if close_tag in potential_content:
thinking_content = potential_content.split(close_tag, 1)[0].strip()
final_content = parts[0] + potential_content.split(close_tag, 1)[1]
else:
thinking_content = potential_content.strip()
final_content = parts[0]
@ -399,22 +393,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Add attachment content if present AND if past attachments are enabled
if user_key in metadata and "attachments" in metadata[user_key]:
attachments_text = ""
image_refs = ""
for attachment in metadata[user_key]["attachments"]:
if attachment.get("type") == "image":
# Add image reference for multimodal models
image_refs += "<__media__>"
elif state.get('include_past_attachments', True):
# Handle text/PDF attachments
filename = attachment.get("name", "file")
content = attachment.get("content", "")
if attachment.get("type") == "text/html" and attachment.get("url"):
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
else:
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
image_refs, attachments_text = _format_attachments(
metadata[user_key]["attachments"],
include_text=state.get('include_past_attachments', True)
)
if image_refs:
enhanced_user_msg = f"{image_refs}\n\n{enhanced_user_msg}"
if attachments_text:
@ -427,37 +409,18 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Check if we have attachments
if not (impersonate or _continue):
has_attachments = False
if len(history_data.get('metadata', {})) > 0:
current_row_idx = len(history)
user_key = f"user_{current_row_idx}"
has_attachments = user_key in metadata and "attachments" in metadata[user_key]
current_row_idx = len(history)
user_key = f"user_{current_row_idx}"
has_attachments = user_key in metadata and "attachments" in metadata[user_key]
if user_input or has_attachments:
# For the current user input being processed, check if we need to add attachments
if len(history_data.get('metadata', {})) > 0:
current_row_idx = len(history)
user_key = f"user_{current_row_idx}"
if user_key in metadata and "attachments" in metadata[user_key]:
attachments_text = ""
image_refs = ""
for attachment in metadata[user_key]["attachments"]:
if attachment.get("type") == "image":
image_refs += "<__media__>"
else:
filename = attachment.get("name", "file")
content = attachment.get("content", "")
if attachment.get("type") == "text/html" and attachment.get("url"):
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
else:
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
if image_refs:
user_input = f"{image_refs}\n\n{user_input}"
if attachments_text:
user_input += f"\n\nATTACHMENTS:\n{attachments_text}"
if has_attachments:
image_refs, attachments_text = _format_attachments(metadata[user_key]["attachments"])
if image_refs:
user_input = f"{image_refs}\n\n{user_input}"
if attachments_text:
user_input += f"\n\nATTACHMENTS:\n{attachments_text}"
messages.append({"role": "user", "content": user_input})
@ -609,7 +572,6 @@ def count_prompt_tokens(text_input, state):
try:
# Handle dict format with text and files
files = []
if isinstance(text_input, dict):
files = text_input.get('files', [])
text = text_input.get('text', '')
@ -647,7 +609,6 @@ def count_prompt_tokens(text_input, state):
def get_stopping_strings(state):
stopping_strings = []
renderers = []
if state['mode'] in ['instruct', 'chat-instruct']: