Improve SuperboogaV2 with Date/Time Embeddings, GPU Support, and Multiple File Formats (#6748)

This commit is contained in:
Alireza Ghasemi 2025-02-18 02:38:15 +01:00 committed by GitHub
parent 12f6f7ba9f
commit 01f20d2d9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 227 additions and 38 deletions

View file

@ -1,5 +1,41 @@
# superboogav2 # SuperboogaV2
For a description, please see the comments in this Pull Request: Enhance your LLM with additional information from text, URLs, and files for more accurate and context-aware responses.
https://github.com/oobabooga/text-generation-webui/pull/3272 ---
## Installation and Activation
1. Start the conda environment by running `cmd_windows.bat` or the equivalent for your system in the root directory of `text-generation-webui`.
2. Install the necessary packages:
```
pip install -r extensions/superboogav2/requirements.txt
```
3. Activate the extension in the `Session` tab of the web UI.
4. Click on `Apply flags/extensions and restart`. Optionally save the configuration by clicking on `Save UI defaults to settings.yaml`.
## Usage and Features
After activation, you can scroll further down in the chat UI to reveal the SuperboogaV2 interface. Here, you can add extra information to your chats through text input, multiple URLs, or by providing multiple files subject to the context window limit of your model.
The extra information and the current date and time are provided to the model as embeddings that persist across conversations. To clear them, click the `Clear Data` button and start a new chat. You can adjust the text extraction parameters and other options in the `Settings`.
## Supported File Formats
SuperboogaV2 utilizes MuPDF, pandas, python-docx, and python-pptx to extract text from various file formats, including:
- TXT
- PDF
- EPUB
- HTML
- CSV
- ODT/ODS/ODP
- DOCX/PPTX/XLSX
## Additional Information
SuperboogaV2 processes your data into context-aware chunks, applies cleaning techniques, and stores them as embeddings to minimize redundant computations. Relevance is determined using distance calculations and prioritization of recent information.
For a detailed description and more information, refer to the comments in this pull request: [https://github.com/oobabooga/text-generation-webui/pull/3272](https://github.com/oobabooga/text-generation-webui/pull/3272)

View file

@ -1,7 +1,7 @@
import math import math
import random import random
import threading import threading
import torch
import chromadb import chromadb
import numpy as np import numpy as np
import posthog import posthog
@ -16,9 +16,6 @@ from modules.text_generation import decode, encode
posthog.capture = lambda *args, **kwargs: None posthog.capture = lambda *args, **kwargs: None
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
class Info: class Info:
def __init__(self, start_index, text_with_context, distance, id): def __init__(self, start_index, text_with_context, distance, id):
self.text_with_context = text_with_context self.text_with_context = text_with_context
@ -77,11 +74,23 @@ class Info:
class ChromaCollector(): class ChromaCollector():
def __init__(self): def __init__(self):
name = ''.join(random.choice('ab') for _ in range(10)) name = "".join(random.choice("ab") for _ in range(10))
self.name = name self.name = name
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False)) self.embedder = embedding_functions.SentenceTransformerEmbeddingFunction(
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder) "sentence-transformers/all-mpnet-base-v2",
device=("cuda" if torch.cuda.is_available() else "cpu"),
)
chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.collection = chroma_client.create_collection(
name=self.name,
embedding_function=self.embedder,
metadata={
"hnsw:search_ef": 200,
"hnsw:construction_ef": 200,
"hnsw:M": 64,
},
)
self.ids = [] self.ids = []
self.id_to_info = {} self.id_to_info = {}
@ -110,7 +119,7 @@ class ChromaCollector():
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead. # If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
if non_existing_texts: if non_existing_texts:
non_existing_embeddings = embedder(non_existing_texts) non_existing_embeddings = self.embedder(non_existing_texts)
for text, embedding in zip(non_existing_texts, non_existing_embeddings): for text, embedding in zip(non_existing_texts, non_existing_embeddings):
self.embeddings_cache[text] = embedding self.embeddings_cache[text] = embedding
@ -139,7 +148,7 @@ class ChromaCollector():
id_ = new_ids[i] id_ = new_ids[i]
metadata = metadatas[i] if metadatas is not None else None metadata = metadatas[i] if metadatas is not None else None
embedding = self.embeddings_cache.get(text) embedding = self.embeddings_cache.get(text)
if embedding: if embedding is not None and embedding.any():
existing_texts.append(text) existing_texts.append(text)
existing_embeddings.append(embedding) existing_embeddings.append(embedding)
existing_ids.append(id_) existing_ids.append(id_)
@ -323,6 +332,8 @@ class ChromaCollector():
def delete(self, ids_to_delete: list[str], where: dict): def delete(self, ids_to_delete: list[str], where: dict):
with self.lock: with self.lock:
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids'] ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
if not ids_to_delete:
return
self.collection.delete(ids=ids_to_delete, where=where) self.collection.delete(ids=ids_to_delete, where=where)
# Remove the deleted ids from self.ids and self.id_to_info # Remove the deleted ids from self.ids and self.id_to_info
@ -335,12 +346,7 @@ class ChromaCollector():
def clear(self): def clear(self):
with self.lock: with self.lock:
self.chroma_client.reset() self.__init__() # reinitialize the collector
self.ids = []
self.chroma_client.delete_collection(name=self.name)
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
logger.info('Successfully cleared all records and reset chromaDB.') logger.info('Successfully cleared all records and reset chromaDB.')

View file

@ -127,6 +127,9 @@
"default": "\n\n<<document end>>\n\n" "default": "\n\n<<document end>>\n\n"
}, },
"manual": { "manual": {
"default": false
},
"add_date_time": {
"default": true "default": true
}, },
"add_chat_to_data": { "add_chat_to_data": {

View file

@ -6,6 +6,7 @@ It will only include full words.
import bisect import bisect
import re import re
from datetime import datetime
import extensions.superboogav2.parameters as parameters import extensions.superboogav2.parameters as parameters
@ -154,6 +155,13 @@ def process_and_add_to_collector(corpus: str, collector: ChromaCollector, clear_
data_chunks_with_context = [] data_chunks_with_context = []
data_chunk_starting_indices = [] data_chunk_starting_indices = []
if parameters.get_add_date_time():
now = datetime.now()
date_time_chunk = f"Current time is {now.strftime('%H:%M:%S')}. Today is {now.strftime('%A')}. The current date is {now.strftime('%Y-%m-%d')}."
data_chunks.append(date_time_chunk)
data_chunks_with_context.append(date_time_chunk)
data_chunk_starting_indices.append(0)
# Handling chunk_regex # Handling chunk_regex
if parameters.get_chunk_regex(): if parameters.get_chunk_regex():
if parameters.get_chunk_separator(): if parameters.get_chunk_separator():

View file

@ -39,11 +39,11 @@ def _markdown_hyperparams():
# Convert numpy types to python types. # Convert numpy types to python types.
def _convert_np_types(params): def _convert_np_types(params):
for key in params: for key in params:
if type(params[key]) == np.bool_: if isinstance(params[key], np.bool_):
params[key] = bool(params[key]) params[key] = bool(params[key])
elif type(params[key]) == np.int64: elif isinstance(params[key], np.int64):
params[key] = int(params[key]) params[key] = int(params[key])
elif type(params[key]) == np.float64: elif isinstance(params[key], np.float64):
params[key] = float(params[key]) params[key] = float(params[key])
return params return params

View file

@ -251,6 +251,10 @@ def get_is_manual() -> bool:
return bool(Parameters.getInstance().hyperparameters['manual']['default']) return bool(Parameters.getInstance().hyperparameters['manual']['default'])
def get_add_date_time() -> bool:
return bool(Parameters.getInstance().hyperparameters['add_date_time']['default'])
def get_add_chat_to_data() -> bool: def get_add_chat_to_data() -> bool:
return bool(Parameters.getInstance().hyperparameters['add_chat_to_data']['default']) return bool(Parameters.getInstance().hyperparameters['add_chat_to_data']['default'])
@ -331,6 +335,10 @@ def set_manual(value: bool):
Parameters.getInstance().hyperparameters['manual']['default'] = value Parameters.getInstance().hyperparameters['manual']['default'] = value
def set_add_date_time(value: bool):
Parameters.getInstance().hyperparameters['add_date_time']['default'] = value
def set_add_chat_to_data(value: bool): def set_add_chat_to_data(value: bool):
Parameters.getInstance().hyperparameters['add_chat_to_data']['default'] = value Parameters.getInstance().hyperparameters['add_chat_to_data']['default'] = value

View file

@ -1,10 +1,16 @@
beautifulsoup4==4.12.2 beautifulsoup4==4.13.3
chromadb==0.4.24 chromadb==0.6.3
lxml lxml
nltk
optuna optuna
pandas==2.0.3 pandas
posthog==2.4.2 posthog==3.13.0
sentence_transformers==2.2.2 sentence_transformers==3.3.1
spacy spacy
pytextrank pytextrank
num2words num2words
PyMuPDF
python-docx
python-pptx
openpyxl
odfpy

View file

@ -9,6 +9,13 @@ os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve(
import codecs import codecs
import textwrap import textwrap
import docx
import pptx
import fitz
fitz.TOOLS.mupdf_display_errors(False)
import pandas as pd
from odf.opendocument import load
from odf.draw import Page
import gradio as gr import gradio as gr
@ -46,11 +53,123 @@ def _feed_data_into_collector(corpus):
yield '### Done.' yield '### Done.'
def _feed_file_into_collector(file): def _feed_file_into_collector(files):
yield '### Reading and processing the input dataset...' if not files:
text = file.decode('utf-8') logger.warning("No files selected.")
process_and_add_to_collector(text, collector, False, create_metadata_source('file')) return
yield '### Done.'
def read_binary_file(file_path):
try:
with open(file_path, 'rb') as f:
return f.read()
except Exception:
logger.error(f"Failed to read {file_path}.")
return None
def extract_with_utf8(text):
try:
return text.decode('utf-8')
except Exception:
return ""
def extract_with_fitz(file_content):
try:
with fitz.open(stream=file_content, filetype=None) as doc:
num_pages = doc.page_count
text = "\n".join(block[4] for page in doc for block in page.get_text("blocks") if block[6] == 0)
logger.info(f"Extracted text from {num_pages} pages with fitz.")
return text
except Exception:
return ""
def extract_with_docx(file_path):
try:
paragraphs = docx.Document(file_path).paragraphs
text = "\n".join(para.text for para in paragraphs)
logger.info(f"Extracted text from {len(paragraphs)} paragraphs with docx.")
return text
except Exception:
return ""
def extract_with_pptx(file_path):
try:
slides = pptx.Presentation(file_path).slides
text = "\n".join(
shape.text for slide in slides for shape in slide.shapes if hasattr(shape, "text")
)
logger.info(f"Extracted text from {len(slides)} slides with pptx.")
return text
except Exception:
return ""
def extract_with_odf(file_path):
if not file_path.endswith(".odp"):
return ""
try:
doc = load(file_path)
text_content = []
def extract_text(element):
parts = []
if hasattr(element, "childNodes"):
for node in element.childNodes:
if node.nodeType == node.TEXT_NODE:
parts.append(node.data)
else:
parts.append(extract_text(node))
return "".join(parts)
for slide in doc.getElementsByType(Page):
slide_text = extract_text(slide)
if slide_text.strip():
text_content.append(slide_text.strip())
text = "\n".join(text_content)
logger.info(f"Extracted text from {len(doc.getElementsByType(Page))} slides with odf.")
return text
except Exception as e:
logger.error(f"Failed to extract text from {file_path}: {str(e)}")
return ""
def extract_with_pandas(file_path):
try:
df = pd.read_excel(file_path)
text = "\n".join(str(cell) for col in df.columns for cell in df[col])
logger.info(f"Extracted text from {df.shape[0]}x{df.shape[1]} cells with pandas.")
return text
except Exception:
return ""
for index, file in enumerate(files, start=1):
file_name = os.path.basename(file)
logger.info(f"Processing {file_name}...")
file_content = read_binary_file(file)
if not file_content:
continue
text_extractors = [
lambda: extract_with_utf8(file_content),
lambda: extract_with_fitz(file_content),
lambda: extract_with_docx(file),
lambda: extract_with_pptx(file),
lambda: extract_with_odf(file),
lambda: extract_with_pandas(file),
]
for extractor in text_extractors:
text = extractor()
if text:
break
if not text:
logger.error(f"Failed to extract text from {file_name}, unsupported format.")
continue
process_and_add_to_collector(text, collector, False, create_metadata_source(f"file-{index}"))
logger.info("Done.")
yield "### Done."
def _feed_url_into_collector(urls): def _feed_url_into_collector(urls):
@ -107,7 +226,7 @@ def _get_optimizable_settings() -> list:
def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count, preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, add_date_time, postfix, data_separator, prefix, max_token_count,
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup): chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup):
logger.debug('Applying settings.') logger.debug('Applying settings.')
@ -124,6 +243,7 @@ def _apply_settings(optimization_steps, time_power, time_steepness, significant_
parameters.set_injection_strategy(injection_strategy) parameters.set_injection_strategy(injection_strategy)
parameters.set_add_chat_to_data(add_chat_to_data) parameters.set_add_chat_to_data(add_chat_to_data)
parameters.set_manual(manual) parameters.set_manual(manual)
parameters.set_add_date_time(add_date_time)
parameters.set_postfix(codecs.decode(postfix, 'unicode_escape')) parameters.set_postfix(codecs.decode(postfix, 'unicode_escape'))
parameters.set_data_separator(codecs.decode(data_separator, 'unicode_escape')) parameters.set_data_separator(codecs.decode(data_separator, 'unicode_escape'))
parameters.set_prefix(codecs.decode(prefix, 'unicode_escape')) parameters.set_prefix(codecs.decode(prefix, 'unicode_escape'))
@ -237,11 +357,11 @@ def ui():
url_input = gr.Textbox(lines=10, label='Input URLs', info='Enter one or more URLs separated by newline characters.') url_input = gr.Textbox(lines=10, label='Input URLs', info='Enter one or more URLs separated by newline characters.')
strong_cleanup = gr.Checkbox(value=parameters.get_is_strong_cleanup(), label='Strong cleanup', info='Only keeps html elements that look like long-form text.') strong_cleanup = gr.Checkbox(value=parameters.get_is_strong_cleanup(), label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
threads = gr.Number(value=parameters.get_num_threads(), label='Threads', info='The number of threads to use while downloading the URLs.', precision=0) threads = gr.Number(value=parameters.get_num_threads(), label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
update_url = gr.Button('Load data') update_urls = gr.Button('Load data')
with gr.Tab("File input"): with gr.Tab("File input"):
file_input = gr.File(label='Input file', type='binary') file_input = gr.File(label="Input file", type="filepath", file_count="multiple")
update_file = gr.Button('Load data') update_files = gr.Button('Load data')
with gr.Tab("Settings"): with gr.Tab("Settings"):
with gr.Accordion("Processing settings", open=True): with gr.Accordion("Processing settings", open=True):
@ -258,6 +378,7 @@ def ui():
postfix = gr.Textbox(value=codecs.encode(parameters.get_postfix(), 'unicode_escape').decode(), label='Postfix', info='What to put after the injection point.') postfix = gr.Textbox(value=codecs.encode(parameters.get_postfix(), 'unicode_escape').decode(), label='Postfix', info='What to put after the injection point.')
with gr.Row(): with gr.Row():
manual = gr.Checkbox(value=parameters.get_is_manual(), label="Is Manual", info="Manually specify when to use ChromaDB. Insert `!c` at the start or end of the message to trigger a query.", visible=shared.is_chat()) manual = gr.Checkbox(value=parameters.get_is_manual(), label="Is Manual", info="Manually specify when to use ChromaDB. Insert `!c` at the start or end of the message to trigger a query.", visible=shared.is_chat())
add_date_time = gr.Checkbox(value=parameters.get_add_date_time(), label="Add date and time to Data", info="Make the current date and time available to the model.", visible=shared.is_chat())
add_chat_to_data = gr.Checkbox(value=parameters.get_add_chat_to_data(), label="Add Chat to Data", info="Automatically feed the chat history as you chat.", visible=shared.is_chat()) add_chat_to_data = gr.Checkbox(value=parameters.get_add_chat_to_data(), label="Add Chat to Data", info="Automatically feed the chat history as you chat.", visible=shared.is_chat())
injection_strategy = gr.Radio(choices=[parameters.PREPEND_TO_LAST, parameters.APPEND_TO_LAST, parameters.HIJACK_LAST_IN_CONTEXT], value=parameters.get_injection_strategy(), label='Injection Strategy', info='Where to inject the messages in chat or instruct mode.', visible=shared.is_chat()) injection_strategy = gr.Radio(choices=[parameters.PREPEND_TO_LAST, parameters.APPEND_TO_LAST, parameters.HIJACK_LAST_IN_CONTEXT], value=parameters.get_injection_strategy(), label='Injection Strategy', info='Where to inject the messages in chat or instruct mode.', visible=shared.is_chat())
with gr.Row(): with gr.Row():
@ -313,14 +434,14 @@ def ui():
last_updated = gr.Markdown() last_updated = gr.Markdown()
all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count, preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, add_date_time, postfix, data_separator, prefix, max_token_count,
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup] chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup]
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
preprocess_pipeline, chunk_count, context_len, chunk_len] preprocess_pipeline, chunk_count, context_len, chunk_len]
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False) update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False) update_urls.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
update_file.click(_feed_file_into_collector, [file_input], last_updated, show_progress=False) update_files.click(_feed_file_into_collector, [file_input], last_updated, show_progress=False)
benchmark_button.click(_begin_benchmark, [], last_updated, show_progress=True) benchmark_button.click(_begin_benchmark, [], last_updated, show_progress=True)
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True) optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
clear_button.click(_clear_data, [], last_updated, show_progress=False) clear_button.click(_clear_data, [], last_updated, show_progress=False)
@ -339,6 +460,7 @@ def ui():
api_on.input(fn=_apply_settings, inputs=all_params, show_progress=False) api_on.input(fn=_apply_settings, inputs=all_params, show_progress=False)
injection_strategy.input(fn=_apply_settings, inputs=all_params, show_progress=False) injection_strategy.input(fn=_apply_settings, inputs=all_params, show_progress=False)
add_chat_to_data.input(fn=_apply_settings, inputs=all_params, show_progress=False) add_chat_to_data.input(fn=_apply_settings, inputs=all_params, show_progress=False)
add_date_time.input(fn=_apply_settings, inputs=all_params, show_progress=False)
manual.input(fn=_apply_settings, inputs=all_params, show_progress=False) manual.input(fn=_apply_settings, inputs=all_params, show_progress=False)
postfix.input(fn=_apply_settings, inputs=all_params, show_progress=False) postfix.input(fn=_apply_settings, inputs=all_params, show_progress=False)
data_separator.input(fn=_apply_settings, inputs=all_params, show_progress=False) data_separator.input(fn=_apply_settings, inputs=all_params, show_progress=False)