mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-20 22:13:43 +00:00
Refactor the transformers loader (#6859)
This commit is contained in:
parent
6ba0164c70
commit
ae02ffc605
18 changed files with 464 additions and 528 deletions
|
|
@ -7,33 +7,18 @@ import time
|
|||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import (
|
||||
LogitsProcessorList,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available
|
||||
)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import models, sampler_hijack
|
||||
from modules.callbacks import (
|
||||
Iteratorize,
|
||||
Stream,
|
||||
_StopEverythingStoppingCriteria
|
||||
)
|
||||
from modules import models
|
||||
from modules.callbacks import Iteratorize
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.grammar.grammar_utils import initialize_grammar
|
||||
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
||||
from modules.html_generator import generate_basic_html
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import clear_torch_cache, get_device, load_model
|
||||
|
||||
sampler_hijack.hijack_samplers()
|
||||
|
||||
|
||||
def generate_reply(*args, **kwargs):
|
||||
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
|
||||
from modules.models import load_model
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
shared.generation_lock.acquire()
|
||||
|
|
@ -46,7 +31,6 @@ def generate_reply(*args, **kwargs):
|
|||
|
||||
|
||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False):
|
||||
|
||||
# Find the appropriate generation function
|
||||
generate_func = apply_extensions('custom_generate_reply')
|
||||
if generate_func is None:
|
||||
|
|
@ -80,7 +64,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
all_stop_strings += st
|
||||
|
||||
shared.stop_everything = False
|
||||
seed = set_manual_seed(state['seed'])
|
||||
last_update = -1
|
||||
reply = ''
|
||||
is_stream = state['stream']
|
||||
|
|
@ -93,7 +76,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
min_update_interval = 1 / state['max_updates_second']
|
||||
|
||||
# Generate
|
||||
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
|
||||
for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat):
|
||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||
if escape_html:
|
||||
reply = html.escape(reply)
|
||||
|
|
@ -132,44 +115,55 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
if shared.tokenizer is None:
|
||||
raise ValueError('No tokenizer is loaded')
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']:
|
||||
if shared.model.__class__.__name__ == 'LlamaServer':
|
||||
input_ids = shared.tokenizer.encode(str(prompt), add_bos_token=add_bos_token)
|
||||
else:
|
||||
# llama.cpp case
|
||||
if shared.model.__class__.__name__ == 'LlamaServer':
|
||||
input_ids = shared.tokenizer.encode(str(prompt), add_bos_token=add_bos_token)
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
|
||||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
|
||||
return input_ids
|
||||
|
||||
# All other model types
|
||||
else:
|
||||
import torch
|
||||
|
||||
from modules.torch_utils import get_device
|
||||
|
||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel']:
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
if shared.model.__class__.__name__ != 'Exllamav2Model':
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
else:
|
||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
||||
|
||||
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
else:
|
||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
||||
if hasattr(shared.tokenizer, 'bos_token_id') and shared.tokenizer.bos_token_id is not None:
|
||||
if add_bos_token:
|
||||
# Add BOS token if missing
|
||||
if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0:
|
||||
bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]])
|
||||
input_ids = torch.cat((bos_tensor, input_ids), 1)
|
||||
|
||||
if hasattr(shared.tokenizer, 'bos_token_id') and shared.tokenizer.bos_token_id is not None:
|
||||
if add_bos_token:
|
||||
if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0:
|
||||
# Add a missing bos token (it may not have been added due to faulty model metadata)
|
||||
bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]])
|
||||
input_ids = torch.cat((bos_tensor, input_ids), 1)
|
||||
# Prevent double BOS tokens from jinja templates
|
||||
while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
else:
|
||||
# Remove BOS tokens when not wanted
|
||||
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
# Prevent double bos token due to jinja templates with <s> somewhere
|
||||
while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
else:
|
||||
# Remove any bos token that may have been added
|
||||
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
|
||||
# Handling truncation
|
||||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
else:
|
||||
device = get_device()
|
||||
if device:
|
||||
return input_ids.to(device)
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
else:
|
||||
device = get_device()
|
||||
if device:
|
||||
return input_ids.to(device)
|
||||
|
||||
return input_ids
|
||||
return input_ids
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
|
|
@ -221,6 +215,9 @@ def formatted_outputs(reply, model_name):
|
|||
|
||||
|
||||
def set_manual_seed(seed):
|
||||
import torch
|
||||
from transformers import is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
seed = int(seed)
|
||||
if seed == -1:
|
||||
seed = random.randint(1, 2**31)
|
||||
|
|
@ -285,10 +282,26 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
|||
return reply
|
||||
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
def generate_reply_HF(question, original_question, state, stopping_strings=None, is_chat=False):
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsProcessorList
|
||||
|
||||
from modules.grammar.grammar_utils import initialize_grammar
|
||||
from modules.grammar.logits_process import (
|
||||
GrammarConstrainedLogitsProcessor
|
||||
)
|
||||
from modules.torch_utils import clear_torch_cache, get_device
|
||||
from modules.transformers_loader import (
|
||||
Stream,
|
||||
_StopEverythingStoppingCriteria
|
||||
)
|
||||
|
||||
if shared.args.loader == 'Transformers':
|
||||
clear_torch_cache()
|
||||
|
||||
seed = set_manual_seed(state['seed'])
|
||||
|
||||
generate_params = {}
|
||||
for k in [
|
||||
'temperature',
|
||||
|
|
@ -458,11 +471,15 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
return
|
||||
|
||||
|
||||
def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
def generate_reply_custom(question, original_question, state, stopping_strings=None, is_chat=False):
|
||||
"""
|
||||
For models that do not use the transformers library for sampling
|
||||
"""
|
||||
seed = set_manual_seed(state['seed'])
|
||||
|
||||
seed = state['seed']
|
||||
if shared.args.loader != 'llama.cpp':
|
||||
print(shared.args.loader)
|
||||
seed = set_manual_seed(seed)
|
||||
|
||||
t0 = time.time()
|
||||
reply = ''
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue