mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-29 03:44:30 +01:00
commit
7c883ef2f0
|
|
@ -4,6 +4,8 @@ A Gradio web UI for Large Language Models.
|
|||
|
||||
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
|
||||
|
||||
[Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason)
|
||||
|
||||
| |  |
|
||||
|:---:|:---:|
|
||||
| |  |
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
'''
|
||||
|
||||
Converts a transformers model to safetensors format and shards it.
|
||||
|
||||
This makes it faster to load (because of safetensors) and lowers its RAM usage
|
||||
while loading (because of sharding).
|
||||
|
||||
Based on the original script by 81300:
|
||||
|
||||
https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
|
||||
|
||||
'''
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||
parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
|
||||
parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")
|
||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
path = Path(args.MODEL)
|
||||
model_name = path.name
|
||||
|
||||
print(f"Loading {model_name}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
|
||||
out_folder = args.output or Path(f"models/{model_name}_safetensors")
|
||||
print(f"Saving the converted model to {out_folder} with a maximum shard size of {args.max_shard_size}...")
|
||||
model.save_pretrained(out_folder, max_shard_size=args.max_shard_size, safe_serialization=True)
|
||||
tokenizer.save_pretrained(out_folder)
|
||||
|
|
@ -46,7 +46,7 @@
|
|||
}
|
||||
|
||||
.chat .user-message {
|
||||
background: #f4f4f4;
|
||||
background: #f5f5f5;
|
||||
padding: 1.5rem 1rem;
|
||||
padding-bottom: 2rem;
|
||||
border-radius: 0;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
--darker-gray: #202123;
|
||||
--dark-gray: #343541;
|
||||
--light-gray: #444654;
|
||||
--light-theme-gray: #f4f4f4;
|
||||
--light-theme-gray: #f5f5f5;
|
||||
--border-color-dark: #525252;
|
||||
--header-width: 112px;
|
||||
--selected-item-color-dark: #32333e;
|
||||
|
|
|
|||
|
|
@ -135,9 +135,12 @@ When you git clone a repository, put it inside WSL and not outside. To understan
|
|||
|
||||
### Bonus: Port Forwarding
|
||||
|
||||
By default, you won't be able to access the webui from another device on your local network. You will need to setup the appropriate port forwarding using the following command (using PowerShell or Terminal with administrator privileges).
|
||||
By default, you won't be able to access the webui from another device on your local network. You will need to setup the appropriate port forwarding using the following steps:
|
||||
|
||||
1. First, get the IP address of the WSL by typing `wsl hostname -I`. This will output the IP address, for example `172.20.134.111`.
|
||||
2. Then, use the following command (using PowerShell or Terminal with administrator privileges) to set up port forwarding, replacing `172.20.134.111` with the IP address you obtained in step 1:
|
||||
|
||||
```
|
||||
netsh interface portproxy add v4tov4 listenaddress=0.0.0.0 listenport=7860 connectaddress=localhost connectport=7860
|
||||
netsh interface portproxy add v4tov4 listenaddress=0.0.0.0 listenport=7860 connectaddress=172.20.134.111 connectport=7860
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import torch
|
|||
from PIL import Image
|
||||
|
||||
from modules import shared
|
||||
from modules.models import reload_model, unload_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.ui import create_refresh_button
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
|
@ -38,7 +38,8 @@ params = {
|
|||
'cfg_scale': 7,
|
||||
'textgen_prefix': 'Please provide a detailed and vivid description of [subject]',
|
||||
'sd_checkpoint': ' ',
|
||||
'checkpoint_list': [" "]
|
||||
'checkpoint_list': [" "],
|
||||
'last_model': ""
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -46,6 +47,7 @@ def give_VRAM_priority(actor):
|
|||
global shared, params
|
||||
|
||||
if actor == 'SD':
|
||||
params["last_model"] = shared.model_name
|
||||
unload_model()
|
||||
print("Requesting Auto1111 to re-load last checkpoint used...")
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
||||
|
|
@ -55,7 +57,8 @@ def give_VRAM_priority(actor):
|
|||
print("Requesting Auto1111 to vacate VRAM...")
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
||||
response.raise_for_status()
|
||||
reload_model()
|
||||
if params["last_model"]:
|
||||
shared.model, shared.tokenizer = load_model(params["last_model"])
|
||||
|
||||
elif actor == 'set':
|
||||
print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
|
||||
|
|
|
|||
|
|
@ -412,8 +412,16 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_
|
|||
yield history
|
||||
return
|
||||
|
||||
show_after = html.escape(state["show_after"]) if state["show_after"] else None
|
||||
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui):
|
||||
yield history
|
||||
if show_after:
|
||||
after = history["visible"][-1][1].partition(show_after)[2] or "*Is thinking...*"
|
||||
yield {
|
||||
'internal': history['internal'],
|
||||
'visible': history['visible'][:-1] + [[history['visible'][-1][0], after]]
|
||||
}
|
||||
else:
|
||||
yield history
|
||||
|
||||
|
||||
def character_is_loaded(state, raise_exception=False):
|
||||
|
|
|
|||
|
|
@ -106,52 +106,6 @@ def replace_blockquote(m):
|
|||
return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
|
||||
|
||||
|
||||
def add_long_list_class(html):
|
||||
'''
|
||||
Adds a long-list class to <ul> or <ol> containing long <li> items.
|
||||
These will receive a smaller margin/padding in the CSS.
|
||||
'''
|
||||
|
||||
# Helper function to check if a tag is within <pre> or <code>
|
||||
def is_within_block(start_idx, end_idx, block_matches):
|
||||
return any(start < start_idx < end or start < end_idx < end for start, end in block_matches)
|
||||
|
||||
# Find all <pre>...</pre> and <code>...</code> blocks
|
||||
pre_blocks = [(m.start(), m.end()) for m in re.finditer(r'<pre.*?>.*?</pre>', html, re.DOTALL)]
|
||||
code_blocks = [(m.start(), m.end()) for m in re.finditer(r'<code.*?>.*?</code>', html, re.DOTALL)]
|
||||
all_blocks = pre_blocks + code_blocks
|
||||
|
||||
# Pattern to find <ul>...</ul> and <ol>...</ol> blocks and their contents
|
||||
list_pattern = re.compile(r'(<[uo]l.*?>)(.*?)(</[uo]l>)', re.DOTALL)
|
||||
li_pattern = re.compile(r'<li.*?>(.*?)</li>', re.DOTALL)
|
||||
|
||||
def process_list(match):
|
||||
start_idx, end_idx = match.span()
|
||||
if is_within_block(start_idx, end_idx, all_blocks):
|
||||
return match.group(0) # Leave the block unchanged if within <pre> or <code>
|
||||
|
||||
opening_tag = match.group(1)
|
||||
list_content = match.group(2)
|
||||
closing_tag = match.group(3)
|
||||
|
||||
# Find all list items within this list
|
||||
li_matches = li_pattern.finditer(list_content)
|
||||
has_long_item = any(len(li_match.group(1).strip()) > 224 for li_match in li_matches)
|
||||
|
||||
if has_long_item:
|
||||
# Add class="long-list" to the opening tag if it doesn't already have a class
|
||||
if 'class=' not in opening_tag:
|
||||
opening_tag = opening_tag[:-1] + ' class="long-list">'
|
||||
else:
|
||||
# If there's already a class, append long-list to it
|
||||
opening_tag = re.sub(r'class="([^"]*)"', r'class="\1 long-list"', opening_tag)
|
||||
|
||||
return opening_tag + list_content + closing_tag
|
||||
|
||||
# Process HTML and replace list blocks
|
||||
return list_pattern.sub(process_list, html)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def convert_to_markdown(string):
|
||||
if not string:
|
||||
|
|
@ -251,9 +205,6 @@ def convert_to_markdown(string):
|
|||
# Unescape backslashes
|
||||
html_output = html_output.replace('\\\\', '\\')
|
||||
|
||||
# Add "long-list" class to <ul> or <ol> containing a long <li> item
|
||||
html_output = add_long_list_class(html_output)
|
||||
|
||||
return html_output
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ transformers.logging.set_verbosity_error()
|
|||
local_rank = None
|
||||
if shared.args.deepspeed:
|
||||
import deepspeed
|
||||
from transformers.deepspeed import (
|
||||
from transformers.integrations.deepspeed import (
|
||||
HfDeepSpeedConfig,
|
||||
is_deepspeed_zero3_enabled
|
||||
)
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ settings = {
|
|||
'seed': -1,
|
||||
'custom_stopping_strings': '',
|
||||
'custom_token_bans': '',
|
||||
'show_after': '',
|
||||
'negative_prompt': '',
|
||||
'autoload_model': False,
|
||||
'dark_theme': True,
|
||||
|
|
|
|||
|
|
@ -215,6 +215,7 @@ def list_interface_input_elements():
|
|||
'sampler_priority',
|
||||
'custom_stopping_strings',
|
||||
'custom_token_bans',
|
||||
'show_after',
|
||||
'negative_prompt',
|
||||
'dry_sequence_breakers',
|
||||
'grammar_string',
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ def create_ui(default_preset):
|
|||
shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
|
||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=2, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
|
||||
shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Token bans', info='Token IDs to ban, separated by commas. The IDs can be found in the Default or Notebook tab.')
|
||||
shared.gradio['show_after'] = gr.Textbox(value=shared.settings['show_after'] or None, label='Show after', info='Hide the reply before this text.', placeholder="</think>")
|
||||
shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', info='For CFG. Only used when guidance_scale is different than 1.', lines=3, elem_classes=['add_scrollbar'])
|
||||
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
|
||||
with gr.Row() as shared.gradio['grammar_file_row']:
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ truncation_length: 2048
|
|||
seed: -1
|
||||
custom_stopping_strings: ''
|
||||
custom_token_bans: ''
|
||||
show_after: ''
|
||||
negative_prompt: ''
|
||||
autoload_model: false
|
||||
dark_theme: true
|
||||
|
|
|
|||
Loading…
Reference in a new issue