mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add multimodal support (ExLlamaV3) (#7174)
This commit is contained in:
parent
b391ac8eb1
commit
88127f46c1
13
css/main.css
13
css/main.css
|
|
@ -1577,6 +1577,19 @@ strong {
|
|||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.image-attachment {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.image-preview {
|
||||
border-radius: 16px;
|
||||
margin-bottom: 5px;
|
||||
object-fit: cover;
|
||||
object-position: center;
|
||||
border: 2px solid var(--border-color-primary);
|
||||
aspect-ratio: 1 / 1;
|
||||
}
|
||||
|
||||
button:focus {
|
||||
outline: none;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,6 +77,24 @@ curl http://127.0.0.1:5000/v1/chat/completions \
|
|||
}'
|
||||
```
|
||||
|
||||
#### Multimodal support (ExLlamaV3)
|
||||
|
||||
```shell
|
||||
curl http://127.0.0.1:5000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What color is this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "https://github.com/turboderp-org/exllamav3/blob/master/examples/media/cat.png?raw=true"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
#### SSE streaming
|
||||
|
||||
```shell
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import tiktoken
|
|||
from pydantic import ValidationError
|
||||
|
||||
from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.image_utils import convert_openai_messages_to_images
|
||||
from extensions.openai.typing import ToolDefinition
|
||||
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
|
||||
from modules import shared
|
||||
|
|
@ -16,6 +17,7 @@ from modules.chat import (
|
|||
load_character_memoized,
|
||||
load_instruction_template_memoized
|
||||
)
|
||||
from modules.logging_colors import logger
|
||||
from modules.presets import load_preset_memoized
|
||||
from modules.text_generation import decode, encode, generate_reply
|
||||
|
||||
|
|
@ -82,6 +84,21 @@ def process_parameters(body, is_legacy=False):
|
|||
return generate_params
|
||||
|
||||
|
||||
def process_multimodal_content(content):
|
||||
"""Extract text from OpenAI multimodal format for non-multimodal models"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get('type') == 'text':
|
||||
text_parts.append(item.get('text', ''))
|
||||
return ' '.join(text_parts) if text_parts else str(content)
|
||||
|
||||
return str(content)
|
||||
|
||||
|
||||
def convert_history(history):
|
||||
'''
|
||||
Chat histories in this program are in the format [message, reply].
|
||||
|
|
@ -99,8 +116,11 @@ def convert_history(history):
|
|||
role = entry["role"]
|
||||
|
||||
if role == "user":
|
||||
# Extract text content (images handled by model-specific code)
|
||||
content = process_multimodal_content(content)
|
||||
user_input = content
|
||||
user_input_last = True
|
||||
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, '', ''])
|
||||
current_message = ""
|
||||
|
|
@ -126,7 +146,11 @@ def convert_history(history):
|
|||
if not user_input_last:
|
||||
user_input = ""
|
||||
|
||||
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
|
||||
return user_input, system_message, {
|
||||
'internal': chat_dialogue,
|
||||
'visible': copy.deepcopy(chat_dialogue),
|
||||
'messages': history # Store original messages for multimodal models
|
||||
}
|
||||
|
||||
|
||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
|
||||
|
|
@ -150,9 +174,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
elif m['role'] == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
|
||||
if 'content' not in m and "image_url" not in m:
|
||||
# Handle multimodal content validation
|
||||
content = m.get('content')
|
||||
if content is None:
|
||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||
|
||||
# Validate multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or 'type' not in item:
|
||||
raise InvalidRequestError(message="messages: invalid content item format", param='messages')
|
||||
if item['type'] not in ['text', 'image_url']:
|
||||
raise InvalidRequestError(message="messages: unsupported content type", param='messages')
|
||||
if item['type'] == 'text' and 'text' not in item:
|
||||
raise InvalidRequestError(message="messages: missing text in content item", param='messages')
|
||||
if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']):
|
||||
raise InvalidRequestError(message="messages: missing image_url in content item", param='messages')
|
||||
|
||||
# Chat Completions
|
||||
object_type = 'chat.completion' if not stream else 'chat.completion.chunk'
|
||||
created_time = int(time.time())
|
||||
|
|
@ -336,9 +374,26 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
if prompt_str not in body:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
# Handle both prompt and messages format for unified multimodal support
|
||||
if prompt_str not in body or body[prompt_str] is None:
|
||||
if 'messages' in body:
|
||||
# Convert messages format to prompt for completions endpoint
|
||||
prompt_text = ""
|
||||
for message in body.get('messages', []):
|
||||
if isinstance(message, dict) and 'content' in message:
|
||||
# Extract text content from multimodal messages
|
||||
content = message['content']
|
||||
if isinstance(content, str):
|
||||
prompt_text += content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get('type') == 'text':
|
||||
prompt_text += item.get('text', '')
|
||||
|
||||
# Allow empty prompts for image-only requests
|
||||
body[prompt_str] = prompt_text
|
||||
else:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
# common params
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
|
|
@ -349,9 +404,18 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||
suffix = body['suffix'] if body['suffix'] else ''
|
||||
echo = body['echo']
|
||||
|
||||
# Add messages to generate_params if present for multimodal processing
|
||||
if 'messages' in body:
|
||||
generate_params['messages'] = body['messages']
|
||||
|
||||
if not stream:
|
||||
prompt_arg = body[prompt_str]
|
||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
||||
|
||||
# Handle empty/None prompts (e.g., image-only requests)
|
||||
if prompt_arg is None:
|
||||
prompt_arg = ""
|
||||
|
||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and len(prompt_arg) > 0 and isinstance(prompt_arg[0], int)):
|
||||
prompt_arg = [prompt_arg]
|
||||
|
||||
resp_list_data = []
|
||||
|
|
@ -374,7 +438,19 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
|
||||
# Use multimodal generation if images are present
|
||||
if 'messages' in generate_params:
|
||||
raw_images = convert_openai_messages_to_images(generate_params['messages'])
|
||||
if raw_images:
|
||||
logger.info(f"Using multimodal generation for {len(raw_images)} images")
|
||||
generate_params['raw_images'] = raw_images
|
||||
generator = shared.model.generate_with_streaming(prompt, generate_params)
|
||||
else:
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
else:
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
|
|
@ -447,7 +523,17 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
# Use multimodal generation if images are present
|
||||
if 'messages' in generate_params:
|
||||
raw_images = convert_openai_messages_to_images(generate_params['messages'])
|
||||
if raw_images:
|
||||
logger.info(f"Using multimodal generation for {len(raw_images)} images")
|
||||
generate_params['raw_images'] = raw_images
|
||||
generator = shared.model.generate_with_streaming(prompt, generate_params)
|
||||
else:
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
else:
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
|
|
|
|||
97
extensions/openai/image_utils.py
Normal file
97
extensions/openai/image_utils.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""
|
||||
Shared image processing utilities for multimodal support.
|
||||
Used by both ExLlamaV3 and llama.cpp implementations.
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def decode_base64_image(base64_string: str) -> Image.Image:
|
||||
"""Decodes a base64 string to a PIL Image."""
|
||||
try:
|
||||
if base64_string.startswith('data:image/'):
|
||||
base64_string = base64_string.split(',', 1)[1]
|
||||
|
||||
image_data = base64.b64decode(base64_string)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
return image
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode base64 image: {e}")
|
||||
raise ValueError(f"Invalid base64 image data: {e}")
|
||||
|
||||
|
||||
def process_message_content(content: Any) -> Tuple[str, List[Image.Image]]:
|
||||
"""
|
||||
Processes message content that may contain text and images.
|
||||
Returns: A tuple of (text_content, list_of_pil_images).
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content, []
|
||||
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
images = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get('type', '')
|
||||
if item_type == 'text':
|
||||
text_parts.append(item.get('text', ''))
|
||||
elif item_type == 'image_url':
|
||||
image_url_data = item.get('image_url', {})
|
||||
image_url = image_url_data.get('url', '')
|
||||
|
||||
if image_url.startswith('data:image/'):
|
||||
try:
|
||||
images.append(decode_base64_image(image_url))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process a base64 image: {e}")
|
||||
elif image_url.startswith('http'):
|
||||
# Support external URLs
|
||||
try:
|
||||
import requests
|
||||
response = requests.get(image_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
image_data = response.content
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
images.append(image)
|
||||
logger.info("Successfully loaded external image from URL")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch external image: {e}")
|
||||
else:
|
||||
logger.warning(f"Unsupported image URL format: {image_url[:70]}...")
|
||||
|
||||
return ' '.join(text_parts), images
|
||||
|
||||
return str(content), []
|
||||
|
||||
|
||||
def convert_image_attachments_to_pil(image_attachments: List[dict]) -> List[Image.Image]:
|
||||
"""Convert webui image_attachments format to PIL Images."""
|
||||
pil_images = []
|
||||
for attachment in image_attachments:
|
||||
if attachment.get('type') == 'image' and 'image_data' in attachment:
|
||||
try:
|
||||
image = decode_base64_image(attachment['image_data'])
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
pil_images.append(image)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process image attachment: {e}")
|
||||
return pil_images
|
||||
|
||||
|
||||
def convert_openai_messages_to_images(messages: List[dict]) -> List[Image.Image]:
|
||||
"""Convert OpenAI messages format to PIL Images."""
|
||||
all_images = []
|
||||
for message in messages:
|
||||
if isinstance(message, dict) and 'content' in message:
|
||||
_, images = process_message_content(message['content'])
|
||||
all_images.extend(images)
|
||||
return all_images
|
||||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator, validator
|
||||
|
||||
|
||||
class GenerationOptions(BaseModel):
|
||||
|
|
@ -99,7 +99,8 @@ class ToolCall(BaseModel):
|
|||
|
||||
class CompletionRequestParams(BaseModel):
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||
prompt: str | List[str]
|
||||
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
|
||||
messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.")
|
||||
best_of: int | None = Field(default=1, description="Unused parameter.")
|
||||
echo: bool | None = False
|
||||
frequency_penalty: float | None = 0
|
||||
|
|
@ -115,6 +116,17 @@ class CompletionRequestParams(BaseModel):
|
|||
top_p: float | None = 1
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
@field_validator('prompt', 'messages')
|
||||
@classmethod
|
||||
def validate_prompt_or_messages(cls, v, info):
|
||||
"""Ensure either 'prompt' or 'messages' is provided for completions."""
|
||||
if info.field_name == 'prompt': # If we're validating 'prompt', check if neither prompt nor messages will be set
|
||||
messages = info.data.get('messages')
|
||||
if v is None and messages is None:
|
||||
raise ValueError("Either 'prompt' or 'messages' must be provided")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class CompletionRequest(GenerationOptions, CompletionRequestParams):
|
||||
pass
|
||||
|
|
|
|||
126
modules/chat.py
126
modules/chat.py
|
|
@ -271,16 +271,27 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
# Add attachment content if present AND if past attachments are enabled
|
||||
if (state.get('include_past_attachments', True) and user_key in metadata and "attachments" in metadata[user_key]):
|
||||
attachments_text = ""
|
||||
for attachment in metadata[user_key]["attachments"]:
|
||||
filename = attachment.get("name", "file")
|
||||
content = attachment.get("content", "")
|
||||
if attachment.get("type") == "text/html" and attachment.get("url"):
|
||||
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
else:
|
||||
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
image_refs = ""
|
||||
|
||||
if attachments_text:
|
||||
enhanced_user_msg = f"{user_msg}\n\nATTACHMENTS:\n{attachments_text}"
|
||||
for attachment in metadata[user_key]["attachments"]:
|
||||
if attachment.get("type") == "image":
|
||||
# Add image reference for multimodal models
|
||||
image_refs += "<__media__>"
|
||||
else:
|
||||
# Handle text/PDF attachments
|
||||
filename = attachment.get("name", "file")
|
||||
content = attachment.get("content", "")
|
||||
if attachment.get("type") == "text/html" and attachment.get("url"):
|
||||
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
else:
|
||||
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
|
||||
if image_refs or attachments_text:
|
||||
enhanced_user_msg = user_msg
|
||||
if image_refs:
|
||||
enhanced_user_msg += f" {image_refs}"
|
||||
if attachments_text:
|
||||
enhanced_user_msg += f"\n\nATTACHMENTS:\n{attachments_text}"
|
||||
|
||||
messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg})
|
||||
|
||||
|
|
@ -301,16 +312,23 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
if user_key in metadata and "attachments" in metadata[user_key]:
|
||||
attachments_text = ""
|
||||
for attachment in metadata[user_key]["attachments"]:
|
||||
filename = attachment.get("name", "file")
|
||||
content = attachment.get("content", "")
|
||||
if attachment.get("type") == "text/html" and attachment.get("url"):
|
||||
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
else:
|
||||
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
image_refs = ""
|
||||
|
||||
if attachments_text:
|
||||
user_input = f"{user_input}\n\nATTACHMENTS:\n{attachments_text}"
|
||||
for attachment in metadata[user_key]["attachments"]:
|
||||
if attachment.get("type") == "image":
|
||||
image_refs += "<__media__>"
|
||||
else:
|
||||
filename = attachment.get("name", "file")
|
||||
content = attachment.get("content", "")
|
||||
if attachment.get("type") == "text/html" and attachment.get("url"):
|
||||
attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
else:
|
||||
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
|
||||
if image_refs or attachments_text:
|
||||
user_input = f"{user_input} {image_refs}"
|
||||
if attachments_text:
|
||||
user_input += f"\n\nATTACHMENTS:\n{attachments_text}"
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
|
|
@ -594,29 +612,64 @@ def add_message_attachment(history, row_idx, file_path, is_user=True):
|
|||
file_extension = path.suffix.lower()
|
||||
|
||||
try:
|
||||
# Handle different file types
|
||||
if file_extension == '.pdf':
|
||||
# Handle image files
|
||||
if file_extension in ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']:
|
||||
# Convert image to base64
|
||||
with open(path, 'rb') as f:
|
||||
image_data = base64.b64encode(f.read()).decode('utf-8')
|
||||
|
||||
# Determine MIME type from extension
|
||||
mime_type_map = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.webp': 'image/webp',
|
||||
'.bmp': 'image/bmp',
|
||||
'.gif': 'image/gif'
|
||||
}
|
||||
mime_type = mime_type_map.get(file_extension, 'image/jpeg')
|
||||
|
||||
# Format as data URL
|
||||
data_url = f"data:{mime_type};base64,{image_data}"
|
||||
|
||||
# Generate unique image ID
|
||||
image_id = len([att for att in history['metadata'][key]["attachments"] if att.get("type") == "image"]) + 1
|
||||
|
||||
attachment = {
|
||||
"name": filename,
|
||||
"type": "image",
|
||||
"image_data": data_url,
|
||||
"image_id": image_id,
|
||||
"file_path": str(path) # For UI preview
|
||||
}
|
||||
elif file_extension == '.pdf':
|
||||
# Process PDF file
|
||||
content = extract_pdf_text(path)
|
||||
file_type = "application/pdf"
|
||||
attachment = {
|
||||
"name": filename,
|
||||
"type": "application/pdf",
|
||||
"content": content,
|
||||
}
|
||||
elif file_extension == '.docx':
|
||||
content = extract_docx_text(path)
|
||||
file_type = "application/docx"
|
||||
attachment = {
|
||||
"name": filename,
|
||||
"type": "application/docx",
|
||||
"content": content,
|
||||
}
|
||||
else:
|
||||
# Default handling for text files
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
file_type = "text/plain"
|
||||
|
||||
# Add attachment
|
||||
attachment = {
|
||||
"name": filename,
|
||||
"type": file_type,
|
||||
"content": content,
|
||||
}
|
||||
attachment = {
|
||||
"name": filename,
|
||||
"type": "text/plain",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
history['metadata'][key]["attachments"].append(attachment)
|
||||
return content # Return the content for reuse
|
||||
return attachment # Return the attachment for reuse
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing attachment {filename}: {e}")
|
||||
return None
|
||||
|
|
@ -759,6 +812,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
for file_path in files:
|
||||
add_message_attachment(output, row_idx, file_path, is_user=True)
|
||||
|
||||
# Collect image attachments for ExLlamaV3
|
||||
image_attachments = []
|
||||
if 'metadata' in output:
|
||||
user_key = f"user_{row_idx}"
|
||||
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
|
||||
for attachment in output['metadata'][user_key]["attachments"]:
|
||||
if attachment.get("type") == "image":
|
||||
image_attachments.append(attachment)
|
||||
|
||||
# Add image attachments to state for the generation
|
||||
if image_attachments:
|
||||
state['image_attachments'] = image_attachments
|
||||
|
||||
# Add web search results as attachments if enabled
|
||||
if state.get('enable_web_search', False):
|
||||
search_query = generate_search_query(text, state)
|
||||
|
|
|
|||
313
modules/exllamav3.py
Normal file
313
modules/exllamav3.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from exllamav3 import Cache, Config, Generator, Model, Tokenizer
|
||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||
|
||||
from extensions.openai.image_utils import (
|
||||
convert_image_attachments_to_pil,
|
||||
convert_openai_messages_to_images
|
||||
)
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
except Exception:
|
||||
logger.warning('Failed to load flash-attention due to the following error:\n')
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class Exllamav3Model:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_to_model):
|
||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||
|
||||
# Reset global MMTokenAllocator to prevent token ID corruption when switching models
|
||||
from exllamav3.tokenizer.mm_embedding import (
|
||||
FIRST_MM_EMBEDDING_INDEX,
|
||||
global_allocator
|
||||
)
|
||||
global_allocator.next_token_index = FIRST_MM_EMBEDDING_INDEX
|
||||
logger.info("Reset MMTokenAllocator for clean multimodal token allocation")
|
||||
|
||||
config = Config.from_directory(str(path_to_model))
|
||||
model = Model.from_config(config)
|
||||
|
||||
# Calculate the closest multiple of 256 at or above the chosen value
|
||||
max_tokens = shared.args.ctx_size
|
||||
if max_tokens % 256 != 0:
|
||||
adjusted_tokens = ((max_tokens // 256) + 1) * 256
|
||||
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
|
||||
max_tokens = adjusted_tokens
|
||||
|
||||
# Parse cache type (ExLlamaV2 pattern)
|
||||
cache_type = shared.args.cache_type.lower()
|
||||
cache_kwargs = {}
|
||||
if cache_type == 'fp16':
|
||||
layer_type = CacheLayer_fp16
|
||||
elif cache_type.startswith('q'):
|
||||
layer_type = CacheLayer_quant
|
||||
if '_' in cache_type:
|
||||
# Different bits for k and v (e.g., q4_q8)
|
||||
k_part, v_part = cache_type.split('_')
|
||||
k_bits = int(k_part[1:])
|
||||
v_bits = int(v_part[1:])
|
||||
else:
|
||||
# Same bits for k and v (e.g., q4)
|
||||
k_bits = v_bits = int(cache_type[1:])
|
||||
|
||||
# Validate bit ranges
|
||||
if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8):
|
||||
logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.")
|
||||
layer_type = CacheLayer_fp16
|
||||
else:
|
||||
cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits}
|
||||
else:
|
||||
logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.")
|
||||
layer_type = CacheLayer_fp16
|
||||
|
||||
cache = Cache(model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||
|
||||
load_params = {'progressbar': True}
|
||||
if shared.args.gpu_split:
|
||||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||
load_params['use_per_device'] = split
|
||||
|
||||
model.load(**load_params)
|
||||
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
|
||||
# Load vision model component (ExLlamaV3 native)
|
||||
vision_model = None
|
||||
try:
|
||||
logger.info("Loading vision model component...")
|
||||
vision_model = Model.from_config(config, component="vision")
|
||||
vision_model.load(progressbar=True)
|
||||
logger.info("Vision model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision model loading failed (multimodal disabled): {e}")
|
||||
|
||||
generator = Generator(
|
||||
model=model,
|
||||
cache=cache,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
result = cls()
|
||||
result.model = model
|
||||
result.cache = cache
|
||||
result.tokenizer = tokenizer
|
||||
result.generator = generator
|
||||
result.config = config
|
||||
result.max_tokens = max_tokens
|
||||
result.vision_model = vision_model
|
||||
|
||||
return result
|
||||
|
||||
def is_multimodal(self) -> bool:
|
||||
"""Check if this model supports multimodal input."""
|
||||
return hasattr(self, 'vision_model') and self.vision_model is not None
|
||||
|
||||
def _process_images_for_generation(self, prompt: str, state: dict) -> Tuple[str, List[Any]]:
|
||||
"""
|
||||
Process all possible image inputs and return modified prompt + embeddings.
|
||||
Returns: (processed_prompt, image_embeddings)
|
||||
"""
|
||||
if not self.is_multimodal():
|
||||
return prompt, []
|
||||
|
||||
# Collect images from various sources using shared utilities
|
||||
pil_images = []
|
||||
|
||||
# From webui image_attachments (preferred format)
|
||||
if 'image_attachments' in state and state['image_attachments']:
|
||||
pil_images.extend(convert_image_attachments_to_pil(state['image_attachments']))
|
||||
|
||||
# From OpenAI API raw_images
|
||||
elif 'raw_images' in state and state['raw_images']:
|
||||
pil_images.extend(state['raw_images'])
|
||||
|
||||
# From OpenAI API messages format
|
||||
elif 'messages' in state and state['messages']:
|
||||
pil_images.extend(convert_openai_messages_to_images(state['messages']))
|
||||
|
||||
if not pil_images:
|
||||
return prompt, []
|
||||
|
||||
# ExLlamaV3-specific: Generate embeddings
|
||||
try:
|
||||
# Use pre-computed embeddings if available (proper MMEmbedding lifetime)
|
||||
if 'image_embeddings' in state and state['image_embeddings']:
|
||||
# Use existing embeddings - this preserves MMEmbedding lifetime
|
||||
image_embeddings = state['image_embeddings']
|
||||
else:
|
||||
# Do not reset the cache/allocator index; it causes token ID conflicts during generation.
|
||||
|
||||
logger.info(f"Processing {len(pil_images)} image(s) with ExLlamaV3 vision model")
|
||||
image_embeddings = [
|
||||
self.vision_model.get_image_embeddings(tokenizer=self.tokenizer, image=img)
|
||||
for img in pil_images
|
||||
]
|
||||
|
||||
# ExLlamaV3-specific: Handle prompt processing with placeholders
|
||||
placeholders = [ie.text_alias for ie in image_embeddings]
|
||||
|
||||
if '<__media__>' in prompt:
|
||||
# Web chat: Replace <__media__> placeholders
|
||||
for alias in placeholders:
|
||||
prompt = prompt.replace('<__media__>', alias, 1)
|
||||
logger.info(f"Replaced {len(placeholders)} <__media__> placeholder(s)")
|
||||
else:
|
||||
# API: Prepend embedding aliases
|
||||
combined_placeholders = "\n".join(placeholders)
|
||||
prompt = combined_placeholders + "\n" + prompt
|
||||
logger.info(f"Prepended {len(placeholders)} embedding(s) to prompt")
|
||||
|
||||
return prompt, image_embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process images: {e}")
|
||||
return prompt, []
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
"""
|
||||
Generate text with streaming using native ExLlamaV3 API
|
||||
"""
|
||||
from exllamav3 import Job
|
||||
from exllamav3.generator.sampler.presets import ComboSampler
|
||||
|
||||
# Process images and modify prompt (ExLlamaV3-specific)
|
||||
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
|
||||
|
||||
sampler = ComboSampler(
|
||||
rep_p=state.get('repetition_penalty', 1.0),
|
||||
freq_p=state.get('frequency_penalty', 0.0),
|
||||
pres_p=state.get('presence_penalty', 0.0),
|
||||
temperature=state.get('temperature', 0.7),
|
||||
min_p=state.get('min_p', 0.0),
|
||||
top_k=state.get('top_k', 0),
|
||||
top_p=state.get('top_p', 1.0),
|
||||
)
|
||||
|
||||
# Encode prompt with embeddings (ExLlamaV3-specific)
|
||||
if image_embeddings:
|
||||
input_ids = self.tokenizer.encode(
|
||||
prompt,
|
||||
encode_special_tokens=True,
|
||||
embeddings=image_embeddings,
|
||||
)
|
||||
else:
|
||||
input_ids = self.tokenizer.encode(prompt, encode_special_tokens=True)
|
||||
|
||||
# Get stop conditions from state (webui format) - keep as strings like ExLlamaV3 examples
|
||||
stop_conditions = []
|
||||
if 'stopping_strings' in state and state['stopping_strings']:
|
||||
# Use strings directly (ExLlamaV3 handles the conversion internally)
|
||||
stop_conditions.extend(state['stopping_strings'])
|
||||
|
||||
# Add EOS token ID as ExLlamaV3 examples do
|
||||
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||
|
||||
job = Job(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=state.get('max_new_tokens', 500),
|
||||
decode_special_tokens=True,
|
||||
embeddings=image_embeddings if image_embeddings else None,
|
||||
sampler=sampler,
|
||||
stop_conditions=stop_conditions if stop_conditions else None,
|
||||
)
|
||||
|
||||
# Stream generation
|
||||
self.generator.enqueue(job)
|
||||
|
||||
response_text = ""
|
||||
try:
|
||||
while self.generator.num_remaining_jobs():
|
||||
results = self.generator.iterate()
|
||||
for result in results:
|
||||
if "eos" in result and result["eos"]:
|
||||
break
|
||||
|
||||
chunk = result.get("text", "")
|
||||
if chunk:
|
||||
response_text += chunk
|
||||
yield response_text
|
||||
finally:
|
||||
# No cleanup needed. MMEmbedding lifetime is managed by Python.
|
||||
# Cache and page table resets are unnecessary and can cause token ID conflicts.
|
||||
pass
|
||||
|
||||
def generate(self, prompt, state):
|
||||
"""
|
||||
Generate text using native ExLlamaV3 API (non-streaming)
|
||||
"""
|
||||
output = self.generator.generate(
|
||||
prompt=prompt,
|
||||
max_new_tokens=state.get('max_new_tokens', 500),
|
||||
temperature=state.get('temperature', 0.7),
|
||||
top_p=state.get('top_p', 1.0),
|
||||
top_k=state.get('top_k', 0),
|
||||
repetition_penalty=state.get('repetition_penalty', 1.0),
|
||||
frequency_penalty=state.get('frequency_penalty', 0.0),
|
||||
presence_penalty=state.get('presence_penalty', 0.0),
|
||||
min_p=state.get('min_p', 0.0),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string, **kwargs)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
return self.tokenizer.decode(ids, **kwargs)
|
||||
|
||||
@property
|
||||
def last_prompt_token_count(self):
|
||||
# This would need to be tracked during generation
|
||||
return 0
|
||||
|
||||
def unload(self):
|
||||
logger.info("Unloading ExLlamaV3 model components...")
|
||||
|
||||
if hasattr(self, 'vision_model') and self.vision_model is not None:
|
||||
try:
|
||||
del self.vision_model
|
||||
except Exception as e:
|
||||
logger.warning(f"Error unloading vision model: {e}")
|
||||
self.vision_model = None
|
||||
|
||||
if hasattr(self, 'model') and self.model is not None:
|
||||
try:
|
||||
self.model.unload()
|
||||
del self.model
|
||||
except Exception as e:
|
||||
logger.warning(f"Error unloading main model: {e}")
|
||||
self.model = None
|
||||
|
||||
if hasattr(self, 'cache') and self.cache is not None:
|
||||
self.cache = None
|
||||
|
||||
if hasattr(self, 'generator') and self.generator is not None:
|
||||
self.generator = None
|
||||
|
||||
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
||||
self.tokenizer = None
|
||||
|
||||
# Force GPU memory cleanup
|
||||
import gc
|
||||
|
||||
import torch
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("ExLlamaV3 model fully unloaded")
|
||||
|
|
@ -406,16 +406,27 @@ def format_message_attachments(history, role, index):
|
|||
for attachment in attachments:
|
||||
name = html.escape(attachment["name"])
|
||||
|
||||
# Make clickable if URL exists
|
||||
if "url" in attachment:
|
||||
name = f'<a href="{html.escape(attachment["url"])}" target="_blank" rel="noopener noreferrer">{name}</a>'
|
||||
if attachment.get("type") == "image":
|
||||
# Show image preview
|
||||
file_path = attachment.get("file_path", "")
|
||||
attachments_html += (
|
||||
f'<div class="attachment-box image-attachment">'
|
||||
f'<img src="file/{file_path}" alt="{name}" class="image-preview" />'
|
||||
f'<div class="attachment-name">{name}</div>'
|
||||
f'</div>'
|
||||
)
|
||||
else:
|
||||
# Make clickable if URL exists (web search)
|
||||
if "url" in attachment:
|
||||
name = f'<a href="{html.escape(attachment["url"])}" target="_blank" rel="noopener noreferrer">{name}</a>'
|
||||
|
||||
attachments_html += (
|
||||
f'<div class="attachment-box">'
|
||||
f'<div class="attachment-icon">{attachment_svg}</div>'
|
||||
f'<div class="attachment-name">{name}</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
attachments_html += (
|
||||
f'<div class="attachment-box">'
|
||||
f'<div class="attachment-icon">{attachment_svg}</div>'
|
||||
f'<div class="attachment-name">{name}</div>'
|
||||
f'</div>'
|
||||
)
|
||||
attachments_html += '</div>'
|
||||
return attachments_html
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,11 @@ loaders_and_params = OrderedDict({
|
|||
'trust_remote_code',
|
||||
'no_use_fast',
|
||||
],
|
||||
'ExLlamav3': [
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
'gpu_split',
|
||||
],
|
||||
'ExLlamav2_HF': [
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
|
|
@ -248,6 +253,41 @@ loaders_samplers = {
|
|||
'grammar_string',
|
||||
'grammar_file_row',
|
||||
},
|
||||
'ExLlamav3': {
|
||||
'temperature',
|
||||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'min_p',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'typical_p',
|
||||
'xtc_threshold',
|
||||
'xtc_probability',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
'repetition_penalty',
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'repetition_penalty_range',
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'dynamic_temperature',
|
||||
'temperature_last',
|
||||
'auto_max_new_tokens',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'enable_thinking',
|
||||
'skip_special_tokens',
|
||||
'seed',
|
||||
'custom_token_bans',
|
||||
'dry_sequence_breakers',
|
||||
},
|
||||
'ExLlamav2': {
|
||||
'temperature',
|
||||
'dynatemp_low',
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def load_model(model_name, loader=None):
|
|||
'llama.cpp': llama_cpp_server_loader,
|
||||
'Transformers': transformers_loader,
|
||||
'ExLlamav3_HF': ExLlamav3_HF_loader,
|
||||
'ExLlamav3': ExLlamav3_loader,
|
||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||
'ExLlamav2': ExLlamav2_loader,
|
||||
'TensorRT-LLM': TensorRT_LLM_loader,
|
||||
|
|
@ -88,6 +89,14 @@ def ExLlamav3_HF_loader(model_name):
|
|||
return Exllamav3HF.from_pretrained(model_name)
|
||||
|
||||
|
||||
def ExLlamav3_loader(model_name):
|
||||
from modules.exllamav3 import Exllamav3Model
|
||||
|
||||
model = Exllamav3Model.from_pretrained(model_name)
|
||||
tokenizer = model.tokenizer
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def ExLlamav2_HF_loader(model_name):
|
||||
from modules.exllamav2_hf import Exllamav2HF
|
||||
|
||||
|
|
@ -116,7 +125,9 @@ def unload_model(keep_model_name=False):
|
|||
return
|
||||
|
||||
is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer')
|
||||
if shared.model.__class__.__name__ == 'Exllamav3HF':
|
||||
if shared.args.loader in ['ExLlamav3_HF', 'ExLlamav3']:
|
||||
shared.model.unload()
|
||||
elif shared.args.loader in ['ExLlamav2_HF', 'ExLlamav2'] and hasattr(shared.model, 'unload'):
|
||||
shared.model.unload()
|
||||
|
||||
shared.model = shared.tokenizer = None
|
||||
|
|
|
|||
|
|
@ -318,6 +318,8 @@ def fix_loader_name(name):
|
|||
return 'ExLlamav2_HF'
|
||||
elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']:
|
||||
return 'ExLlamav3_HF'
|
||||
elif name in ['exllamav3']:
|
||||
return 'ExLlamav3'
|
||||
elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']:
|
||||
return 'TensorRT-LLM'
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
yield ''
|
||||
return
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']:
|
||||
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']:
|
||||
generate_func = generate_reply_custom
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
|
|
@ -128,9 +128,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
|
||||
from modules.torch_utils import get_device
|
||||
|
||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel']:
|
||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']:
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
if shared.model.__class__.__name__ != 'Exllamav2Model':
|
||||
if shared.model.__class__.__name__ not in ['Exllamav2Model', 'Exllamav3Model']:
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
else:
|
||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
||||
|
|
@ -148,7 +148,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
|
||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
else:
|
||||
device = get_device()
|
||||
|
|
@ -295,6 +295,8 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
|
|||
_StopEverythingStoppingCriteria
|
||||
)
|
||||
|
||||
# Native ExLlamav3Model handles multimodal internally - no special routing needed
|
||||
|
||||
if shared.args.loader == 'Transformers':
|
||||
clear_torch_cache()
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def create_ui():
|
|||
gr.HTML(value='<div class="hover-element" onclick="void(0)"><span style="width: 100px; display: block" id="hover-element-button">☰</span><div class="hover-menu" id="hover-menu"></div>', elem_id='gr-hover')
|
||||
|
||||
with gr.Column(scale=10, elem_id='chat-input-container'):
|
||||
shared.gradio['textbox'] = gr.MultimodalTextbox(label='', placeholder='Send a message', file_types=['text', '.pdf'], file_count="multiple", elem_id='chat-input', elem_classes=['add_scrollbar'])
|
||||
shared.gradio['textbox'] = gr.MultimodalTextbox(label='', placeholder='Send a message', file_types=['text', '.pdf', 'image'], file_count="multiple", elem_id='chat-input', elem_classes=['add_scrollbar'])
|
||||
shared.gradio['typing-dots'] = gr.HTML(value='<div class="typing"><span></span><span class="dot1"></span><span class="dot2"></span></div>', label='typing', elem_id='typing-container')
|
||||
|
||||
with gr.Column(scale=1, elem_id='generate-stop-container'):
|
||||
|
|
|
|||
Loading…
Reference in a new issue