Exllamav3: small sampler fixes

This commit is contained in:
oobabooga 2025-08-11 07:35:22 -07:00
parent 4d8dbbab64
commit 4809ddfeb8

View file

@ -6,7 +6,9 @@ import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job
from exllamav3.generator.sampler import (
from modules import shared
from modules.exllamav3_custom_sampler import (
CustomSampler,
SS_Argmax,
SS_MinP,
@ -17,8 +19,6 @@ from exllamav3.generator.sampler import (
SS_TopK,
SS_TopP
)
from modules import shared
from modules.image_utils import (
convert_image_attachments_to_pil,
convert_openai_messages_to_images
@ -194,7 +194,6 @@ class Exllamav3Model:
# Process images and modify prompt (ExLlamaV3-specific)
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
# -- Manually build and sort the sampler stack --
# Greedy decoding is a special case
if state['temperature'] == 0:
sampler = CustomSampler([SS_Argmax()])
@ -205,7 +204,7 @@ class Exllamav3Model:
# Penalties
penalty_range = state['repetition_penalty_range']
if penalty_range <= 0:
penalty_range = -1 # ExllamaV3 uses -1 for whole context
penalty_range = int(10e7) # Use large number for "full context"
rep_decay = 0 # Not a configurable parameter
# Add penalty samplers if they are active
@ -222,7 +221,7 @@ class Exllamav3Model:
if state['min_p'] > 0.0:
unordered_samplers.append(SS_MinP(state['min_p']))
# Temperature
# Temperature (SS_NoOp is returned if temp is 1.0)
unordered_samplers.append(SS_Temperature(state['temperature']))
# 2. Define the mapping from class names to the priority list keys
@ -246,7 +245,7 @@ class Exllamav3Model:
def custom_sort_key(sampler_obj):
class_name = sampler_obj.__class__.__name__
nickname = class_name_to_nickname.get(class_name)
if nickname in sampler_priority:
if nickname and nickname in sampler_priority:
return sampler_priority.index(nickname)
return -1
@ -255,7 +254,6 @@ class Exllamav3Model:
# 5. Add the final sampling stage and build the sampler
ordered_samplers.append(SS_Sample())
sampler = CustomSampler(ordered_samplers)
# -- End of sampler building --
# Encode prompt with embeddings (ExLlamaV3-specific)
input_ids = self.tokenizer.encode(