mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-10 00:53:39 +00:00
Add ExLlamaV3 support (#6832)
This commit is contained in:
parent
0b3503c91f
commit
8b8d39ec4e
13 changed files with 322 additions and 54 deletions
179
modules/exllamav3_hf.py
Normal file
179
modules/exllamav3_hf.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from exllamav3 import Cache, Config, Model
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
except Exception:
|
||||
logger.warning('Failed to load flash-attention due to the following error:\n')
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class Exllamav3HF(PreTrainedModel):
|
||||
def __init__(self, model_dir):
|
||||
super().__init__(PretrainedConfig())
|
||||
self.generation_config = GenerationConfig()
|
||||
|
||||
config = Config.from_directory(model_dir)
|
||||
self.ex_model = Model.from_config(config)
|
||||
|
||||
# Calculate the closest multiple of 256 at or above the chosen value
|
||||
max_tokens = shared.args.max_seq_len
|
||||
if max_tokens % 256 != 0:
|
||||
adjusted_tokens = ((max_tokens // 256) + 1) * 256
|
||||
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
|
||||
max_tokens = adjusted_tokens
|
||||
|
||||
self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens)
|
||||
|
||||
# Create load parameters dictionary
|
||||
load_params = {'progressbar': True}
|
||||
if shared.args.gpu_split:
|
||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||
load_params['use_per_device'] = split
|
||||
|
||||
self.ex_model.load(**load_params)
|
||||
self.past_seq = None
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _validate_model_class(self):
|
||||
pass
|
||||
|
||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||
pass
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {'input_ids': input_ids, **kwargs}
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(0)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
use_cache = kwargs.get('use_cache', True)
|
||||
labels = kwargs.get('labels', None)
|
||||
past_key_values = kwargs.get('past_key_values', None)
|
||||
|
||||
if len(args) > 0:
|
||||
if not shared.args.cfg_cache:
|
||||
logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.")
|
||||
return
|
||||
|
||||
input_ids = args[0]
|
||||
is_negative = True
|
||||
past_seq = self.past_seq_negative
|
||||
ex_cache = self.ex_cache_negative
|
||||
else:
|
||||
input_ids = kwargs['input_ids']
|
||||
is_negative = False
|
||||
past_seq = self.past_seq
|
||||
ex_cache = self.ex_cache
|
||||
|
||||
seq = input_ids[0].tolist()
|
||||
if is_negative and past_key_values is not None:
|
||||
seq = past_key_values + seq
|
||||
|
||||
seq_tensor = torch.tensor(seq)
|
||||
reset = True
|
||||
|
||||
# Make the forward call
|
||||
if labels is None:
|
||||
if past_seq is not None:
|
||||
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
||||
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
||||
if len(indices) > 0:
|
||||
longest_prefix = indices[0].item()
|
||||
else:
|
||||
longest_prefix = min_length
|
||||
|
||||
if longest_prefix > 0:
|
||||
reset = False
|
||||
current_len = longest_prefix
|
||||
if len(seq_tensor) - longest_prefix > 1:
|
||||
self.ex_model.forward(
|
||||
input_ids=seq_tensor[longest_prefix:-1].view(1, -1),
|
||||
params={
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": ex_cache,
|
||||
"past_len": longest_prefix,
|
||||
"batch_shape": (1, self.max_tokens)
|
||||
}
|
||||
)
|
||||
|
||||
current_len = longest_prefix + len(seq_tensor) - longest_prefix - 1
|
||||
|
||||
if reset:
|
||||
if len(seq_tensor) > 1:
|
||||
self.ex_model.forward(
|
||||
input_ids=seq_tensor[:-1].view(1, -1),
|
||||
params={
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": ex_cache,
|
||||
"past_len": 0,
|
||||
"batch_shape": (1, self.max_tokens)
|
||||
}
|
||||
)
|
||||
|
||||
current_len = len(seq_tensor) - 1
|
||||
else:
|
||||
current_len = 0
|
||||
|
||||
logits = self.ex_model.forward(
|
||||
input_ids=seq_tensor[-1:].view(1, -1),
|
||||
params={
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": ex_cache,
|
||||
"past_len": current_len,
|
||||
"batch_shape": (1, self.max_tokens)
|
||||
}
|
||||
).to(input_ids.device).float()
|
||||
else:
|
||||
logits = self.ex_model.forward(
|
||||
input_ids=seq_tensor.view(1, -1),
|
||||
params={
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": ex_cache,
|
||||
"past_len": 0,
|
||||
"batch_shape": (1, self.max_tokens)
|
||||
}
|
||||
).float()
|
||||
|
||||
if is_negative:
|
||||
self.past_seq_negative = seq_tensor
|
||||
else:
|
||||
self.past_seq = seq_tensor
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, logits.shape[-1])
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
|
||||
if isinstance(pretrained_model_name_or_path, str):
|
||||
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
||||
|
||||
pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
|
||||
|
||||
return Exllamav3HF(pretrained_model_name_or_path)
|
||||
|
|
@ -23,7 +23,6 @@ loaders_and_params = OrderedDict({
|
|||
'use_double_quant',
|
||||
'use_eager_attention',
|
||||
'bf16',
|
||||
|
||||
'trust_remote_code',
|
||||
'no_use_fast',
|
||||
],
|
||||
|
|
@ -76,6 +75,13 @@ loaders_and_params = OrderedDict({
|
|||
'no_use_fast',
|
||||
'llamacpp_HF_info',
|
||||
],
|
||||
'ExLlamav3_HF': [
|
||||
'max_seq_len',
|
||||
'gpu_split',
|
||||
'cfg_cache',
|
||||
'trust_remote_code',
|
||||
'no_use_fast',
|
||||
],
|
||||
'ExLlamav2_HF': [
|
||||
'max_seq_len',
|
||||
'cache_type',
|
||||
|
|
@ -174,30 +180,38 @@ def transformers_samplers():
|
|||
loaders_samplers = {
|
||||
'Transformers': transformers_samplers(),
|
||||
'HQQ': transformers_samplers(),
|
||||
'ExLlamav2': {
|
||||
'ExLlamav3_HF': {
|
||||
'temperature',
|
||||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'smoothing_curve',
|
||||
'min_p',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'typical_p',
|
||||
'xtc_threshold',
|
||||
'xtc_probability',
|
||||
'epsilon_cutoff',
|
||||
'eta_cutoff',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
'repetition_penalty',
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'encoder_repetition_penalty',
|
||||
'no_repeat_ngram_size',
|
||||
'repetition_penalty_range',
|
||||
'guidance_scale',
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'do_sample',
|
||||
'dynamic_temperature',
|
||||
'temperature_last',
|
||||
'auto_max_new_tokens',
|
||||
|
|
@ -205,8 +219,12 @@ loaders_samplers = {
|
|||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'seed',
|
||||
'sampler_priority',
|
||||
'custom_token_bans',
|
||||
'negative_prompt',
|
||||
'dry_sequence_breakers',
|
||||
'grammar_string',
|
||||
'grammar_file_row',
|
||||
},
|
||||
'ExLlamav2_HF': {
|
||||
'temperature',
|
||||
|
|
@ -254,6 +272,40 @@ loaders_samplers = {
|
|||
'grammar_string',
|
||||
'grammar_file_row',
|
||||
},
|
||||
'ExLlamav2': {
|
||||
'temperature',
|
||||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'min_p',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'typical_p',
|
||||
'xtc_threshold',
|
||||
'xtc_probability',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
'repetition_penalty',
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'repetition_penalty_range',
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'dynamic_temperature',
|
||||
'temperature_last',
|
||||
'auto_max_new_tokens',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'seed',
|
||||
'custom_token_bans',
|
||||
'dry_sequence_breakers',
|
||||
},
|
||||
'llama.cpp': {
|
||||
'temperature',
|
||||
'min_p',
|
||||
|
|
|
|||
|
|
@ -69,8 +69,9 @@ def load_model(model_name, loader=None):
|
|||
'Transformers': huggingface_loader,
|
||||
'llama.cpp': llamacpp_loader,
|
||||
'llamacpp_HF': llamacpp_HF_loader,
|
||||
'ExLlamav2': ExLlamav2_loader,
|
||||
'ExLlamav3_HF': ExLlamav3_HF_loader,
|
||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||
'ExLlamav2': ExLlamav2_loader,
|
||||
'HQQ': HQQ_loader,
|
||||
'TensorRT-LLM': TensorRT_LLM_loader,
|
||||
}
|
||||
|
|
@ -304,11 +305,10 @@ def llamacpp_HF_loader(model_name):
|
|||
return model
|
||||
|
||||
|
||||
def ExLlamav2_loader(model_name):
|
||||
from modules.exllamav2 import Exllamav2Model
|
||||
def ExLlamav3_HF_loader(model_name):
|
||||
from modules.exllamav3_hf import Exllamav3HF
|
||||
|
||||
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
|
||||
return model, tokenizer
|
||||
return Exllamav3HF.from_pretrained(model_name)
|
||||
|
||||
|
||||
def ExLlamav2_HF_loader(model_name):
|
||||
|
|
@ -317,6 +317,13 @@ def ExLlamav2_HF_loader(model_name):
|
|||
return Exllamav2HF.from_pretrained(model_name)
|
||||
|
||||
|
||||
def ExLlamav2_loader(model_name):
|
||||
from modules.exllamav2 import Exllamav2Model
|
||||
|
||||
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def HQQ_loader(model_name):
|
||||
try:
|
||||
from hqq.core.quantize import HQQBackend, HQQLinear
|
||||
|
|
|
|||
|
|
@ -158,14 +158,14 @@ def infer_loader(model_name, model_settings):
|
|||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if not path_to_model.exists():
|
||||
loader = None
|
||||
elif (path_to_model / 'quantize_config.json').exists(): # Old GPTQ metadata file
|
||||
loader = 'ExLlamav2_HF'
|
||||
elif len(list(path_to_model.glob('*.gguf'))) > 0 and path_to_model.is_dir() and (path_to_model / 'tokenizer_config.json').exists():
|
||||
loader = 'llamacpp_HF'
|
||||
elif len(list(path_to_model.glob('*.gguf'))) > 0:
|
||||
loader = 'llama.cpp'
|
||||
elif re.match(r'.*\.gguf', model_name.lower()):
|
||||
loader = 'llama.cpp'
|
||||
elif re.match(r'.*exl3', model_name.lower()):
|
||||
loader = 'ExLlamav3_HF'
|
||||
elif re.match(r'.*exl2', model_name.lower()):
|
||||
loader = 'ExLlamav2_HF'
|
||||
elif re.match(r'.*-hqq', model_name.lower()):
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft
|
|||
|
||||
# Model loader
|
||||
group = parser.add_argument_group('Model loader')
|
||||
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.')
|
||||
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.')
|
||||
|
||||
# Transformers/Accelerate
|
||||
group = parser.add_argument_group('Transformers/Accelerate')
|
||||
|
|
@ -273,6 +273,8 @@ def fix_loader_name(name):
|
|||
return 'ExLlamav2'
|
||||
elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']:
|
||||
return 'ExLlamav2_HF'
|
||||
elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']:
|
||||
return 'ExLlamav3_HF'
|
||||
elif name in ['hqq']:
|
||||
return 'HQQ'
|
||||
elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue