This commit is contained in:
oobabooga 2025-08-11 12:32:17 -07:00
parent b10d525bf7
commit 999471256c
2 changed files with 5 additions and 8 deletions

View file

@ -3,6 +3,7 @@ import traceback
from pathlib import Path
import torch
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
@ -15,7 +16,6 @@ from exllamav2 import (
ExLlamaV2Tokenizer
)
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length

View file

@ -2,12 +2,9 @@ import traceback
from pathlib import Path
from typing import Any, List, Tuple
import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job
from modules import shared
from exllamav3.generator.sampler import (
CustomSampler,
SS_Argmax,
@ -19,13 +16,13 @@ from exllamav3.generator.sampler import (
SS_TopK,
SS_TopP
)
from modules import shared
from modules.image_utils import (
convert_image_attachments_to_pil,
convert_openai_messages_to_images
)
from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length
from modules.torch_utils import clear_torch_cache
try:
import flash_attn
@ -205,13 +202,13 @@ class Exllamav3Model:
penalty_range = state['repetition_penalty_range']
if penalty_range <= 0:
penalty_range = int(10e7) # Use large number for "full context"
rep_decay = 0 # Not a configurable parameter
rep_decay = 0 # Not a configurable parameter
# Add penalty samplers if they are active
if state['repetition_penalty'] != 1.0:
unordered_samplers.append(SS_RepP(state['repetition_penalty'], penalty_range, rep_decay))
unordered_samplers.append(SS_RepP(state['repetition_penalty'], penalty_range, rep_decay))
if state['presence_penalty'] != 0.0 or state['frequency_penalty'] != 0.0:
unordered_samplers.append(SS_PresFreqP(state['presence_penalty'], state['frequency_penalty'], penalty_range, rep_decay))
unordered_samplers.append(SS_PresFreqP(state['presence_penalty'], state['frequency_penalty'], penalty_range, rep_decay))
# Standard samplers
if state['top_k'] > 0: