From 4809ddfeb85e8b8d28bb617366c86fd8037815ee Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 11 Aug 2025 07:35:22 -0700 Subject: [PATCH] Exllamav3: small sampler fixes --- modules/exllamav3.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/modules/exllamav3.py b/modules/exllamav3.py index 8f686669..5c142ec2 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -6,7 +6,9 @@ import torch from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.generator import Job -from exllamav3.generator.sampler import ( + +from modules import shared +from modules.exllamav3_custom_sampler import ( CustomSampler, SS_Argmax, SS_MinP, @@ -17,8 +19,6 @@ from exllamav3.generator.sampler import ( SS_TopK, SS_TopP ) - -from modules import shared from modules.image_utils import ( convert_image_attachments_to_pil, convert_openai_messages_to_images @@ -194,7 +194,6 @@ class Exllamav3Model: # Process images and modify prompt (ExLlamaV3-specific) prompt, image_embeddings = self._process_images_for_generation(prompt, state) - # -- Manually build and sort the sampler stack -- # Greedy decoding is a special case if state['temperature'] == 0: sampler = CustomSampler([SS_Argmax()]) @@ -205,7 +204,7 @@ class Exllamav3Model: # Penalties penalty_range = state['repetition_penalty_range'] if penalty_range <= 0: - penalty_range = -1 # ExllamaV3 uses -1 for whole context + penalty_range = int(10e7) # Use large number for "full context" rep_decay = 0 # Not a configurable parameter # Add penalty samplers if they are active @@ -222,7 +221,7 @@ class Exllamav3Model: if state['min_p'] > 0.0: unordered_samplers.append(SS_MinP(state['min_p'])) - # Temperature + # Temperature (SS_NoOp is returned if temp is 1.0) unordered_samplers.append(SS_Temperature(state['temperature'])) # 2. Define the mapping from class names to the priority list keys @@ -246,7 +245,7 @@ class Exllamav3Model: def custom_sort_key(sampler_obj): class_name = sampler_obj.__class__.__name__ nickname = class_name_to_nickname.get(class_name) - if nickname in sampler_priority: + if nickname and nickname in sampler_priority: return sampler_priority.index(nickname) return -1 @@ -255,7 +254,6 @@ class Exllamav3Model: # 5. Add the final sampling stage and build the sampler ordered_samplers.append(SS_Sample()) sampler = CustomSampler(ordered_samplers) - # -- End of sampler building -- # Encode prompt with embeddings (ExLlamaV3-specific) input_ids = self.tokenizer.encode(