Refactor the transformers loader (#6859)

This commit is contained in:
oobabooga 2025-04-20 13:33:47 -03:00 committed by GitHub
parent 6ba0164c70
commit ae02ffc605
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 464 additions and 528 deletions

View file

@ -7,33 +7,18 @@ import time
import traceback
import numpy as np
import torch
import transformers
from transformers import (
LogitsProcessorList,
is_torch_npu_available,
is_torch_xpu_available
)
import modules.shared as shared
from modules import models, sampler_hijack
from modules.callbacks import (
Iteratorize,
Stream,
_StopEverythingStoppingCriteria
)
from modules import models
from modules.callbacks import Iteratorize
from modules.extensions import apply_extensions
from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
from modules.html_generator import generate_basic_html
from modules.logging_colors import logger
from modules.models import clear_torch_cache, get_device, load_model
sampler_hijack.hijack_samplers()
def generate_reply(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
from modules.models import load_model
shared.model, shared.tokenizer = load_model(shared.model_name)
shared.generation_lock.acquire()
@ -46,7 +31,6 @@ def generate_reply(*args, **kwargs):
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False):
# Find the appropriate generation function
generate_func = apply_extensions('custom_generate_reply')
if generate_func is None:
@ -80,7 +64,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
all_stop_strings += st
shared.stop_everything = False
seed = set_manual_seed(state['seed'])
last_update = -1
reply = ''
is_stream = state['stream']
@ -93,7 +76,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
min_update_interval = 1 / state['max_updates_second']
# Generate
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat):
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
if escape_html:
reply = html.escape(reply)
@ -132,44 +115,55 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']:
if shared.model.__class__.__name__ == 'LlamaServer':
input_ids = shared.tokenizer.encode(str(prompt), add_bos_token=add_bos_token)
else:
# llama.cpp case
if shared.model.__class__.__name__ == 'LlamaServer':
input_ids = shared.tokenizer.encode(str(prompt), add_bos_token=add_bos_token)
input_ids = np.array(input_ids).reshape(1, len(input_ids))
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
return input_ids
# All other model types
else:
import torch
from modules.torch_utils import get_device
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel']:
input_ids = shared.tokenizer.encode(str(prompt))
if shared.model.__class__.__name__ != 'Exllamav2Model':
input_ids = np.array(input_ids).reshape(1, len(input_ids))
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
input_ids = np.array(input_ids).reshape(1, len(input_ids))
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
if hasattr(shared.tokenizer, 'bos_token_id') and shared.tokenizer.bos_token_id is not None:
if add_bos_token:
# Add BOS token if missing
if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0:
bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]])
input_ids = torch.cat((bos_tensor, input_ids), 1)
if hasattr(shared.tokenizer, 'bos_token_id') and shared.tokenizer.bos_token_id is not None:
if add_bos_token:
if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0:
# Add a missing bos token (it may not have been added due to faulty model metadata)
bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]])
input_ids = torch.cat((bos_tensor, input_ids), 1)
# Prevent double BOS tokens from jinja templates
while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
else:
# Remove BOS tokens when not wanted
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
# Prevent double bos token due to jinja templates with <s> somewhere
while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
else:
# Remove any bos token that may have been added
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
# Handling truncation
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
return input_ids
else:
device = get_device()
if device:
return input_ids.to(device)
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
return input_ids
else:
device = get_device()
if device:
return input_ids.to(device)
return input_ids
return input_ids
def decode(output_ids, skip_special_tokens=True):
@ -221,6 +215,9 @@ def formatted_outputs(reply, model_name):
def set_manual_seed(seed):
import torch
from transformers import is_torch_npu_available, is_torch_xpu_available
seed = int(seed)
if seed == -1:
seed = random.randint(1, 2**31)
@ -285,10 +282,26 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
return reply
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
def generate_reply_HF(question, original_question, state, stopping_strings=None, is_chat=False):
import torch
import transformers
from transformers import LogitsProcessorList
from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import (
GrammarConstrainedLogitsProcessor
)
from modules.torch_utils import clear_torch_cache, get_device
from modules.transformers_loader import (
Stream,
_StopEverythingStoppingCriteria
)
if shared.args.loader == 'Transformers':
clear_torch_cache()
seed = set_manual_seed(state['seed'])
generate_params = {}
for k in [
'temperature',
@ -458,11 +471,15 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
return
def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
def generate_reply_custom(question, original_question, state, stopping_strings=None, is_chat=False):
"""
For models that do not use the transformers library for sampling
"""
seed = set_manual_seed(state['seed'])
seed = state['seed']
if shared.args.loader != 'llama.cpp':
print(shared.args.loader)
seed = set_manual_seed(seed)
t0 = time.time()
reply = ''