From 8320190184c285f6a909c11ee040962a498e6e4f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 21 Apr 2025 18:32:23 -0700 Subject: [PATCH] Fix the exllamav2_HF and exllamav3_HF loaders --- modules/sampler_hijack.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index ee871a6e..dfdb6914 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -15,6 +15,9 @@ from modules import shared from modules.logging_colors import logger from modules.torch_utils import get_device +original_init = transformers.GenerationConfig.__init__ +original_get_logits_processor = transformers.GenerationMixin._get_logits_processor + global_scores = None @@ -484,7 +487,7 @@ def get_logits_processor_patch(self, **kwargs): generation_config.temperature = float(generation_config.temperature) # Must be float # Get the original warpers - warpers = self._get_logits_processor_old(**kwargs) + warpers = original_get_logits_processor(self, **kwargs) for i in range(len(warpers) - 1, -1, -1): # Replace temperature with our modified class. @@ -674,7 +677,7 @@ def get_logits_processor_patch(self, **kwargs): def generation_config_init_patch(self, **kwargs): - self.__init___old(**kwargs) + original_init(self, **kwargs) self.min_p = kwargs.pop("min_p", 0.0) self.dynamic_temperature = kwargs.pop("dynamic_temperature", False) self.dynatemp_low = kwargs.pop("dynatemp_low", 1) @@ -702,8 +705,5 @@ def generation_config_init_patch(self, **kwargs): def hijack_samplers(): - transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch - - transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__ transformers.GenerationConfig.__init__ = generation_config_init_patch