diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 6486e438..b159d9ce 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -4,10 +4,6 @@ from pathlib import Path from typing import Any, Dict, Optional, Union import torch -from torch.nn import CrossEntropyLoss -from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithPast - from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, @@ -18,6 +14,15 @@ from exllamav2 import ( ExLlamaV2Cache_TP, ExLlamaV2Config ) +from torch.nn import CrossEntropyLoss +from transformers import ( + GenerationConfig, + GenerationMixin, + PretrainedConfig, + PreTrainedModel +) +from transformers.modeling_outputs import CausalLMOutputWithPast + from modules import shared from modules.logging_colors import logger @@ -28,7 +33,7 @@ except Exception: traceback.print_exc() -class Exllamav2HF(PreTrainedModel): +class Exllamav2HF(PreTrainedModel, GenerationMixin): def __init__(self, config: ExLlamaV2Config): super().__init__(PretrainedConfig()) self.ex_config = config diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index 0f742fa2..2d9c493a 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -6,7 +6,12 @@ from typing import Any, Dict, Optional, Union import torch from exllamav3 import Cache, Config, Model from torch.nn import CrossEntropyLoss -from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel +from transformers import ( + GenerationConfig, + GenerationMixin, + PretrainedConfig, + PreTrainedModel +) from transformers.modeling_outputs import CausalLMOutputWithPast from modules import shared @@ -19,7 +24,7 @@ except Exception: traceback.print_exc() -class Exllamav3HF(PreTrainedModel): +class Exllamav3HF(PreTrainedModel, GenerationMixin): def __init__(self, model_dir): super().__init__(PretrainedConfig()) self.generation_config = GenerationConfig() diff --git a/modules/models.py b/modules/models.py index c4dfa149..2c23462a 100644 --- a/modules/models.py +++ b/modules/models.py @@ -1,3 +1,4 @@ +import sys import time from pathlib import Path @@ -34,6 +35,10 @@ def load_model(model_name, loader=None): logger.error('The path to the model does not exist. Exiting.') raise ValueError + if loader != 'llama.cpp' and 'sampler_hijack' not in sys.modules: + from modules import sampler_hijack + sampler_hijack.hijack_samplers() + shared.args.loader = loader output = load_func_map[loader](model_name) if type(output) is tuple: diff --git a/modules/transformers_loader.py b/modules/transformers_loader.py index 5512f061..add3be66 100644 --- a/modules/transformers_loader.py +++ b/modules/transformers_loader.py @@ -22,13 +22,11 @@ from transformers import ( ) import modules.shared as shared -from modules import sampler_hijack from modules.logging_colors import logger from modules.text_generation import get_reply_from_output_ids from modules.torch_utils import get_device transformers.logging.set_verbosity_error() -sampler_hijack.hijack_samplers() local_rank = None if shared.args.deepspeed: