Add ExLlamaV3 support (#6832)

This commit is contained in:
oobabooga 2025-04-09 00:07:08 -03:00 committed by GitHub
parent 0b3503c91f
commit 8b8d39ec4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 322 additions and 54 deletions

179
modules/exllamav3_hf.py Normal file
View file

@ -0,0 +1,179 @@
import os
import traceback
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from exllamav3 import Cache, Config, Model
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
try:
import flash_attn
except Exception:
logger.warning('Failed to load flash-attention due to the following error:\n')
traceback.print_exc()
class Exllamav3HF(PreTrainedModel):
def __init__(self, model_dir):
super().__init__(PretrainedConfig())
self.generation_config = GenerationConfig()
config = Config.from_directory(model_dir)
self.ex_model = Model.from_config(config)
# Calculate the closest multiple of 256 at or above the chosen value
max_tokens = shared.args.max_seq_len
if max_tokens % 256 != 0:
adjusted_tokens = ((max_tokens // 256) + 1) * 256
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
max_tokens = adjusted_tokens
self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens)
# Create load parameters dictionary
load_params = {'progressbar': True}
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
load_params['use_per_device'] = split
self.ex_model.load(**load_params)
self.past_seq = None
self.max_tokens = max_tokens
def _validate_model_class(self):
pass
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
pass
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {'input_ids': input_ids, **kwargs}
@property
def device(self) -> torch.device:
return torch.device(0)
def __call__(self, *args, **kwargs):
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
past_key_values = kwargs.get('past_key_values', None)
if len(args) > 0:
if not shared.args.cfg_cache:
logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.")
return
input_ids = args[0]
is_negative = True
past_seq = self.past_seq_negative
ex_cache = self.ex_cache_negative
else:
input_ids = kwargs['input_ids']
is_negative = False
past_seq = self.past_seq
ex_cache = self.ex_cache
seq = input_ids[0].tolist()
if is_negative and past_key_values is not None:
seq = past_key_values + seq
seq_tensor = torch.tensor(seq)
reset = True
# Make the forward call
if labels is None:
if past_seq is not None:
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if len(indices) > 0:
longest_prefix = indices[0].item()
else:
longest_prefix = min_length
if longest_prefix > 0:
reset = False
current_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(
input_ids=seq_tensor[longest_prefix:-1].view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": longest_prefix,
"batch_shape": (1, self.max_tokens)
}
)
current_len = longest_prefix + len(seq_tensor) - longest_prefix - 1
if reset:
if len(seq_tensor) > 1:
self.ex_model.forward(
input_ids=seq_tensor[:-1].view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": 0,
"batch_shape": (1, self.max_tokens)
}
)
current_len = len(seq_tensor) - 1
else:
current_len = 0
logits = self.ex_model.forward(
input_ids=seq_tensor[-1:].view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens)
}
).to(input_ids.device).float()
else:
logits = self.ex_model.forward(
input_ids=seq_tensor.view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": 0,
"batch_shape": (1, self.max_tokens)
}
).float()
if is_negative:
self.past_seq_negative = seq_tensor
else:
self.past_seq = seq_tensor
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
return Exllamav3HF(pretrained_model_name_or_path)

View file

@ -23,7 +23,6 @@ loaders_and_params = OrderedDict({
'use_double_quant',
'use_eager_attention',
'bf16',
'trust_remote_code',
'no_use_fast',
],
@ -76,6 +75,13 @@ loaders_and_params = OrderedDict({
'no_use_fast',
'llamacpp_HF_info',
],
'ExLlamav3_HF': [
'max_seq_len',
'gpu_split',
'cfg_cache',
'trust_remote_code',
'no_use_fast',
],
'ExLlamav2_HF': [
'max_seq_len',
'cache_type',
@ -174,30 +180,38 @@ def transformers_samplers():
loaders_samplers = {
'Transformers': transformers_samplers(),
'HQQ': transformers_samplers(),
'ExLlamav2': {
'ExLlamav3_HF': {
'temperature',
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'smoothing_curve',
'min_p',
'top_p',
'top_k',
'typical_p',
'xtc_threshold',
'xtc_probability',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'top_n_sigma',
'dry_multiplier',
'dry_allowed_length',
'dry_base',
'repetition_penalty',
'frequency_penalty',
'presence_penalty',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'repetition_penalty_range',
'guidance_scale',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'do_sample',
'dynamic_temperature',
'temperature_last',
'auto_max_new_tokens',
@ -205,8 +219,12 @@ loaders_samplers = {
'add_bos_token',
'skip_special_tokens',
'seed',
'sampler_priority',
'custom_token_bans',
'negative_prompt',
'dry_sequence_breakers',
'grammar_string',
'grammar_file_row',
},
'ExLlamav2_HF': {
'temperature',
@ -254,6 +272,40 @@ loaders_samplers = {
'grammar_string',
'grammar_file_row',
},
'ExLlamav2': {
'temperature',
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'min_p',
'top_p',
'top_k',
'typical_p',
'xtc_threshold',
'xtc_probability',
'tfs',
'top_a',
'dry_multiplier',
'dry_allowed_length',
'dry_base',
'repetition_penalty',
'frequency_penalty',
'presence_penalty',
'repetition_penalty_range',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'dynamic_temperature',
'temperature_last',
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
'seed',
'custom_token_bans',
'dry_sequence_breakers',
},
'llama.cpp': {
'temperature',
'min_p',

View file

@ -69,8 +69,9 @@ def load_model(model_name, loader=None):
'Transformers': huggingface_loader,
'llama.cpp': llamacpp_loader,
'llamacpp_HF': llamacpp_HF_loader,
'ExLlamav2': ExLlamav2_loader,
'ExLlamav3_HF': ExLlamav3_HF_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ExLlamav2': ExLlamav2_loader,
'HQQ': HQQ_loader,
'TensorRT-LLM': TensorRT_LLM_loader,
}
@ -304,11 +305,10 @@ def llamacpp_HF_loader(model_name):
return model
def ExLlamav2_loader(model_name):
from modules.exllamav2 import Exllamav2Model
def ExLlamav3_HF_loader(model_name):
from modules.exllamav3_hf import Exllamav3HF
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
return model, tokenizer
return Exllamav3HF.from_pretrained(model_name)
def ExLlamav2_HF_loader(model_name):
@ -317,6 +317,13 @@ def ExLlamav2_HF_loader(model_name):
return Exllamav2HF.from_pretrained(model_name)
def ExLlamav2_loader(model_name):
from modules.exllamav2 import Exllamav2Model
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
return model, tokenizer
def HQQ_loader(model_name):
try:
from hqq.core.quantize import HQQBackend, HQQLinear

View file

@ -158,14 +158,14 @@ def infer_loader(model_name, model_settings):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
if not path_to_model.exists():
loader = None
elif (path_to_model / 'quantize_config.json').exists(): # Old GPTQ metadata file
loader = 'ExLlamav2_HF'
elif len(list(path_to_model.glob('*.gguf'))) > 0 and path_to_model.is_dir() and (path_to_model / 'tokenizer_config.json').exists():
loader = 'llamacpp_HF'
elif len(list(path_to_model.glob('*.gguf'))) > 0:
loader = 'llama.cpp'
elif re.match(r'.*\.gguf', model_name.lower()):
loader = 'llama.cpp'
elif re.match(r'.*exl3', model_name.lower()):
loader = 'ExLlamav3_HF'
elif re.match(r'.*exl2', model_name.lower()):
loader = 'ExLlamav2_HF'
elif re.match(r'.*-hqq', model_name.lower()):

View file

@ -86,7 +86,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft
# Model loader
group = parser.add_argument_group('Model loader')
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.')
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.')
# Transformers/Accelerate
group = parser.add_argument_group('Transformers/Accelerate')
@ -273,6 +273,8 @@ def fix_loader_name(name):
return 'ExLlamav2'
elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']:
return 'ExLlamav2_HF'
elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']:
return 'ExLlamav3_HF'
elif name in ['hqq']:
return 'HQQ'
elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']: