diff --git a/modules/chat.py b/modules/chat.py index 8bac680c..495fe934 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -538,6 +538,27 @@ def extract_pdf_text(pdf_path): return f"[Error extracting PDF text: {str(e)}]" +def generate_search_query(user_message, state): + """Generate a search query from user message using the LLM""" + # Augment the user message with search instruction + augmented_message = f"{user_message}\n\n=====\n\nPlease turn the message above into a short web search query in the same language as the message. Respond with only the search query, nothing else." + + # Use a minimal state for search query generation but keep the full history + search_state = state.copy() + search_state['max_new_tokens'] = 64 + search_state['auto_max_new_tokens'] = False + search_state['enable_thinking'] = False + + # Generate the full prompt using existing history + augmented message + formatted_prompt = generate_chat_prompt(augmented_message, search_state) + + query = "" + for reply in generate_reply(formatted_prompt, search_state, stopping_strings=[], is_chat=True): + query = reply.strip() + + return query + + def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False): # Handle dict format with text and files files = [] @@ -570,7 +591,9 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess add_message_attachment(output, row_idx, file_path, is_user=True) # Add web search results as attachments if enabled - add_web_search_attachments(output, row_idx, text, state) + if state.get('enable_web_search', False): + search_query = generate_search_query(text, state) + add_web_search_attachments(output, row_idx, text, search_query, state) # Apply extensions text, visible_text = apply_extensions('chat_input', text, visible_text, state) diff --git a/modules/web_search.py b/modules/web_search.py index d3387ac9..667178c5 100644 --- a/modules/web_search.py +++ b/modules/web_search.py @@ -13,22 +13,6 @@ def get_current_timestamp(): return datetime.now().strftime('%b %d, %Y %H:%M') -def generate_search_query(user_message, state): - """Generate a search query from user message using the LLM""" - search_prompt = f"{user_message}\n\n=====\n\nPlease turn the message above into a short web search query in the same language as the message. Respond with only the search query, nothing else." - - # Use a minimal state for search query generation - search_state = state.copy() - search_state['max_new_tokens'] = 64 - search_state['temperature'] = 0.1 - - query = "" - for reply in generate_reply(search_prompt, search_state, stopping_strings=[], is_chat=False): - query = reply.strip() - - return query - - def download_web_page(url, timeout=10): """Download and extract text from a web page""" try: @@ -82,19 +66,14 @@ def perform_web_search(query, num_pages=3): return [] -def add_web_search_attachments(history, row_idx, user_message, state): +def add_web_search_attachments(history, row_idx, user_message, search_query, state): """Perform web search and add results as attachments""" - if not state.get('enable_web_search', False): + if not search_query: + logger.warning("No search query provided") return try: - # Generate search query - search_query = generate_search_query(user_message, state) - if not search_query: - logger.warning("Failed to generate search query") - return - - logger.info(f"Generated search query: {search_query}") + logger.info(f"Using search query: {search_query}") # Perform web search num_pages = int(state.get('web_search_pages', 3))