Add multimodal support (ExLlamaV3) (#7174)

This commit is contained in:
Katehuuh 2025-08-09 04:31:16 +02:00 committed by GitHub
parent b391ac8eb1
commit 88127f46c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 726 additions and 55 deletions

View file

@ -1577,6 +1577,19 @@ strong {
margin-top: 4px;
}
.image-attachment {
flex-direction: column;
}
.image-preview {
border-radius: 16px;
margin-bottom: 5px;
object-fit: cover;
object-position: center;
border: 2px solid var(--border-color-primary);
aspect-ratio: 1 / 1;
}
button:focus {
outline: none;
}

View file

@ -77,6 +77,24 @@ curl http://127.0.0.1:5000/v1/chat/completions \
}'
```
#### Multimodal support (ExLlamaV3)
```shell
curl http://127.0.0.1:5000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What color is this image?"},
{"type": "image_url", "image_url": {"url": "https://github.com/turboderp-org/exllamav3/blob/master/examples/media/cat.png?raw=true"}}
]
}
]
}'
```
#### SSE streaming
```shell

View file

@ -7,6 +7,7 @@ import tiktoken
from pydantic import ValidationError
from extensions.openai.errors import InvalidRequestError
from extensions.openai.image_utils import convert_openai_messages_to_images
from extensions.openai.typing import ToolDefinition
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
from modules import shared
@ -16,6 +17,7 @@ from modules.chat import (
load_character_memoized,
load_instruction_template_memoized
)
from modules.logging_colors import logger
from modules.presets import load_preset_memoized
from modules.text_generation import decode, encode, generate_reply
@ -82,6 +84,21 @@ def process_parameters(body, is_legacy=False):
return generate_params
def process_multimodal_content(content):
"""Extract text from OpenAI multimodal format for non-multimodal models"""
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict) and item.get('type') == 'text':
text_parts.append(item.get('text', ''))
return ' '.join(text_parts) if text_parts else str(content)
return str(content)
def convert_history(history):
'''
Chat histories in this program are in the format [message, reply].
@ -99,8 +116,11 @@ def convert_history(history):
role = entry["role"]
if role == "user":
# Extract text content (images handled by model-specific code)
content = process_multimodal_content(content)
user_input = content
user_input_last = True
if current_message:
chat_dialogue.append([current_message, '', ''])
current_message = ""
@ -126,7 +146,11 @@ def convert_history(history):
if not user_input_last:
user_input = ""
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
return user_input, system_message, {
'internal': chat_dialogue,
'visible': copy.deepcopy(chat_dialogue),
'messages': history # Store original messages for multimodal models
}
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
@ -150,9 +174,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
elif m['role'] == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
if 'content' not in m and "image_url" not in m:
# Handle multimodal content validation
content = m.get('content')
if content is None:
raise InvalidRequestError(message="messages: missing content", param='messages')
# Validate multimodal content structure
if isinstance(content, list):
for item in content:
if not isinstance(item, dict) or 'type' not in item:
raise InvalidRequestError(message="messages: invalid content item format", param='messages')
if item['type'] not in ['text', 'image_url']:
raise InvalidRequestError(message="messages: unsupported content type", param='messages')
if item['type'] == 'text' and 'text' not in item:
raise InvalidRequestError(message="messages: missing text in content item", param='messages')
if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']):
raise InvalidRequestError(message="messages: missing image_url in content item", param='messages')
# Chat Completions
object_type = 'chat.completion' if not stream else 'chat.completion.chunk'
created_time = int(time.time())
@ -336,9 +374,26 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
prompt_str = 'context' if is_legacy else 'prompt'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
# Handle both prompt and messages format for unified multimodal support
if prompt_str not in body or body[prompt_str] is None:
if 'messages' in body:
# Convert messages format to prompt for completions endpoint
prompt_text = ""
for message in body.get('messages', []):
if isinstance(message, dict) and 'content' in message:
# Extract text content from multimodal messages
content = message['content']
if isinstance(content, str):
prompt_text += content
elif isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get('type') == 'text':
prompt_text += item.get('text', '')
# Allow empty prompts for image-only requests
body[prompt_str] = prompt_text
else:
raise InvalidRequestError("Missing required input", param=prompt_str)
# common params
generate_params = process_parameters(body, is_legacy=is_legacy)
@ -349,9 +404,18 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
suffix = body['suffix'] if body['suffix'] else ''
echo = body['echo']
# Add messages to generate_params if present for multimodal processing
if 'messages' in body:
generate_params['messages'] = body['messages']
if not stream:
prompt_arg = body[prompt_str]
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
# Handle empty/None prompts (e.g., image-only requests)
if prompt_arg is None:
prompt_arg = ""
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and len(prompt_arg) > 0 and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
resp_list_data = []
@ -374,7 +438,19 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
# Use multimodal generation if images are present
if 'messages' in generate_params:
raw_images = convert_openai_messages_to_images(generate_params['messages'])
if raw_images:
logger.info(f"Using multimodal generation for {len(raw_images)} images")
generate_params['raw_images'] = raw_images
generator = shared.model.generate_with_streaming(prompt, generate_params)
else:
generator = generate_reply(prompt, generate_params, is_chat=False)
else:
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
@ -447,7 +523,17 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
# Use multimodal generation if images are present
if 'messages' in generate_params:
raw_images = convert_openai_messages_to_images(generate_params['messages'])
if raw_images:
logger.info(f"Using multimodal generation for {len(raw_images)} images")
generate_params['raw_images'] = raw_images
generator = shared.model.generate_with_streaming(prompt, generate_params)
else:
generator = generate_reply(prompt, generate_params, is_chat=False)
else:
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
seen_content = ''

View file

@ -0,0 +1,97 @@
"""
Shared image processing utilities for multimodal support.
Used by both ExLlamaV3 and llama.cpp implementations.
"""
import base64
import io
from typing import Any, List, Tuple
from PIL import Image
from modules.logging_colors import logger
def decode_base64_image(base64_string: str) -> Image.Image:
"""Decodes a base64 string to a PIL Image."""
try:
if base64_string.startswith('data:image/'):
base64_string = base64_string.split(',', 1)[1]
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
logger.error(f"Failed to decode base64 image: {e}")
raise ValueError(f"Invalid base64 image data: {e}")
def process_message_content(content: Any) -> Tuple[str, List[Image.Image]]:
"""
Processes message content that may contain text and images.
Returns: A tuple of (text_content, list_of_pil_images).
"""
if isinstance(content, str):
return content, []
if isinstance(content, list):
text_parts = []
images = []
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get('type', '')
if item_type == 'text':
text_parts.append(item.get('text', ''))
elif item_type == 'image_url':
image_url_data = item.get('image_url', {})
image_url = image_url_data.get('url', '')
if image_url.startswith('data:image/'):
try:
images.append(decode_base64_image(image_url))
except Exception as e:
logger.warning(f"Failed to process a base64 image: {e}")
elif image_url.startswith('http'):
# Support external URLs
try:
import requests
response = requests.get(image_url, timeout=10)
response.raise_for_status()
image_data = response.content
image = Image.open(io.BytesIO(image_data))
images.append(image)
logger.info("Successfully loaded external image from URL")
except Exception as e:
logger.warning(f"Failed to fetch external image: {e}")
else:
logger.warning(f"Unsupported image URL format: {image_url[:70]}...")
return ' '.join(text_parts), images
return str(content), []
def convert_image_attachments_to_pil(image_attachments: List[dict]) -> List[Image.Image]:
"""Convert webui image_attachments format to PIL Images."""
pil_images = []
for attachment in image_attachments:
if attachment.get('type') == 'image' and 'image_data' in attachment:
try:
image = decode_base64_image(attachment['image_data'])
if image.mode != 'RGB':
image = image.convert('RGB')
pil_images.append(image)
except Exception as e:
logger.warning(f"Failed to process image attachment: {e}")
return pil_images
def convert_openai_messages_to_images(messages: List[dict]) -> List[Image.Image]:
"""Convert OpenAI messages format to PIL Images."""
all_images = []
for message in messages:
if isinstance(message, dict) and 'content' in message:
_, images = process_message_content(message['content'])
all_images.extend(images)
return all_images

View file

@ -2,7 +2,7 @@ import json
import time
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator, validator
class GenerationOptions(BaseModel):
@ -99,7 +99,8 @@ class ToolCall(BaseModel):
class CompletionRequestParams(BaseModel):
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
prompt: str | List[str]
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.")
best_of: int | None = Field(default=1, description="Unused parameter.")
echo: bool | None = False
frequency_penalty: float | None = 0
@ -115,6 +116,17 @@ class CompletionRequestParams(BaseModel):
top_p: float | None = 1
user: str | None = Field(default=None, description="Unused parameter.")
@field_validator('prompt', 'messages')
@classmethod
def validate_prompt_or_messages(cls, v, info):
"""Ensure either 'prompt' or 'messages' is provided for completions."""
if info.field_name == 'prompt': # If we're validating 'prompt', check if neither prompt nor messages will be set
messages = info.data.get('messages')
if v is None and messages is None:
raise ValueError("Either 'prompt' or 'messages' must be provided")
return v
class CompletionRequest(GenerationOptions, CompletionRequestParams):
pass

View file

@ -271,16 +271,27 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Add attachment content if present AND if past attachments are enabled
if (state.get('include_past_attachments', True) and user_key in metadata and "attachments" in metadata[user_key]):
attachments_text = ""
for attachment in metadata[user_key]["attachments"]:
filename = attachment.get("name", "file")
content = attachment.get("content", "")
if attachment.get("type") == "text/html" and attachment.get("url"):
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
else:
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
image_refs = ""
if attachments_text:
enhanced_user_msg = f"{user_msg}\n\nATTACHMENTS:\n{attachments_text}"
for attachment in metadata[user_key]["attachments"]:
if attachment.get("type") == "image":
# Add image reference for multimodal models
image_refs += "<__media__>"
else:
# Handle text/PDF attachments
filename = attachment.get("name", "file")
content = attachment.get("content", "")
if attachment.get("type") == "text/html" and attachment.get("url"):
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
else:
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
if image_refs or attachments_text:
enhanced_user_msg = user_msg
if image_refs:
enhanced_user_msg += f" {image_refs}"
if attachments_text:
enhanced_user_msg += f"\n\nATTACHMENTS:\n{attachments_text}"
messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg})
@ -301,16 +312,23 @@ def generate_chat_prompt(user_input, state, **kwargs):
if user_key in metadata and "attachments" in metadata[user_key]:
attachments_text = ""
for attachment in metadata[user_key]["attachments"]:
filename = attachment.get("name", "file")
content = attachment.get("content", "")
if attachment.get("type") == "text/html" and attachment.get("url"):
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
else:
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
image_refs = ""
if attachments_text:
user_input = f"{user_input}\n\nATTACHMENTS:\n{attachments_text}"
for attachment in metadata[user_key]["attachments"]:
if attachment.get("type") == "image":
image_refs += "<__media__>"
else:
filename = attachment.get("name", "file")
content = attachment.get("content", "")
if attachment.get("type") == "text/html" and attachment.get("url"):
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
else:
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
if image_refs or attachments_text:
user_input = f"{user_input} {image_refs}"
if attachments_text:
user_input += f"\n\nATTACHMENTS:\n{attachments_text}"
messages.append({"role": "user", "content": user_input})
@ -594,29 +612,64 @@ def add_message_attachment(history, row_idx, file_path, is_user=True):
file_extension = path.suffix.lower()
try:
# Handle different file types
if file_extension == '.pdf':
# Handle image files
if file_extension in ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']:
# Convert image to base64
with open(path, 'rb') as f:
image_data = base64.b64encode(f.read()).decode('utf-8')
# Determine MIME type from extension
mime_type_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.webp': 'image/webp',
'.bmp': 'image/bmp',
'.gif': 'image/gif'
}
mime_type = mime_type_map.get(file_extension, 'image/jpeg')
# Format as data URL
data_url = f"data:{mime_type};base64,{image_data}"
# Generate unique image ID
image_id = len([att for att in history['metadata'][key]["attachments"] if att.get("type") == "image"]) + 1
attachment = {
"name": filename,
"type": "image",
"image_data": data_url,
"image_id": image_id,
"file_path": str(path) # For UI preview
}
elif file_extension == '.pdf':
# Process PDF file
content = extract_pdf_text(path)
file_type = "application/pdf"
attachment = {
"name": filename,
"type": "application/pdf",
"content": content,
}
elif file_extension == '.docx':
content = extract_docx_text(path)
file_type = "application/docx"
attachment = {
"name": filename,
"type": "application/docx",
"content": content,
}
else:
# Default handling for text files
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
file_type = "text/plain"
# Add attachment
attachment = {
"name": filename,
"type": file_type,
"content": content,
}
attachment = {
"name": filename,
"type": "text/plain",
"content": content,
}
history['metadata'][key]["attachments"].append(attachment)
return content # Return the content for reuse
return attachment # Return the attachment for reuse
except Exception as e:
logger.error(f"Error processing attachment {filename}: {e}")
return None
@ -759,6 +812,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
for file_path in files:
add_message_attachment(output, row_idx, file_path, is_user=True)
# Collect image attachments for ExLlamaV3
image_attachments = []
if 'metadata' in output:
user_key = f"user_{row_idx}"
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
for attachment in output['metadata'][user_key]["attachments"]:
if attachment.get("type") == "image":
image_attachments.append(attachment)
# Add image attachments to state for the generation
if image_attachments:
state['image_attachments'] = image_attachments
# Add web search results as attachments if enabled
if state.get('enable_web_search', False):
search_query = generate_search_query(text, state)

313
modules/exllamav3.py Normal file
View file

@ -0,0 +1,313 @@
import traceback
from pathlib import Path
from typing import Any, List, Tuple
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from extensions.openai.image_utils import (
convert_image_attachments_to_pil,
convert_openai_messages_to_images
)
from modules import shared
from modules.logging_colors import logger
try:
import flash_attn
except Exception:
logger.warning('Failed to load flash-attention due to the following error:\n')
traceback.print_exc()
class Exllamav3Model:
def __init__(self):
pass
@classmethod
def from_pretrained(cls, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
# Reset global MMTokenAllocator to prevent token ID corruption when switching models
from exllamav3.tokenizer.mm_embedding import (
FIRST_MM_EMBEDDING_INDEX,
global_allocator
)
global_allocator.next_token_index = FIRST_MM_EMBEDDING_INDEX
logger.info("Reset MMTokenAllocator for clean multimodal token allocation")
config = Config.from_directory(str(path_to_model))
model = Model.from_config(config)
# Calculate the closest multiple of 256 at or above the chosen value
max_tokens = shared.args.ctx_size
if max_tokens % 256 != 0:
adjusted_tokens = ((max_tokens // 256) + 1) * 256
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
max_tokens = adjusted_tokens
# Parse cache type (ExLlamaV2 pattern)
cache_type = shared.args.cache_type.lower()
cache_kwargs = {}
if cache_type == 'fp16':
layer_type = CacheLayer_fp16
elif cache_type.startswith('q'):
layer_type = CacheLayer_quant
if '_' in cache_type:
# Different bits for k and v (e.g., q4_q8)
k_part, v_part = cache_type.split('_')
k_bits = int(k_part[1:])
v_bits = int(v_part[1:])
else:
# Same bits for k and v (e.g., q4)
k_bits = v_bits = int(cache_type[1:])
# Validate bit ranges
if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8):
logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.")
layer_type = CacheLayer_fp16
else:
cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits}
else:
logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.")
layer_type = CacheLayer_fp16
cache = Cache(model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
load_params = {'progressbar': True}
if shared.args.gpu_split:
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
load_params['use_per_device'] = split
model.load(**load_params)
tokenizer = Tokenizer.from_config(config)
# Load vision model component (ExLlamaV3 native)
vision_model = None
try:
logger.info("Loading vision model component...")
vision_model = Model.from_config(config, component="vision")
vision_model.load(progressbar=True)
logger.info("Vision model loaded successfully")
except Exception as e:
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
generator = Generator(
model=model,
cache=cache,
tokenizer=tokenizer,
)
result = cls()
result.model = model
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
result.config = config
result.max_tokens = max_tokens
result.vision_model = vision_model
return result
def is_multimodal(self) -> bool:
"""Check if this model supports multimodal input."""
return hasattr(self, 'vision_model') and self.vision_model is not None
def _process_images_for_generation(self, prompt: str, state: dict) -> Tuple[str, List[Any]]:
"""
Process all possible image inputs and return modified prompt + embeddings.
Returns: (processed_prompt, image_embeddings)
"""
if not self.is_multimodal():
return prompt, []
# Collect images from various sources using shared utilities
pil_images = []
# From webui image_attachments (preferred format)
if 'image_attachments' in state and state['image_attachments']:
pil_images.extend(convert_image_attachments_to_pil(state['image_attachments']))
# From OpenAI API raw_images
elif 'raw_images' in state and state['raw_images']:
pil_images.extend(state['raw_images'])
# From OpenAI API messages format
elif 'messages' in state and state['messages']:
pil_images.extend(convert_openai_messages_to_images(state['messages']))
if not pil_images:
return prompt, []
# ExLlamaV3-specific: Generate embeddings
try:
# Use pre-computed embeddings if available (proper MMEmbedding lifetime)
if 'image_embeddings' in state and state['image_embeddings']:
# Use existing embeddings - this preserves MMEmbedding lifetime
image_embeddings = state['image_embeddings']
else:
# Do not reset the cache/allocator index; it causes token ID conflicts during generation.
logger.info(f"Processing {len(pil_images)} image(s) with ExLlamaV3 vision model")
image_embeddings = [
self.vision_model.get_image_embeddings(tokenizer=self.tokenizer, image=img)
for img in pil_images
]
# ExLlamaV3-specific: Handle prompt processing with placeholders
placeholders = [ie.text_alias for ie in image_embeddings]
if '<__media__>' in prompt:
# Web chat: Replace <__media__> placeholders
for alias in placeholders:
prompt = prompt.replace('<__media__>', alias, 1)
logger.info(f"Replaced {len(placeholders)} <__media__> placeholder(s)")
else:
# API: Prepend embedding aliases
combined_placeholders = "\n".join(placeholders)
prompt = combined_placeholders + "\n" + prompt
logger.info(f"Prepended {len(placeholders)} embedding(s) to prompt")
return prompt, image_embeddings
except Exception as e:
logger.error(f"Failed to process images: {e}")
return prompt, []
def generate_with_streaming(self, prompt, state):
"""
Generate text with streaming using native ExLlamaV3 API
"""
from exllamav3 import Job
from exllamav3.generator.sampler.presets import ComboSampler
# Process images and modify prompt (ExLlamaV3-specific)
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
sampler = ComboSampler(
rep_p=state.get('repetition_penalty', 1.0),
freq_p=state.get('frequency_penalty', 0.0),
pres_p=state.get('presence_penalty', 0.0),
temperature=state.get('temperature', 0.7),
min_p=state.get('min_p', 0.0),
top_k=state.get('top_k', 0),
top_p=state.get('top_p', 1.0),
)
# Encode prompt with embeddings (ExLlamaV3-specific)
if image_embeddings:
input_ids = self.tokenizer.encode(
prompt,
encode_special_tokens=True,
embeddings=image_embeddings,
)
else:
input_ids = self.tokenizer.encode(prompt, encode_special_tokens=True)
# Get stop conditions from state (webui format) - keep as strings like ExLlamaV3 examples
stop_conditions = []
if 'stopping_strings' in state and state['stopping_strings']:
# Use strings directly (ExLlamaV3 handles the conversion internally)
stop_conditions.extend(state['stopping_strings'])
# Add EOS token ID as ExLlamaV3 examples do
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
stop_conditions.append(self.tokenizer.eos_token_id)
job = Job(
input_ids=input_ids,
max_new_tokens=state.get('max_new_tokens', 500),
decode_special_tokens=True,
embeddings=image_embeddings if image_embeddings else None,
sampler=sampler,
stop_conditions=stop_conditions if stop_conditions else None,
)
# Stream generation
self.generator.enqueue(job)
response_text = ""
try:
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
for result in results:
if "eos" in result and result["eos"]:
break
chunk = result.get("text", "")
if chunk:
response_text += chunk
yield response_text
finally:
# No cleanup needed. MMEmbedding lifetime is managed by Python.
# Cache and page table resets are unnecessary and can cause token ID conflicts.
pass
def generate(self, prompt, state):
"""
Generate text using native ExLlamaV3 API (non-streaming)
"""
output = self.generator.generate(
prompt=prompt,
max_new_tokens=state.get('max_new_tokens', 500),
temperature=state.get('temperature', 0.7),
top_p=state.get('top_p', 1.0),
top_k=state.get('top_k', 0),
repetition_penalty=state.get('repetition_penalty', 1.0),
frequency_penalty=state.get('frequency_penalty', 0.0),
presence_penalty=state.get('presence_penalty', 0.0),
min_p=state.get('min_p', 0.0),
)
return output
def encode(self, string, **kwargs):
return self.tokenizer.encode(string, **kwargs)
def decode(self, ids, **kwargs):
return self.tokenizer.decode(ids, **kwargs)
@property
def last_prompt_token_count(self):
# This would need to be tracked during generation
return 0
def unload(self):
logger.info("Unloading ExLlamaV3 model components...")
if hasattr(self, 'vision_model') and self.vision_model is not None:
try:
del self.vision_model
except Exception as e:
logger.warning(f"Error unloading vision model: {e}")
self.vision_model = None
if hasattr(self, 'model') and self.model is not None:
try:
self.model.unload()
del self.model
except Exception as e:
logger.warning(f"Error unloading main model: {e}")
self.model = None
if hasattr(self, 'cache') and self.cache is not None:
self.cache = None
if hasattr(self, 'generator') and self.generator is not None:
self.generator = None
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
self.tokenizer = None
# Force GPU memory cleanup
import gc
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.empty_cache()
logger.info("ExLlamaV3 model fully unloaded")

View file

@ -406,16 +406,27 @@ def format_message_attachments(history, role, index):
for attachment in attachments:
name = html.escape(attachment["name"])
# Make clickable if URL exists
if "url" in attachment:
name = f'<a href="{html.escape(attachment["url"])}" target="_blank" rel="noopener noreferrer">{name}</a>'
if attachment.get("type") == "image":
# Show image preview
file_path = attachment.get("file_path", "")
attachments_html += (
f'<div class="attachment-box image-attachment">'
f'<img src="file/{file_path}" alt="{name}" class="image-preview" />'
f'<div class="attachment-name">{name}</div>'
f'</div>'
)
else:
# Make clickable if URL exists (web search)
if "url" in attachment:
name = f'<a href="{html.escape(attachment["url"])}" target="_blank" rel="noopener noreferrer">{name}</a>'
attachments_html += (
f'<div class="attachment-box">'
f'<div class="attachment-icon">{attachment_svg}</div>'
f'<div class="attachment-name">{name}</div>'
f'</div>'
)
attachments_html += (
f'<div class="attachment-box">'
f'<div class="attachment-icon">{attachment_svg}</div>'
f'<div class="attachment-name">{name}</div>'
f'</div>'
)
attachments_html += '</div>'
return attachments_html

View file

@ -55,6 +55,11 @@ loaders_and_params = OrderedDict({
'trust_remote_code',
'no_use_fast',
],
'ExLlamav3': [
'ctx_size',
'cache_type',
'gpu_split',
],
'ExLlamav2_HF': [
'ctx_size',
'cache_type',
@ -248,6 +253,41 @@ loaders_samplers = {
'grammar_string',
'grammar_file_row',
},
'ExLlamav3': {
'temperature',
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'min_p',
'top_p',
'top_k',
'typical_p',
'xtc_threshold',
'xtc_probability',
'tfs',
'top_a',
'dry_multiplier',
'dry_allowed_length',
'dry_base',
'repetition_penalty',
'frequency_penalty',
'presence_penalty',
'repetition_penalty_range',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'dynamic_temperature',
'temperature_last',
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'skip_special_tokens',
'seed',
'custom_token_bans',
'dry_sequence_breakers',
},
'ExLlamav2': {
'temperature',
'dynatemp_low',

View file

@ -19,6 +19,7 @@ def load_model(model_name, loader=None):
'llama.cpp': llama_cpp_server_loader,
'Transformers': transformers_loader,
'ExLlamav3_HF': ExLlamav3_HF_loader,
'ExLlamav3': ExLlamav3_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ExLlamav2': ExLlamav2_loader,
'TensorRT-LLM': TensorRT_LLM_loader,
@ -88,6 +89,14 @@ def ExLlamav3_HF_loader(model_name):
return Exllamav3HF.from_pretrained(model_name)
def ExLlamav3_loader(model_name):
from modules.exllamav3 import Exllamav3Model
model = Exllamav3Model.from_pretrained(model_name)
tokenizer = model.tokenizer
return model, tokenizer
def ExLlamav2_HF_loader(model_name):
from modules.exllamav2_hf import Exllamav2HF
@ -116,7 +125,9 @@ def unload_model(keep_model_name=False):
return
is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer')
if shared.model.__class__.__name__ == 'Exllamav3HF':
if shared.args.loader in ['ExLlamav3_HF', 'ExLlamav3']:
shared.model.unload()
elif shared.args.loader in ['ExLlamav2_HF', 'ExLlamav2'] and hasattr(shared.model, 'unload'):
shared.model.unload()
shared.model = shared.tokenizer = None

View file

@ -318,6 +318,8 @@ def fix_loader_name(name):
return 'ExLlamav2_HF'
elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']:
return 'ExLlamav3_HF'
elif name in ['exllamav3']:
return 'ExLlamav3'
elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']:
return 'TensorRT-LLM'

View file

@ -40,7 +40,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
yield ''
return
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']:
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']:
generate_func = generate_reply_custom
else:
generate_func = generate_reply_HF
@ -128,9 +128,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
from modules.torch_utils import get_device
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel']:
if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']:
input_ids = shared.tokenizer.encode(str(prompt))
if shared.model.__class__.__name__ != 'Exllamav2Model':
if shared.model.__class__.__name__ not in ['Exllamav2Model', 'Exllamav3Model']:
input_ids = np.array(input_ids).reshape(1, len(input_ids))
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
@ -148,7 +148,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu:
return input_ids
else:
device = get_device()
@ -295,6 +295,8 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
_StopEverythingStoppingCriteria
)
# Native ExLlamav3Model handles multimodal internally - no special routing needed
if shared.args.loader == 'Transformers':
clear_torch_cache()

View file

@ -54,7 +54,7 @@ def create_ui():
gr.HTML(value='<div class="hover-element" onclick="void(0)"><span style="width: 100px; display: block" id="hover-element-button">&#9776;</span><div class="hover-menu" id="hover-menu"></div>', elem_id='gr-hover')
with gr.Column(scale=10, elem_id='chat-input-container'):
shared.gradio['textbox'] = gr.MultimodalTextbox(label='', placeholder='Send a message', file_types=['text', '.pdf'], file_count="multiple", elem_id='chat-input', elem_classes=['add_scrollbar'])
shared.gradio['textbox'] = gr.MultimodalTextbox(label='', placeholder='Send a message', file_types=['text', '.pdf', 'image'], file_count="multiple", elem_id='chat-input', elem_classes=['add_scrollbar'])
shared.gradio['typing-dots'] = gr.HTML(value='<div class="typing"><span></span><span class="dot1"></span><span class="dot2"></span></div>', label='typing', elem_id='typing-container')
with gr.Column(scale=1, elem_id='generate-stop-container'):