This commit is contained in:
oobabooga 2025-08-11 12:32:17 -07:00
parent b10d525bf7
commit 999471256c
2 changed files with 5 additions and 8 deletions

View file

@ -3,6 +3,7 @@ import traceback
from pathlib import Path from pathlib import Path
import torch import torch
from exllamav2 import ( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
ExLlamaV2Cache, ExLlamaV2Cache,
@ -15,7 +16,6 @@ from exllamav2 import (
ExLlamaV2Tokenizer ExLlamaV2Tokenizer
) )
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length from modules.text_generation import get_max_prompt_length

View file

@ -2,12 +2,9 @@ import traceback
from pathlib import Path from pathlib import Path
from typing import Any, List, Tuple from typing import Any, List, Tuple
import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job from exllamav3.generator import Job
from modules import shared
from exllamav3.generator.sampler import ( from exllamav3.generator.sampler import (
CustomSampler, CustomSampler,
SS_Argmax, SS_Argmax,
@ -19,13 +16,13 @@ from exllamav3.generator.sampler import (
SS_TopK, SS_TopK,
SS_TopP SS_TopP
) )
from modules import shared
from modules.image_utils import ( from modules.image_utils import (
convert_image_attachments_to_pil, convert_image_attachments_to_pil,
convert_openai_messages_to_images convert_openai_messages_to_images
) )
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length from modules.text_generation import get_max_prompt_length
from modules.torch_utils import clear_torch_cache
try: try:
import flash_attn import flash_attn
@ -205,13 +202,13 @@ class Exllamav3Model:
penalty_range = state['repetition_penalty_range'] penalty_range = state['repetition_penalty_range']
if penalty_range <= 0: if penalty_range <= 0:
penalty_range = int(10e7) # Use large number for "full context" penalty_range = int(10e7) # Use large number for "full context"
rep_decay = 0 # Not a configurable parameter rep_decay = 0 # Not a configurable parameter
# Add penalty samplers if they are active # Add penalty samplers if they are active
if state['repetition_penalty'] != 1.0: if state['repetition_penalty'] != 1.0:
unordered_samplers.append(SS_RepP(state['repetition_penalty'], penalty_range, rep_decay)) unordered_samplers.append(SS_RepP(state['repetition_penalty'], penalty_range, rep_decay))
if state['presence_penalty'] != 0.0 or state['frequency_penalty'] != 0.0: if state['presence_penalty'] != 0.0 or state['frequency_penalty'] != 0.0:
unordered_samplers.append(SS_PresFreqP(state['presence_penalty'], state['frequency_penalty'], penalty_range, rep_decay)) unordered_samplers.append(SS_PresFreqP(state['presence_penalty'], state['frequency_penalty'], penalty_range, rep_decay))
# Standard samplers # Standard samplers
if state['top_k'] > 0: if state['top_k'] > 0: