mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-09 00:23:38 +00:00
Refactor to not import gradio in --nowebui mode
This commit is contained in:
parent
970055ca00
commit
39e6c997cc
7 changed files with 232 additions and 209 deletions
|
|
@ -4,10 +4,9 @@ import re
|
|||
from math import floor
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
|
||||
from modules import chat, loaders, metadata_gguf, shared, ui
|
||||
from modules import loaders, metadata_gguf, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.utils import resolve_model_path
|
||||
|
||||
|
|
@ -199,7 +198,7 @@ def get_model_metadata(model):
|
|||
|
||||
# Load instruction template if defined by name rather than by value
|
||||
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
|
||||
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
|
||||
model_settings['instruction_template_str'] = load_instruction_template(model_settings['instruction_template'])
|
||||
|
||||
return model_settings
|
||||
|
||||
|
|
@ -228,7 +227,7 @@ def update_model_parameters(state, initial=False):
|
|||
'''
|
||||
UI: update the command-line arguments based on the interface values
|
||||
'''
|
||||
elements = ui.list_model_elements() # the names of the parameters
|
||||
elements = loaders.list_model_elements() # the names of the parameters
|
||||
|
||||
for i, element in enumerate(elements):
|
||||
if element not in state:
|
||||
|
|
@ -248,6 +247,7 @@ def apply_model_settings_to_state(model, state):
|
|||
'''
|
||||
UI: update the state variable with the model settings
|
||||
'''
|
||||
import gradio as gr
|
||||
model_settings = get_model_metadata(model)
|
||||
if 'loader' in model_settings:
|
||||
loader = model_settings.pop('loader')
|
||||
|
|
@ -290,7 +290,7 @@ def save_model_settings(model, state):
|
|||
if model_regex not in user_config:
|
||||
user_config[model_regex] = {}
|
||||
|
||||
for k in ui.list_model_elements():
|
||||
for k in loaders.list_model_elements():
|
||||
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
|
||||
user_config[model_regex][k] = state[k]
|
||||
|
||||
|
|
@ -419,3 +419,102 @@ def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type):
|
|||
|
||||
vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type)
|
||||
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{vram_usage:.0f} MiB</span></div>"
|
||||
|
||||
|
||||
def load_instruction_template(template):
|
||||
if template == 'None':
|
||||
return ''
|
||||
|
||||
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
|
||||
if filepath.exists():
|
||||
break
|
||||
else:
|
||||
return ''
|
||||
|
||||
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
||||
data = yaml.safe_load(file_contents)
|
||||
if 'instruction_template' in data:
|
||||
return data['instruction_template']
|
||||
else:
|
||||
return _jinja_template_from_old_format(data)
|
||||
|
||||
|
||||
def _jinja_template_from_old_format(params, verbose=False):
|
||||
MASTER_TEMPLATE = """
|
||||
{%- set ns = namespace(found=false) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- set ns.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if not ns.found -%}
|
||||
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
|
||||
{%- else -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
|
||||
{%- else -%}
|
||||
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
|
||||
{%- endif -%}
|
||||
"""
|
||||
|
||||
if 'context' in params and '<|system-message|>' in params['context']:
|
||||
pre_system = params['context'].split('<|system-message|>')[0]
|
||||
post_system = params['context'].split('<|system-message|>')[1]
|
||||
else:
|
||||
pre_system = ''
|
||||
post_system = ''
|
||||
|
||||
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
|
||||
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
|
||||
|
||||
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
|
||||
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
|
||||
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
|
||||
|
||||
def preprocess(string):
|
||||
return string.replace('\n', '\\n').replace('\'', '\\\'')
|
||||
|
||||
pre_system = preprocess(pre_system)
|
||||
post_system = preprocess(post_system)
|
||||
pre_user = preprocess(pre_user)
|
||||
post_user = preprocess(post_user)
|
||||
pre_assistant = preprocess(pre_assistant)
|
||||
post_assistant = preprocess(post_assistant)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
'\n',
|
||||
repr(pre_system) + '\n',
|
||||
repr(post_system) + '\n',
|
||||
repr(pre_user) + '\n',
|
||||
repr(post_user) + '\n',
|
||||
repr(pre_assistant) + '\n',
|
||||
repr(post_assistant) + '\n',
|
||||
)
|
||||
|
||||
result = MASTER_TEMPLATE
|
||||
if 'system_message' in params:
|
||||
result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message']))
|
||||
else:
|
||||
result = result.replace('<|SYSTEM-MESSAGE|>', '')
|
||||
|
||||
result = result.replace('<|PRE-SYSTEM|>', pre_system)
|
||||
result = result.replace('<|POST-SYSTEM|>', post_system)
|
||||
result = result.replace('<|PRE-USER|>', pre_user)
|
||||
result = result.replace('<|POST-USER|>', post_user)
|
||||
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
|
||||
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' '))
|
||||
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
|
||||
|
||||
result = result.strip()
|
||||
|
||||
return result
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue