mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Improve SuperboogaV2 with Date/Time Embeddings, GPU Support, and Multiple File Formats (#6748)
This commit is contained in:
parent
12f6f7ba9f
commit
01f20d2d9f
|
|
@ -1,5 +1,41 @@
|
|||
# superboogav2
|
||||
# SuperboogaV2
|
||||
|
||||
For a description, please see the comments in this Pull Request:
|
||||
Enhance your LLM with additional information from text, URLs, and files for more accurate and context-aware responses.
|
||||
|
||||
https://github.com/oobabooga/text-generation-webui/pull/3272
|
||||
---
|
||||
|
||||
|
||||
|
||||
## Installation and Activation
|
||||
|
||||
1. Start the conda environment by running `cmd_windows.bat` or the equivalent for your system in the root directory of `text-generation-webui`.
|
||||
2. Install the necessary packages:
|
||||
```
|
||||
pip install -r extensions/superboogav2/requirements.txt
|
||||
```
|
||||
3. Activate the extension in the `Session` tab of the web UI.
|
||||
4. Click on `Apply flags/extensions and restart`. Optionally save the configuration by clicking on `Save UI defaults to settings.yaml`.
|
||||
|
||||
## Usage and Features
|
||||
|
||||
After activation, you can scroll further down in the chat UI to reveal the SuperboogaV2 interface. Here, you can add extra information to your chats through text input, multiple URLs, or by providing multiple files subject to the context window limit of your model.
|
||||
|
||||
The extra information and the current date and time are provided to the model as embeddings that persist across conversations. To clear them, click the `Clear Data` button and start a new chat. You can adjust the text extraction parameters and other options in the `Settings`.
|
||||
|
||||
## Supported File Formats
|
||||
|
||||
SuperboogaV2 utilizes MuPDF, pandas, python-docx, and python-pptx to extract text from various file formats, including:
|
||||
|
||||
- TXT
|
||||
- PDF
|
||||
- EPUB
|
||||
- HTML
|
||||
- CSV
|
||||
- ODT/ODS/ODP
|
||||
- DOCX/PPTX/XLSX
|
||||
|
||||
## Additional Information
|
||||
|
||||
SuperboogaV2 processes your data into context-aware chunks, applies cleaning techniques, and stores them as embeddings to minimize redundant computations. Relevance is determined using distance calculations and prioritization of recent information.
|
||||
|
||||
For a detailed description and more information, refer to the comments in this pull request: [https://github.com/oobabooga/text-generation-webui/pull/3272](https://github.com/oobabooga/text-generation-webui/pull/3272)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
import random
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import chromadb
|
||||
import numpy as np
|
||||
import posthog
|
||||
|
|
@ -16,9 +16,6 @@ from modules.text_generation import decode, encode
|
|||
posthog.capture = lambda *args, **kwargs: None
|
||||
|
||||
|
||||
embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
|
||||
|
||||
|
||||
class Info:
|
||||
def __init__(self, start_index, text_with_context, distance, id):
|
||||
self.text_with_context = text_with_context
|
||||
|
|
@ -77,11 +74,23 @@ class Info:
|
|||
|
||||
class ChromaCollector():
|
||||
def __init__(self):
|
||||
name = ''.join(random.choice('ab') for _ in range(10))
|
||||
name = "".join(random.choice("ab") for _ in range(10))
|
||||
|
||||
self.name = name
|
||||
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
|
||||
self.embedder = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
"sentence-transformers/all-mpnet-base-v2",
|
||||
device=("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
||||
self.collection = chroma_client.create_collection(
|
||||
name=self.name,
|
||||
embedding_function=self.embedder,
|
||||
metadata={
|
||||
"hnsw:search_ef": 200,
|
||||
"hnsw:construction_ef": 200,
|
||||
"hnsw:M": 64,
|
||||
},
|
||||
)
|
||||
|
||||
self.ids = []
|
||||
self.id_to_info = {}
|
||||
|
|
@ -110,7 +119,7 @@ class ChromaCollector():
|
|||
|
||||
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
|
||||
if non_existing_texts:
|
||||
non_existing_embeddings = embedder(non_existing_texts)
|
||||
non_existing_embeddings = self.embedder(non_existing_texts)
|
||||
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
|
||||
self.embeddings_cache[text] = embedding
|
||||
|
||||
|
|
@ -139,7 +148,7 @@ class ChromaCollector():
|
|||
id_ = new_ids[i]
|
||||
metadata = metadatas[i] if metadatas is not None else None
|
||||
embedding = self.embeddings_cache.get(text)
|
||||
if embedding:
|
||||
if embedding is not None and embedding.any():
|
||||
existing_texts.append(text)
|
||||
existing_embeddings.append(embedding)
|
||||
existing_ids.append(id_)
|
||||
|
|
@ -323,6 +332,8 @@ class ChromaCollector():
|
|||
def delete(self, ids_to_delete: list[str], where: dict):
|
||||
with self.lock:
|
||||
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
|
||||
if not ids_to_delete:
|
||||
return
|
||||
self.collection.delete(ids=ids_to_delete, where=where)
|
||||
|
||||
# Remove the deleted ids from self.ids and self.id_to_info
|
||||
|
|
@ -335,12 +346,7 @@ class ChromaCollector():
|
|||
|
||||
def clear(self):
|
||||
with self.lock:
|
||||
self.chroma_client.reset()
|
||||
|
||||
self.ids = []
|
||||
self.chroma_client.delete_collection(name=self.name)
|
||||
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
|
||||
|
||||
self.__init__() # reinitialize the collector
|
||||
logger.info('Successfully cleared all records and reset chromaDB.')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -127,6 +127,9 @@
|
|||
"default": "\n\n<<document end>>\n\n"
|
||||
},
|
||||
"manual": {
|
||||
"default": false
|
||||
},
|
||||
"add_date_time": {
|
||||
"default": true
|
||||
},
|
||||
"add_chat_to_data": {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ It will only include full words.
|
|||
|
||||
import bisect
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
import extensions.superboogav2.parameters as parameters
|
||||
|
||||
|
|
@ -154,6 +155,13 @@ def process_and_add_to_collector(corpus: str, collector: ChromaCollector, clear_
|
|||
data_chunks_with_context = []
|
||||
data_chunk_starting_indices = []
|
||||
|
||||
if parameters.get_add_date_time():
|
||||
now = datetime.now()
|
||||
date_time_chunk = f"Current time is {now.strftime('%H:%M:%S')}. Today is {now.strftime('%A')}. The current date is {now.strftime('%Y-%m-%d')}."
|
||||
data_chunks.append(date_time_chunk)
|
||||
data_chunks_with_context.append(date_time_chunk)
|
||||
data_chunk_starting_indices.append(0)
|
||||
|
||||
# Handling chunk_regex
|
||||
if parameters.get_chunk_regex():
|
||||
if parameters.get_chunk_separator():
|
||||
|
|
|
|||
|
|
@ -39,11 +39,11 @@ def _markdown_hyperparams():
|
|||
# Convert numpy types to python types.
|
||||
def _convert_np_types(params):
|
||||
for key in params:
|
||||
if type(params[key]) == np.bool_:
|
||||
if isinstance(params[key], np.bool_):
|
||||
params[key] = bool(params[key])
|
||||
elif type(params[key]) == np.int64:
|
||||
elif isinstance(params[key], np.int64):
|
||||
params[key] = int(params[key])
|
||||
elif type(params[key]) == np.float64:
|
||||
elif isinstance(params[key], np.float64):
|
||||
params[key] = float(params[key])
|
||||
return params
|
||||
|
||||
|
|
|
|||
|
|
@ -251,6 +251,10 @@ def get_is_manual() -> bool:
|
|||
return bool(Parameters.getInstance().hyperparameters['manual']['default'])
|
||||
|
||||
|
||||
def get_add_date_time() -> bool:
|
||||
return bool(Parameters.getInstance().hyperparameters['add_date_time']['default'])
|
||||
|
||||
|
||||
def get_add_chat_to_data() -> bool:
|
||||
return bool(Parameters.getInstance().hyperparameters['add_chat_to_data']['default'])
|
||||
|
||||
|
|
@ -331,6 +335,10 @@ def set_manual(value: bool):
|
|||
Parameters.getInstance().hyperparameters['manual']['default'] = value
|
||||
|
||||
|
||||
def set_add_date_time(value: bool):
|
||||
Parameters.getInstance().hyperparameters['add_date_time']['default'] = value
|
||||
|
||||
|
||||
def set_add_chat_to_data(value: bool):
|
||||
Parameters.getInstance().hyperparameters['add_chat_to_data']['default'] = value
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,16 @@
|
|||
beautifulsoup4==4.12.2
|
||||
chromadb==0.4.24
|
||||
beautifulsoup4==4.13.3
|
||||
chromadb==0.6.3
|
||||
lxml
|
||||
nltk
|
||||
optuna
|
||||
pandas==2.0.3
|
||||
posthog==2.4.2
|
||||
sentence_transformers==2.2.2
|
||||
pandas
|
||||
posthog==3.13.0
|
||||
sentence_transformers==3.3.1
|
||||
spacy
|
||||
pytextrank
|
||||
num2words
|
||||
PyMuPDF
|
||||
python-docx
|
||||
python-pptx
|
||||
openpyxl
|
||||
odfpy
|
||||
|
|
@ -9,6 +9,13 @@ os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve(
|
|||
|
||||
import codecs
|
||||
import textwrap
|
||||
import docx
|
||||
import pptx
|
||||
import fitz
|
||||
fitz.TOOLS.mupdf_display_errors(False)
|
||||
import pandas as pd
|
||||
from odf.opendocument import load
|
||||
from odf.draw import Page
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
|
@ -46,11 +53,123 @@ def _feed_data_into_collector(corpus):
|
|||
yield '### Done.'
|
||||
|
||||
|
||||
def _feed_file_into_collector(file):
|
||||
yield '### Reading and processing the input dataset...'
|
||||
text = file.decode('utf-8')
|
||||
process_and_add_to_collector(text, collector, False, create_metadata_source('file'))
|
||||
yield '### Done.'
|
||||
def _feed_file_into_collector(files):
|
||||
if not files:
|
||||
logger.warning("No files selected.")
|
||||
return
|
||||
|
||||
def read_binary_file(file_path):
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
logger.error(f"Failed to read {file_path}.")
|
||||
return None
|
||||
|
||||
def extract_with_utf8(text):
|
||||
try:
|
||||
return text.decode('utf-8')
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def extract_with_fitz(file_content):
|
||||
try:
|
||||
with fitz.open(stream=file_content, filetype=None) as doc:
|
||||
num_pages = doc.page_count
|
||||
text = "\n".join(block[4] for page in doc for block in page.get_text("blocks") if block[6] == 0)
|
||||
logger.info(f"Extracted text from {num_pages} pages with fitz.")
|
||||
return text
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def extract_with_docx(file_path):
|
||||
try:
|
||||
paragraphs = docx.Document(file_path).paragraphs
|
||||
text = "\n".join(para.text for para in paragraphs)
|
||||
logger.info(f"Extracted text from {len(paragraphs)} paragraphs with docx.")
|
||||
return text
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def extract_with_pptx(file_path):
|
||||
try:
|
||||
slides = pptx.Presentation(file_path).slides
|
||||
text = "\n".join(
|
||||
shape.text for slide in slides for shape in slide.shapes if hasattr(shape, "text")
|
||||
)
|
||||
logger.info(f"Extracted text from {len(slides)} slides with pptx.")
|
||||
return text
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def extract_with_odf(file_path):
|
||||
if not file_path.endswith(".odp"):
|
||||
return ""
|
||||
try:
|
||||
doc = load(file_path)
|
||||
text_content = []
|
||||
|
||||
def extract_text(element):
|
||||
parts = []
|
||||
if hasattr(element, "childNodes"):
|
||||
for node in element.childNodes:
|
||||
if node.nodeType == node.TEXT_NODE:
|
||||
parts.append(node.data)
|
||||
else:
|
||||
parts.append(extract_text(node))
|
||||
return "".join(parts)
|
||||
|
||||
for slide in doc.getElementsByType(Page):
|
||||
slide_text = extract_text(slide)
|
||||
if slide_text.strip():
|
||||
text_content.append(slide_text.strip())
|
||||
|
||||
text = "\n".join(text_content)
|
||||
logger.info(f"Extracted text from {len(doc.getElementsByType(Page))} slides with odf.")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract text from {file_path}: {str(e)}")
|
||||
return ""
|
||||
|
||||
def extract_with_pandas(file_path):
|
||||
try:
|
||||
df = pd.read_excel(file_path)
|
||||
text = "\n".join(str(cell) for col in df.columns for cell in df[col])
|
||||
logger.info(f"Extracted text from {df.shape[0]}x{df.shape[1]} cells with pandas.")
|
||||
return text
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
for index, file in enumerate(files, start=1):
|
||||
file_name = os.path.basename(file)
|
||||
logger.info(f"Processing {file_name}...")
|
||||
|
||||
file_content = read_binary_file(file)
|
||||
if not file_content:
|
||||
continue
|
||||
|
||||
text_extractors = [
|
||||
lambda: extract_with_utf8(file_content),
|
||||
lambda: extract_with_fitz(file_content),
|
||||
lambda: extract_with_docx(file),
|
||||
lambda: extract_with_pptx(file),
|
||||
lambda: extract_with_odf(file),
|
||||
lambda: extract_with_pandas(file),
|
||||
]
|
||||
|
||||
for extractor in text_extractors:
|
||||
text = extractor()
|
||||
if text:
|
||||
break
|
||||
|
||||
if not text:
|
||||
logger.error(f"Failed to extract text from {file_name}, unsupported format.")
|
||||
continue
|
||||
|
||||
process_and_add_to_collector(text, collector, False, create_metadata_source(f"file-{index}"))
|
||||
|
||||
logger.info("Done.")
|
||||
yield "### Done."
|
||||
|
||||
|
||||
def _feed_url_into_collector(urls):
|
||||
|
|
@ -107,7 +226,7 @@ def _get_optimizable_settings() -> list:
|
|||
|
||||
|
||||
def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, add_date_time, postfix, data_separator, prefix, max_token_count,
|
||||
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup):
|
||||
logger.debug('Applying settings.')
|
||||
|
||||
|
|
@ -124,6 +243,7 @@ def _apply_settings(optimization_steps, time_power, time_steepness, significant_
|
|||
parameters.set_injection_strategy(injection_strategy)
|
||||
parameters.set_add_chat_to_data(add_chat_to_data)
|
||||
parameters.set_manual(manual)
|
||||
parameters.set_add_date_time(add_date_time)
|
||||
parameters.set_postfix(codecs.decode(postfix, 'unicode_escape'))
|
||||
parameters.set_data_separator(codecs.decode(data_separator, 'unicode_escape'))
|
||||
parameters.set_prefix(codecs.decode(prefix, 'unicode_escape'))
|
||||
|
|
@ -237,11 +357,11 @@ def ui():
|
|||
url_input = gr.Textbox(lines=10, label='Input URLs', info='Enter one or more URLs separated by newline characters.')
|
||||
strong_cleanup = gr.Checkbox(value=parameters.get_is_strong_cleanup(), label='Strong cleanup', info='Only keeps html elements that look like long-form text.')
|
||||
threads = gr.Number(value=parameters.get_num_threads(), label='Threads', info='The number of threads to use while downloading the URLs.', precision=0)
|
||||
update_url = gr.Button('Load data')
|
||||
update_urls = gr.Button('Load data')
|
||||
|
||||
with gr.Tab("File input"):
|
||||
file_input = gr.File(label='Input file', type='binary')
|
||||
update_file = gr.Button('Load data')
|
||||
file_input = gr.File(label="Input file", type="filepath", file_count="multiple")
|
||||
update_files = gr.Button('Load data')
|
||||
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Accordion("Processing settings", open=True):
|
||||
|
|
@ -258,6 +378,7 @@ def ui():
|
|||
postfix = gr.Textbox(value=codecs.encode(parameters.get_postfix(), 'unicode_escape').decode(), label='Postfix', info='What to put after the injection point.')
|
||||
with gr.Row():
|
||||
manual = gr.Checkbox(value=parameters.get_is_manual(), label="Is Manual", info="Manually specify when to use ChromaDB. Insert `!c` at the start or end of the message to trigger a query.", visible=shared.is_chat())
|
||||
add_date_time = gr.Checkbox(value=parameters.get_add_date_time(), label="Add date and time to Data", info="Make the current date and time available to the model.", visible=shared.is_chat())
|
||||
add_chat_to_data = gr.Checkbox(value=parameters.get_add_chat_to_data(), label="Add Chat to Data", info="Automatically feed the chat history as you chat.", visible=shared.is_chat())
|
||||
injection_strategy = gr.Radio(choices=[parameters.PREPEND_TO_LAST, parameters.APPEND_TO_LAST, parameters.HIJACK_LAST_IN_CONTEXT], value=parameters.get_injection_strategy(), label='Injection Strategy', info='Where to inject the messages in chat or instruct mode.', visible=shared.is_chat())
|
||||
with gr.Row():
|
||||
|
|
@ -313,14 +434,14 @@ def ui():
|
|||
last_updated = gr.Markdown()
|
||||
|
||||
all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
|
||||
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, add_date_time, postfix, data_separator, prefix, max_token_count,
|
||||
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup]
|
||||
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
|
||||
preprocess_pipeline, chunk_count, context_len, chunk_len]
|
||||
|
||||
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
|
||||
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
|
||||
update_file.click(_feed_file_into_collector, [file_input], last_updated, show_progress=False)
|
||||
update_urls.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
|
||||
update_files.click(_feed_file_into_collector, [file_input], last_updated, show_progress=False)
|
||||
benchmark_button.click(_begin_benchmark, [], last_updated, show_progress=True)
|
||||
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
|
||||
clear_button.click(_clear_data, [], last_updated, show_progress=False)
|
||||
|
|
@ -339,6 +460,7 @@ def ui():
|
|||
api_on.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
injection_strategy.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
add_chat_to_data.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
add_date_time.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
manual.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
postfix.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
data_separator.input(fn=_apply_settings, inputs=all_params, show_progress=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue