Add attachments support (text files, PDF documents) (#7005)

This commit is contained in:
oobabooga 2025-05-21 00:36:20 -03:00 committed by GitHub
parent 5d00574a56
commit 409a48d6bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 233 additions and 12 deletions

View file

@ -157,7 +157,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False)
history = kwargs.get('history', state['history'])['internal']
history_data = kwargs.get('history', state['history'])
history = history_data['internal']
metadata = history_data.get('metadata', {})
# Templates
chat_template_str = state['chat_template_str']
@ -196,11 +198,13 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "system", "content": context})
insert_pos = len(messages)
for entry in reversed(history):
for i, entry in enumerate(reversed(history)):
user_msg = entry[0].strip()
assistant_msg = entry[1].strip()
tool_msg = entry[2].strip() if len(entry) > 2 else ''
row_idx = len(history) - i - 1
if tool_msg:
messages.insert(insert_pos, {"role": "tool", "content": tool_msg})
@ -208,10 +212,40 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
messages.insert(insert_pos, {"role": "user", "content": user_msg})
# Check for user message attachments in metadata
user_key = f"user_{row_idx}"
enhanced_user_msg = user_msg
# Add attachment content if present
if user_key in metadata and "attachments" in metadata[user_key]:
attachments_text = ""
for attachment in metadata[user_key]["attachments"]:
filename = attachment.get("name", "file")
content = attachment.get("content", "")
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
if attachments_text:
enhanced_user_msg = f"{user_msg}\n\nATTACHMENTS:{attachments_text}"
messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg})
user_input = user_input.strip()
if user_input and not impersonate and not _continue:
# For the current user input being processed, check if we need to add attachments
if not impersonate and not _continue and len(history_data.get('metadata', {})) > 0:
current_row_idx = len(history)
user_key = f"user_{current_row_idx}"
if user_key in metadata and "attachments" in metadata[user_key]:
attachments_text = ""
for attachment in metadata[user_key]["attachments"]:
filename = attachment.get("name", "file")
content = attachment.get("content", "")
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
if attachments_text:
user_input = f"{user_input}\n\nATTACHMENTS:{attachments_text}"
messages.append({"role": "user", "content": user_input})
def make_prompt(messages):
@ -280,7 +314,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Resort to truncating the user input
else:
user_message = messages[-1]['content']
# Bisect the truncation point
@ -393,7 +426,74 @@ def add_message_version(history, row_idx, is_current=True):
history['metadata'][key]["current_version_index"] = len(history['metadata'][key]["versions"]) - 1
def add_message_attachment(history, row_idx, file_path, is_user=True):
"""Add a file attachment to a message in history metadata"""
if 'metadata' not in history:
history['metadata'] = {}
key = f"{'user' if is_user else 'assistant'}_{row_idx}"
if key not in history['metadata']:
history['metadata'][key] = {"timestamp": get_current_timestamp()}
if "attachments" not in history['metadata'][key]:
history['metadata'][key]["attachments"] = []
# Get file info using pathlib
path = Path(file_path)
filename = path.name
file_extension = path.suffix.lower()
try:
# Handle different file types
if file_extension == '.pdf':
# Process PDF file
content = extract_pdf_text(path)
file_type = "application/pdf"
else:
# Default handling for text files
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
file_type = "text/plain"
# Add attachment
attachment = {
"name": filename,
"type": file_type,
"content": content,
}
history['metadata'][key]["attachments"].append(attachment)
return content # Return the content for reuse
except Exception as e:
logger.error(f"Error processing attachment {filename}: {e}")
return None
def extract_pdf_text(pdf_path):
"""Extract text from a PDF file"""
import PyPDF2
text = ""
try:
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text += page.extract_text() + "\n\n"
return text.strip()
except Exception as e:
logger.error(f"Error extracting text from PDF: {e}")
return f"[Error extracting PDF text: {str(e)}]"
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
# Handle dict format with text and files
files = []
if isinstance(text, dict):
files = text.get('files', [])
text = text.get('text', '')
history = state['history']
output = copy.deepcopy(history)
output = apply_extensions('history', output)
@ -411,12 +511,18 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
if not (regenerate or _continue):
visible_text = html.escape(text)
# Process file attachments and store in metadata
row_idx = len(output['internal'])
# Add attachments to metadata only, not modifying the message text
for file_path in files:
add_message_attachment(output, row_idx, file_path, is_user=True)
# Apply extensions
text, visible_text = apply_extensions('chat_input', text, visible_text, state)
text = apply_extensions('input', text, state, is_chat=True)
# Current row index
row_idx = len(output['internal'])
output['internal'].append([text, ''])
output['visible'].append([visible_text, ''])
# Add metadata with timestamp
@ -1215,7 +1321,7 @@ def handle_replace_last_reply_click(text, state):
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
return [history, html, ""]
return [history, html, {"text": "", "files": []}]
def handle_send_dummy_message_click(text, state):
@ -1223,7 +1329,7 @@ def handle_send_dummy_message_click(text, state):
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
return [history, html, ""]
return [history, html, {"text": "", "files": []}]
def handle_send_dummy_reply_click(text, state):
@ -1231,7 +1337,7 @@ def handle_send_dummy_reply_click(text, state):
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
return [history, html, ""]
return [history, html, {"text": "", "files": []}]
def handle_remove_last_click(state):
@ -1239,7 +1345,7 @@ def handle_remove_last_click(state):
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
return [history, html, last_input]
return [history, html, {"text": last_input, "files": []}]
def handle_unique_id_select(state):