Fix ExLlamaV2_HF and ExLlamaV3_HF after ae02ffc605

This commit is contained in:
oobabooga 2025-04-20 11:32:48 -07:00
parent 9c59acf820
commit b3bf7a885d
4 changed files with 22 additions and 9 deletions

View file

@ -4,10 +4,6 @@ from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
@ -18,6 +14,15 @@ from exllamav2 import (
ExLlamaV2Cache_TP,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss
from transformers import (
GenerationConfig,
GenerationMixin,
PretrainedConfig,
PreTrainedModel
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
from modules.logging_colors import logger
@ -28,7 +33,7 @@ except Exception:
traceback.print_exc()
class Exllamav2HF(PreTrainedModel):
class Exllamav2HF(PreTrainedModel, GenerationMixin):
def __init__(self, config: ExLlamaV2Config):
super().__init__(PretrainedConfig())
self.ex_config = config

View file

@ -6,7 +6,12 @@ from typing import Any, Dict, Optional, Union
import torch
from exllamav3 import Cache, Config, Model
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers import (
GenerationConfig,
GenerationMixin,
PretrainedConfig,
PreTrainedModel
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared
@ -19,7 +24,7 @@ except Exception:
traceback.print_exc()
class Exllamav3HF(PreTrainedModel):
class Exllamav3HF(PreTrainedModel, GenerationMixin):
def __init__(self, model_dir):
super().__init__(PretrainedConfig())
self.generation_config = GenerationConfig()

View file

@ -1,3 +1,4 @@
import sys
import time
from pathlib import Path
@ -34,6 +35,10 @@ def load_model(model_name, loader=None):
logger.error('The path to the model does not exist. Exiting.')
raise ValueError
if loader != 'llama.cpp' and 'sampler_hijack' not in sys.modules:
from modules import sampler_hijack
sampler_hijack.hijack_samplers()
shared.args.loader = loader
output = load_func_map[loader](model_name)
if type(output) is tuple:

View file

@ -22,13 +22,11 @@ from transformers import (
)
import modules.shared as shared
from modules import sampler_hijack
from modules.logging_colors import logger
from modules.text_generation import get_reply_from_output_ids
from modules.torch_utils import get_device
transformers.logging.set_verbosity_error()
sampler_hijack.hijack_samplers()
local_rank = None
if shared.args.deepspeed: