text-generation-webui/modules/exllamav3.py
2025-08-08 21:22:47 -07:00

357 lines
14 KiB
Python

import traceback
from pathlib import Path
from typing import Any, List, Tuple
import torch
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from exllamav3.generator import Job
# Import the base sampler components directly from exllamav3
from exllamav3.generator.sampler import (
CustomSampler,
SS_Argmax,
SS_MinP,
SS_PresFreqP,
SS_RepP,
SS_Sample,
SS_Temperature,
SS_TopK,
SS_TopP
)
from extensions.openai.image_utils import (
convert_image_attachments_to_pil,
convert_openai_messages_to_images
)
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length
from modules.torch_utils import clear_torch_cache
try:
import flash_attn
except Exception:
logger.warning('Failed to load flash-attention due to the following error:\n')
traceback.print_exc()
class Exllamav3Model:
def __init__(self):
pass
@classmethod
def from_pretrained(cls, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
# Reset global MMTokenAllocator to prevent token ID corruption when switching models
from exllamav3.tokenizer.mm_embedding import (
FIRST_MM_EMBEDDING_INDEX,
global_allocator
)
global_allocator.next_token_index = FIRST_MM_EMBEDDING_INDEX
logger.info("Reset MMTokenAllocator for clean multimodal token allocation")
config = Config.from_directory(str(path_to_model))
model = Model.from_config(config)
# Calculate the closest multiple of 256 at or above the chosen value
max_tokens = shared.args.ctx_size
if max_tokens % 256 != 0:
adjusted_tokens = ((max_tokens // 256) + 1) * 256
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
max_tokens = adjusted_tokens
# Parse cache type (ExLlamaV2 pattern)
cache_type = shared.args.cache_type.lower()
cache_kwargs = {}
if cache_type == 'fp16':
layer_type = CacheLayer_fp16
elif cache_type.startswith('q'):
layer_type = CacheLayer_quant
if '_' in cache_type:
# Different bits for k and v (e.g., q4_q8)
k_part, v_part = cache_type.split('_')
k_bits = int(k_part[1:])
v_bits = int(v_part[1:])
else:
# Same bits for k and v (e.g., q4)
k_bits = v_bits = int(cache_type[1:])
# Validate bit ranges
if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8):
logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.")
layer_type = CacheLayer_fp16
else:
cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits}
else:
logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.")
layer_type = CacheLayer_fp16
cache = Cache(model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
load_params = {'progressbar': True}
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
load_params['use_per_device'] = split
model.load(**load_params)
tokenizer = Tokenizer.from_config(config)
# Load vision model component (ExLlamaV3 native)
vision_model = None
try:
logger.info("Loading vision model component...")
vision_model = Model.from_config(config, component="vision")
vision_model.load(progressbar=True)
logger.info("Vision model loaded successfully")
except Exception as e:
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
generator = Generator(
model=model,
cache=cache,
tokenizer=tokenizer,
)
result = cls()
result.model = model
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
result.config = config
result.max_tokens = max_tokens
result.vision_model = vision_model
return result
def is_multimodal(self) -> bool:
"""Check if this model supports multimodal input."""
return hasattr(self, 'vision_model') and self.vision_model is not None
def _process_images_for_generation(self, prompt: str, state: dict) -> Tuple[str, List[Any]]:
"""
Process all possible image inputs and return modified prompt + embeddings.
Returns: (processed_prompt, image_embeddings)
"""
if not self.is_multimodal():
return prompt, []
# Collect images from various sources using shared utilities
pil_images = []
# From webui image_attachments (preferred format)
if 'image_attachments' in state and state['image_attachments']:
pil_images.extend(convert_image_attachments_to_pil(state['image_attachments']))
# From OpenAI API raw_images
elif 'raw_images' in state and state['raw_images']:
pil_images.extend(state['raw_images'])
# From OpenAI API messages format
elif 'messages' in state and state['messages']:
pil_images.extend(convert_openai_messages_to_images(state['messages']))
if not pil_images:
return prompt, []
# ExLlamaV3-specific: Generate embeddings
try:
# Use pre-computed embeddings if available (proper MMEmbedding lifetime)
if 'image_embeddings' in state and state['image_embeddings']:
# Use existing embeddings - this preserves MMEmbedding lifetime
image_embeddings = state['image_embeddings']
else:
# Do not reset the cache/allocator index; it causes token ID conflicts during generation.
logger.info(f"Processing {len(pil_images)} image(s) with ExLlamaV3 vision model")
image_embeddings = [
self.vision_model.get_image_embeddings(tokenizer=self.tokenizer, image=img)
for img in pil_images
]
# ExLlamaV3-specific: Handle prompt processing with placeholders
placeholders = [ie.text_alias for ie in image_embeddings]
if '<__media__>' in prompt:
# Web chat: Replace <__media__> placeholders
for alias in placeholders:
prompt = prompt.replace('<__media__>', alias, 1)
logger.info(f"Replaced {len(placeholders)} <__media__> placeholder(s)")
else:
# API: Prepend embedding aliases
combined_placeholders = "\n".join(placeholders)
prompt = combined_placeholders + "\n" + prompt
logger.info(f"Prepended {len(placeholders)} embedding(s) to prompt")
return prompt, image_embeddings
except Exception as e:
logger.error(f"Failed to process images: {e}")
return prompt, []
def generate_with_streaming(self, prompt, state):
"""
Generate text with streaming using native ExLlamaV3 API
"""
# Process images and modify prompt (ExLlamaV3-specific)
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
# -- Manually build and sort the sampler stack --
# Greedy decoding is a special case
if state['temperature'] == 0:
sampler = CustomSampler([SS_Argmax()])
else:
# 1. Create a list of all active, unordered samplers
unordered_samplers = []
# Penalties
penalty_range = state['repetition_penalty_range']
if penalty_range <= 0:
penalty_range = -1 # ExllamaV3 uses -1 for whole context
rep_decay = 0 # Not a configurable parameter
# Add penalty samplers if they are active
if state['repetition_penalty'] != 1.0:
unordered_samplers.append(SS_RepP(state['repetition_penalty'], penalty_range, rep_decay))
if state['presence_penalty'] != 0.0 or state['frequency_penalty'] != 0.0:
unordered_samplers.append(SS_PresFreqP(state['presence_penalty'], state['frequency_penalty'], penalty_range, rep_decay))
# Standard samplers
if state['top_k'] > 0:
unordered_samplers.append(SS_TopK(state['top_k']))
if state['top_p'] < 1.0:
unordered_samplers.append(SS_TopP(state['top_p']))
if state['min_p'] > 0.0:
unordered_samplers.append(SS_MinP(state['min_p']))
# Temperature
unordered_samplers.append(SS_Temperature(state['temperature']))
# 2. Define the mapping from class names to the priority list keys
class_name_to_nickname = {
'SS_RepP': 'repetition_penalty',
'SS_PresFreqP': 'presence_frequency_penalty',
'SS_TopK': 'top_k',
'SS_TopP': 'top_p',
'SS_MinP': 'min_p',
'SS_Temperature': 'temperature',
}
# 3. Get the priority list and handle temperature_last
default_priority = ['repetition_penalty', 'presence_frequency_penalty', 'top_k', 'top_p', 'min_p', 'temperature']
sampler_priority = state.get('sampler_priority', default_priority)
if state['temperature_last'] and 'temperature' in sampler_priority:
sampler_priority.append(sampler_priority.pop(sampler_priority.index('temperature')))
# 4. Sort the unordered list based on the priority list
def custom_sort_key(sampler_obj):
class_name = sampler_obj.__class__.__name__
nickname = class_name_to_nickname.get(class_name)
if nickname in sampler_priority:
return sampler_priority.index(nickname)
return -1
ordered_samplers = sorted(unordered_samplers, key=custom_sort_key)
# 5. Add the final sampling stage and build the sampler
ordered_samplers.append(SS_Sample())
sampler = CustomSampler(ordered_samplers)
# -- End of sampler building --
# Encode prompt with embeddings (ExLlamaV3-specific)
input_ids = self.tokenizer.encode(
prompt,
add_bos=state['add_bos_token'],
encode_special_tokens=True,
embeddings=image_embeddings,
)
input_ids = input_ids[:, -get_max_prompt_length(state):]
# Determine max_new_tokens
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
else:
max_new_tokens = state['max_new_tokens']
# Get stop conditions
stop_conditions = []
if not state['ban_eos_token']:
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
stop_conditions.append(self.tokenizer.eos_token_id)
job = Job(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
decode_special_tokens=not state['skip_special_tokens'],
embeddings=image_embeddings if image_embeddings else None,
sampler=sampler,
stop_conditions=stop_conditions if stop_conditions else None,
)
# Stream generation
self.generator.enqueue(job)
response_text = ""
try:
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
for result in results:
if "eos" in result and result["eos"]:
break
chunk = result.get("text", "")
if chunk:
response_text += chunk
yield response_text
finally:
# No cleanup needed. MMEmbedding lifetime is managed by Python.
# Cache and page table resets are unnecessary and can cause token ID conflicts.
pass
def generate(self, prompt, state):
output = ""
for chunk in self.generate_with_streaming(prompt, state):
output = chunk
return output
def encode(self, string, **kwargs):
# Default add_bos to True for consistency with exllamav2 behavior
add_bos = kwargs.pop('add_bos', True)
return self.tokenizer.encode(string, add_bos=add_bos, **kwargs)
def decode(self, ids, **kwargs):
return self.tokenizer.decode(ids, **kwargs)
@property
def last_prompt_token_count(self):
# This would need to be tracked during generation
return 0
def unload(self):
logger.info("Unloading ExLlamaV3 model components...")
if hasattr(self, 'vision_model') and self.vision_model is not None:
try:
del self.vision_model
except Exception as e:
logger.warning(f"Error unloading vision model: {e}")
self.vision_model = None
if hasattr(self, 'model') and self.model is not None:
try:
self.model.unload()
del self.model
except Exception as e:
logger.warning(f"Error unloading main model: {e}")
self.model = None
if hasattr(self, 'cache') and self.cache is not None:
self.cache = None
if hasattr(self, 'generator') and self.generator is not None:
self.generator = None
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
self.tokenizer = None