mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-20 15:40:23 +01:00
Exllamav3: small sampler fixes
This commit is contained in:
parent
4d8dbbab64
commit
4809ddfeb8
|
|
@ -6,7 +6,9 @@ import torch
|
|||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from exllamav3.generator import Job
|
||||
from exllamav3.generator.sampler import (
|
||||
|
||||
from modules import shared
|
||||
from modules.exllamav3_custom_sampler import (
|
||||
CustomSampler,
|
||||
SS_Argmax,
|
||||
SS_MinP,
|
||||
|
|
@ -17,8 +19,6 @@ from exllamav3.generator.sampler import (
|
|||
SS_TopK,
|
||||
SS_TopP
|
||||
)
|
||||
|
||||
from modules import shared
|
||||
from modules.image_utils import (
|
||||
convert_image_attachments_to_pil,
|
||||
convert_openai_messages_to_images
|
||||
|
|
@ -194,7 +194,6 @@ class Exllamav3Model:
|
|||
# Process images and modify prompt (ExLlamaV3-specific)
|
||||
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
|
||||
|
||||
# -- Manually build and sort the sampler stack --
|
||||
# Greedy decoding is a special case
|
||||
if state['temperature'] == 0:
|
||||
sampler = CustomSampler([SS_Argmax()])
|
||||
|
|
@ -205,7 +204,7 @@ class Exllamav3Model:
|
|||
# Penalties
|
||||
penalty_range = state['repetition_penalty_range']
|
||||
if penalty_range <= 0:
|
||||
penalty_range = -1 # ExllamaV3 uses -1 for whole context
|
||||
penalty_range = int(10e7) # Use large number for "full context"
|
||||
rep_decay = 0 # Not a configurable parameter
|
||||
|
||||
# Add penalty samplers if they are active
|
||||
|
|
@ -222,7 +221,7 @@ class Exllamav3Model:
|
|||
if state['min_p'] > 0.0:
|
||||
unordered_samplers.append(SS_MinP(state['min_p']))
|
||||
|
||||
# Temperature
|
||||
# Temperature (SS_NoOp is returned if temp is 1.0)
|
||||
unordered_samplers.append(SS_Temperature(state['temperature']))
|
||||
|
||||
# 2. Define the mapping from class names to the priority list keys
|
||||
|
|
@ -246,7 +245,7 @@ class Exllamav3Model:
|
|||
def custom_sort_key(sampler_obj):
|
||||
class_name = sampler_obj.__class__.__name__
|
||||
nickname = class_name_to_nickname.get(class_name)
|
||||
if nickname in sampler_priority:
|
||||
if nickname and nickname in sampler_priority:
|
||||
return sampler_priority.index(nickname)
|
||||
return -1
|
||||
|
||||
|
|
@ -255,7 +254,6 @@ class Exllamav3Model:
|
|||
# 5. Add the final sampling stage and build the sampler
|
||||
ordered_samplers.append(SS_Sample())
|
||||
sampler = CustomSampler(ordered_samplers)
|
||||
# -- End of sampler building --
|
||||
|
||||
# Encode prompt with embeddings (ExLlamaV3-specific)
|
||||
input_ids = self.tokenizer.encode(
|
||||
|
|
|
|||
Loading…
Reference in a new issue