diff --git a/css/main.css b/css/main.css index 240a94d5..de16d81d 100644 --- a/css/main.css +++ b/css/main.css @@ -1577,6 +1577,19 @@ strong { margin-top: 4px; } +.image-attachment { + flex-direction: column; +} + +.image-preview { + border-radius: 16px; + margin-bottom: 5px; + object-fit: cover; + object-position: center; + border: 2px solid var(--border-color-primary); + aspect-ratio: 1 / 1; +} + button:focus { outline: none; } diff --git a/docs/12 - OpenAI API.md b/docs/12 - OpenAI API.md index ec999397..b7b5fbc1 100644 --- a/docs/12 - OpenAI API.md +++ b/docs/12 - OpenAI API.md @@ -77,6 +77,24 @@ curl http://127.0.0.1:5000/v1/chat/completions \ }' ``` +#### Multimodal support (ExLlamaV3) + +```shell +curl http://127.0.0.1:5000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is this image?"}, + {"type": "image_url", "image_url": {"url": "https://github.com/turboderp-org/exllamav3/blob/master/examples/media/cat.png?raw=true"}} + ] + } + ] + }' +``` + #### SSE streaming ```shell diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 5181b18b..3d389f0b 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -7,6 +7,7 @@ import tiktoken from pydantic import ValidationError from extensions.openai.errors import InvalidRequestError +from extensions.openai.image_utils import convert_openai_messages_to_images from extensions.openai.typing import ToolDefinition from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall from modules import shared @@ -16,6 +17,7 @@ from modules.chat import ( load_character_memoized, load_instruction_template_memoized ) +from modules.logging_colors import logger from modules.presets import load_preset_memoized from modules.text_generation import decode, encode, generate_reply @@ -82,6 +84,21 @@ def process_parameters(body, is_legacy=False): return generate_params +def process_multimodal_content(content): + """Extract text from OpenAI multimodal format for non-multimodal models""" + if isinstance(content, str): + return content + + if isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, dict) and item.get('type') == 'text': + text_parts.append(item.get('text', '')) + return ' '.join(text_parts) if text_parts else str(content) + + return str(content) + + def convert_history(history): ''' Chat histories in this program are in the format [message, reply]. @@ -99,8 +116,11 @@ def convert_history(history): role = entry["role"] if role == "user": + # Extract text content (images handled by model-specific code) + content = process_multimodal_content(content) user_input = content user_input_last = True + if current_message: chat_dialogue.append([current_message, '', '']) current_message = "" @@ -126,7 +146,11 @@ def convert_history(history): if not user_input_last: user_input = "" - return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)} + return user_input, system_message, { + 'internal': chat_dialogue, + 'visible': copy.deepcopy(chat_dialogue), + 'messages': history # Store original messages for multimodal models + } def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict: @@ -150,9 +174,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p elif m['role'] == 'function': raise InvalidRequestError(message="role: function is not supported.", param='messages') - if 'content' not in m and "image_url" not in m: + # Handle multimodal content validation + content = m.get('content') + if content is None: raise InvalidRequestError(message="messages: missing content", param='messages') + # Validate multimodal content structure + if isinstance(content, list): + for item in content: + if not isinstance(item, dict) or 'type' not in item: + raise InvalidRequestError(message="messages: invalid content item format", param='messages') + if item['type'] not in ['text', 'image_url']: + raise InvalidRequestError(message="messages: unsupported content type", param='messages') + if item['type'] == 'text' and 'text' not in item: + raise InvalidRequestError(message="messages: missing text in content item", param='messages') + if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']): + raise InvalidRequestError(message="messages: missing image_url in content item", param='messages') + # Chat Completions object_type = 'chat.completion' if not stream else 'chat.completion.chunk' created_time = int(time.time()) @@ -336,9 +374,26 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): prompt_str = 'context' if is_legacy else 'prompt' - # ... encoded as a string, array of strings, array of tokens, or array of token arrays. - if prompt_str not in body: - raise InvalidRequestError("Missing required input", param=prompt_str) + # Handle both prompt and messages format for unified multimodal support + if prompt_str not in body or body[prompt_str] is None: + if 'messages' in body: + # Convert messages format to prompt for completions endpoint + prompt_text = "" + for message in body.get('messages', []): + if isinstance(message, dict) and 'content' in message: + # Extract text content from multimodal messages + content = message['content'] + if isinstance(content, str): + prompt_text += content + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get('type') == 'text': + prompt_text += item.get('text', '') + + # Allow empty prompts for image-only requests + body[prompt_str] = prompt_text + else: + raise InvalidRequestError("Missing required input", param=prompt_str) # common params generate_params = process_parameters(body, is_legacy=is_legacy) @@ -349,9 +404,18 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): suffix = body['suffix'] if body['suffix'] else '' echo = body['echo'] + # Add messages to generate_params if present for multimodal processing + if 'messages' in body: + generate_params['messages'] = body['messages'] + if not stream: prompt_arg = body[prompt_str] - if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)): + + # Handle empty/None prompts (e.g., image-only requests) + if prompt_arg is None: + prompt_arg = "" + + if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and len(prompt_arg) > 0 and isinstance(prompt_arg[0], int)): prompt_arg = [prompt_arg] resp_list_data = [] @@ -374,7 +438,19 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): # generate reply ####################################### debug_msg({'prompt': prompt, 'generate_params': generate_params}) - generator = generate_reply(prompt, generate_params, is_chat=False) + + # Use multimodal generation if images are present + if 'messages' in generate_params: + raw_images = convert_openai_messages_to_images(generate_params['messages']) + if raw_images: + logger.info(f"Using multimodal generation for {len(raw_images)} images") + generate_params['raw_images'] = raw_images + generator = shared.model.generate_with_streaming(prompt, generate_params) + else: + generator = generate_reply(prompt, generate_params, is_chat=False) + else: + generator = generate_reply(prompt, generate_params, is_chat=False) + answer = '' for a in generator: @@ -447,7 +523,17 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): # generate reply ####################################### debug_msg({'prompt': prompt, 'generate_params': generate_params}) - generator = generate_reply(prompt, generate_params, is_chat=False) + # Use multimodal generation if images are present + if 'messages' in generate_params: + raw_images = convert_openai_messages_to_images(generate_params['messages']) + if raw_images: + logger.info(f"Using multimodal generation for {len(raw_images)} images") + generate_params['raw_images'] = raw_images + generator = shared.model.generate_with_streaming(prompt, generate_params) + else: + generator = generate_reply(prompt, generate_params, is_chat=False) + else: + generator = generate_reply(prompt, generate_params, is_chat=False) answer = '' seen_content = '' diff --git a/extensions/openai/image_utils.py b/extensions/openai/image_utils.py new file mode 100644 index 00000000..c54f0532 --- /dev/null +++ b/extensions/openai/image_utils.py @@ -0,0 +1,97 @@ +""" +Shared image processing utilities for multimodal support. +Used by both ExLlamaV3 and llama.cpp implementations. +""" +import base64 +import io +from typing import Any, List, Tuple + +from PIL import Image + +from modules.logging_colors import logger + + +def decode_base64_image(base64_string: str) -> Image.Image: + """Decodes a base64 string to a PIL Image.""" + try: + if base64_string.startswith('data:image/'): + base64_string = base64_string.split(',', 1)[1] + + image_data = base64.b64decode(base64_string) + image = Image.open(io.BytesIO(image_data)) + return image + except Exception as e: + logger.error(f"Failed to decode base64 image: {e}") + raise ValueError(f"Invalid base64 image data: {e}") + + +def process_message_content(content: Any) -> Tuple[str, List[Image.Image]]: + """ + Processes message content that may contain text and images. + Returns: A tuple of (text_content, list_of_pil_images). + """ + if isinstance(content, str): + return content, [] + + if isinstance(content, list): + text_parts = [] + images = [] + for item in content: + if not isinstance(item, dict): + continue + + item_type = item.get('type', '') + if item_type == 'text': + text_parts.append(item.get('text', '')) + elif item_type == 'image_url': + image_url_data = item.get('image_url', {}) + image_url = image_url_data.get('url', '') + + if image_url.startswith('data:image/'): + try: + images.append(decode_base64_image(image_url)) + except Exception as e: + logger.warning(f"Failed to process a base64 image: {e}") + elif image_url.startswith('http'): + # Support external URLs + try: + import requests + response = requests.get(image_url, timeout=10) + response.raise_for_status() + image_data = response.content + image = Image.open(io.BytesIO(image_data)) + images.append(image) + logger.info("Successfully loaded external image from URL") + except Exception as e: + logger.warning(f"Failed to fetch external image: {e}") + else: + logger.warning(f"Unsupported image URL format: {image_url[:70]}...") + + return ' '.join(text_parts), images + + return str(content), [] + + +def convert_image_attachments_to_pil(image_attachments: List[dict]) -> List[Image.Image]: + """Convert webui image_attachments format to PIL Images.""" + pil_images = [] + for attachment in image_attachments: + if attachment.get('type') == 'image' and 'image_data' in attachment: + try: + image = decode_base64_image(attachment['image_data']) + if image.mode != 'RGB': + image = image.convert('RGB') + pil_images.append(image) + except Exception as e: + logger.warning(f"Failed to process image attachment: {e}") + return pil_images + + +def convert_openai_messages_to_images(messages: List[dict]) -> List[Image.Image]: + """Convert OpenAI messages format to PIL Images.""" + all_images = [] + for message in messages: + if isinstance(message, dict) and 'content' in message: + _, images = process_message_content(message['content']) + all_images.extend(images) + return all_images diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 6bd3749f..e9f92da5 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -2,7 +2,7 @@ import json import time from typing import Dict, List, Optional -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator, validator class GenerationOptions(BaseModel): @@ -99,7 +99,8 @@ class ToolCall(BaseModel): class CompletionRequestParams(BaseModel): model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.") - prompt: str | List[str] + prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.") + messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.") best_of: int | None = Field(default=1, description="Unused parameter.") echo: bool | None = False frequency_penalty: float | None = 0 @@ -115,6 +116,17 @@ class CompletionRequestParams(BaseModel): top_p: float | None = 1 user: str | None = Field(default=None, description="Unused parameter.") + @field_validator('prompt', 'messages') + @classmethod + def validate_prompt_or_messages(cls, v, info): + """Ensure either 'prompt' or 'messages' is provided for completions.""" + if info.field_name == 'prompt': # If we're validating 'prompt', check if neither prompt nor messages will be set + messages = info.data.get('messages') + if v is None and messages is None: + raise ValueError("Either 'prompt' or 'messages' must be provided") + + return v + class CompletionRequest(GenerationOptions, CompletionRequestParams): pass diff --git a/modules/chat.py b/modules/chat.py index 1ab91b5e..354ae46b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -271,16 +271,27 @@ def generate_chat_prompt(user_input, state, **kwargs): # Add attachment content if present AND if past attachments are enabled if (state.get('include_past_attachments', True) and user_key in metadata and "attachments" in metadata[user_key]): attachments_text = "" - for attachment in metadata[user_key]["attachments"]: - filename = attachment.get("name", "file") - content = attachment.get("content", "") - if attachment.get("type") == "text/html" and attachment.get("url"): - attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n" - else: - attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" + image_refs = "" - if attachments_text: - enhanced_user_msg = f"{user_msg}\n\nATTACHMENTS:\n{attachments_text}" + for attachment in metadata[user_key]["attachments"]: + if attachment.get("type") == "image": + # Add image reference for multimodal models + image_refs += "<__media__>" + else: + # Handle text/PDF attachments + filename = attachment.get("name", "file") + content = attachment.get("content", "") + if attachment.get("type") == "text/html" and attachment.get("url"): + attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n" + else: + attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" + + if image_refs or attachments_text: + enhanced_user_msg = user_msg + if image_refs: + enhanced_user_msg += f" {image_refs}" + if attachments_text: + enhanced_user_msg += f"\n\nATTACHMENTS:\n{attachments_text}" messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg}) @@ -301,16 +312,23 @@ def generate_chat_prompt(user_input, state, **kwargs): if user_key in metadata and "attachments" in metadata[user_key]: attachments_text = "" - for attachment in metadata[user_key]["attachments"]: - filename = attachment.get("name", "file") - content = attachment.get("content", "") - if attachment.get("type") == "text/html" and attachment.get("url"): - attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n" - else: - attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" + image_refs = "" - if attachments_text: - user_input = f"{user_input}\n\nATTACHMENTS:\n{attachments_text}" + for attachment in metadata[user_key]["attachments"]: + if attachment.get("type") == "image": + image_refs += "<__media__>" + else: + filename = attachment.get("name", "file") + content = attachment.get("content", "") + if attachment.get("type") == "text/html" and attachment.get("url"): + attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n" + else: + attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" + + if image_refs or attachments_text: + user_input = f"{user_input} {image_refs}" + if attachments_text: + user_input += f"\n\nATTACHMENTS:\n{attachments_text}" messages.append({"role": "user", "content": user_input}) @@ -594,29 +612,64 @@ def add_message_attachment(history, row_idx, file_path, is_user=True): file_extension = path.suffix.lower() try: - # Handle different file types - if file_extension == '.pdf': + # Handle image files + if file_extension in ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']: + # Convert image to base64 + with open(path, 'rb') as f: + image_data = base64.b64encode(f.read()).decode('utf-8') + + # Determine MIME type from extension + mime_type_map = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.webp': 'image/webp', + '.bmp': 'image/bmp', + '.gif': 'image/gif' + } + mime_type = mime_type_map.get(file_extension, 'image/jpeg') + + # Format as data URL + data_url = f"data:{mime_type};base64,{image_data}" + + # Generate unique image ID + image_id = len([att for att in history['metadata'][key]["attachments"] if att.get("type") == "image"]) + 1 + + attachment = { + "name": filename, + "type": "image", + "image_data": data_url, + "image_id": image_id, + "file_path": str(path) # For UI preview + } + elif file_extension == '.pdf': # Process PDF file content = extract_pdf_text(path) - file_type = "application/pdf" + attachment = { + "name": filename, + "type": "application/pdf", + "content": content, + } elif file_extension == '.docx': content = extract_docx_text(path) - file_type = "application/docx" + attachment = { + "name": filename, + "type": "application/docx", + "content": content, + } else: # Default handling for text files with open(path, 'r', encoding='utf-8') as f: content = f.read() - file_type = "text/plain" - # Add attachment - attachment = { - "name": filename, - "type": file_type, - "content": content, - } + attachment = { + "name": filename, + "type": "text/plain", + "content": content, + } history['metadata'][key]["attachments"].append(attachment) - return content # Return the content for reuse + return attachment # Return the attachment for reuse except Exception as e: logger.error(f"Error processing attachment {filename}: {e}") return None @@ -759,6 +812,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess for file_path in files: add_message_attachment(output, row_idx, file_path, is_user=True) + # Collect image attachments for ExLlamaV3 + image_attachments = [] + if 'metadata' in output: + user_key = f"user_{row_idx}" + if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]: + for attachment in output['metadata'][user_key]["attachments"]: + if attachment.get("type") == "image": + image_attachments.append(attachment) + + # Add image attachments to state for the generation + if image_attachments: + state['image_attachments'] = image_attachments + # Add web search results as attachments if enabled if state.get('enable_web_search', False): search_query = generate_search_query(text, state) diff --git a/modules/exllamav3.py b/modules/exllamav3.py new file mode 100644 index 00000000..c2532ec3 --- /dev/null +++ b/modules/exllamav3.py @@ -0,0 +1,313 @@ +import traceback +from pathlib import Path +from typing import Any, List, Tuple + +from exllamav3 import Cache, Config, Generator, Model, Tokenizer +from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant + +from extensions.openai.image_utils import ( + convert_image_attachments_to_pil, + convert_openai_messages_to_images +) +from modules import shared +from modules.logging_colors import logger + +try: + import flash_attn +except Exception: + logger.warning('Failed to load flash-attention due to the following error:\n') + traceback.print_exc() + + +class Exllamav3Model: + def __init__(self): + pass + + @classmethod + def from_pretrained(cls, path_to_model): + path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) + + # Reset global MMTokenAllocator to prevent token ID corruption when switching models + from exllamav3.tokenizer.mm_embedding import ( + FIRST_MM_EMBEDDING_INDEX, + global_allocator + ) + global_allocator.next_token_index = FIRST_MM_EMBEDDING_INDEX + logger.info("Reset MMTokenAllocator for clean multimodal token allocation") + + config = Config.from_directory(str(path_to_model)) + model = Model.from_config(config) + + # Calculate the closest multiple of 256 at or above the chosen value + max_tokens = shared.args.ctx_size + if max_tokens % 256 != 0: + adjusted_tokens = ((max_tokens // 256) + 1) * 256 + logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}") + max_tokens = adjusted_tokens + + # Parse cache type (ExLlamaV2 pattern) + cache_type = shared.args.cache_type.lower() + cache_kwargs = {} + if cache_type == 'fp16': + layer_type = CacheLayer_fp16 + elif cache_type.startswith('q'): + layer_type = CacheLayer_quant + if '_' in cache_type: + # Different bits for k and v (e.g., q4_q8) + k_part, v_part = cache_type.split('_') + k_bits = int(k_part[1:]) + v_bits = int(v_part[1:]) + else: + # Same bits for k and v (e.g., q4) + k_bits = v_bits = int(cache_type[1:]) + + # Validate bit ranges + if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8): + logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.") + layer_type = CacheLayer_fp16 + else: + cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits} + else: + logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.") + layer_type = CacheLayer_fp16 + + cache = Cache(model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) + + load_params = {'progressbar': True} + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + load_params['use_per_device'] = split + + model.load(**load_params) + + tokenizer = Tokenizer.from_config(config) + + # Load vision model component (ExLlamaV3 native) + vision_model = None + try: + logger.info("Loading vision model component...") + vision_model = Model.from_config(config, component="vision") + vision_model.load(progressbar=True) + logger.info("Vision model loaded successfully") + except Exception as e: + logger.warning(f"Vision model loading failed (multimodal disabled): {e}") + + generator = Generator( + model=model, + cache=cache, + tokenizer=tokenizer, + ) + + result = cls() + result.model = model + result.cache = cache + result.tokenizer = tokenizer + result.generator = generator + result.config = config + result.max_tokens = max_tokens + result.vision_model = vision_model + + return result + + def is_multimodal(self) -> bool: + """Check if this model supports multimodal input.""" + return hasattr(self, 'vision_model') and self.vision_model is not None + + def _process_images_for_generation(self, prompt: str, state: dict) -> Tuple[str, List[Any]]: + """ + Process all possible image inputs and return modified prompt + embeddings. + Returns: (processed_prompt, image_embeddings) + """ + if not self.is_multimodal(): + return prompt, [] + + # Collect images from various sources using shared utilities + pil_images = [] + + # From webui image_attachments (preferred format) + if 'image_attachments' in state and state['image_attachments']: + pil_images.extend(convert_image_attachments_to_pil(state['image_attachments'])) + + # From OpenAI API raw_images + elif 'raw_images' in state and state['raw_images']: + pil_images.extend(state['raw_images']) + + # From OpenAI API messages format + elif 'messages' in state and state['messages']: + pil_images.extend(convert_openai_messages_to_images(state['messages'])) + + if not pil_images: + return prompt, [] + + # ExLlamaV3-specific: Generate embeddings + try: + # Use pre-computed embeddings if available (proper MMEmbedding lifetime) + if 'image_embeddings' in state and state['image_embeddings']: + # Use existing embeddings - this preserves MMEmbedding lifetime + image_embeddings = state['image_embeddings'] + else: + # Do not reset the cache/allocator index; it causes token ID conflicts during generation. + + logger.info(f"Processing {len(pil_images)} image(s) with ExLlamaV3 vision model") + image_embeddings = [ + self.vision_model.get_image_embeddings(tokenizer=self.tokenizer, image=img) + for img in pil_images + ] + + # ExLlamaV3-specific: Handle prompt processing with placeholders + placeholders = [ie.text_alias for ie in image_embeddings] + + if '<__media__>' in prompt: + # Web chat: Replace <__media__> placeholders + for alias in placeholders: + prompt = prompt.replace('<__media__>', alias, 1) + logger.info(f"Replaced {len(placeholders)} <__media__> placeholder(s)") + else: + # API: Prepend embedding aliases + combined_placeholders = "\n".join(placeholders) + prompt = combined_placeholders + "\n" + prompt + logger.info(f"Prepended {len(placeholders)} embedding(s) to prompt") + + return prompt, image_embeddings + + except Exception as e: + logger.error(f"Failed to process images: {e}") + return prompt, [] + + def generate_with_streaming(self, prompt, state): + """ + Generate text with streaming using native ExLlamaV3 API + """ + from exllamav3 import Job + from exllamav3.generator.sampler.presets import ComboSampler + + # Process images and modify prompt (ExLlamaV3-specific) + prompt, image_embeddings = self._process_images_for_generation(prompt, state) + + sampler = ComboSampler( + rep_p=state.get('repetition_penalty', 1.0), + freq_p=state.get('frequency_penalty', 0.0), + pres_p=state.get('presence_penalty', 0.0), + temperature=state.get('temperature', 0.7), + min_p=state.get('min_p', 0.0), + top_k=state.get('top_k', 0), + top_p=state.get('top_p', 1.0), + ) + + # Encode prompt with embeddings (ExLlamaV3-specific) + if image_embeddings: + input_ids = self.tokenizer.encode( + prompt, + encode_special_tokens=True, + embeddings=image_embeddings, + ) + else: + input_ids = self.tokenizer.encode(prompt, encode_special_tokens=True) + + # Get stop conditions from state (webui format) - keep as strings like ExLlamaV3 examples + stop_conditions = [] + if 'stopping_strings' in state and state['stopping_strings']: + # Use strings directly (ExLlamaV3 handles the conversion internally) + stop_conditions.extend(state['stopping_strings']) + + # Add EOS token ID as ExLlamaV3 examples do + if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: + stop_conditions.append(self.tokenizer.eos_token_id) + + job = Job( + input_ids=input_ids, + max_new_tokens=state.get('max_new_tokens', 500), + decode_special_tokens=True, + embeddings=image_embeddings if image_embeddings else None, + sampler=sampler, + stop_conditions=stop_conditions if stop_conditions else None, + ) + + # Stream generation + self.generator.enqueue(job) + + response_text = "" + try: + while self.generator.num_remaining_jobs(): + results = self.generator.iterate() + for result in results: + if "eos" in result and result["eos"]: + break + + chunk = result.get("text", "") + if chunk: + response_text += chunk + yield response_text + finally: + # No cleanup needed. MMEmbedding lifetime is managed by Python. + # Cache and page table resets are unnecessary and can cause token ID conflicts. + pass + + def generate(self, prompt, state): + """ + Generate text using native ExLlamaV3 API (non-streaming) + """ + output = self.generator.generate( + prompt=prompt, + max_new_tokens=state.get('max_new_tokens', 500), + temperature=state.get('temperature', 0.7), + top_p=state.get('top_p', 1.0), + top_k=state.get('top_k', 0), + repetition_penalty=state.get('repetition_penalty', 1.0), + frequency_penalty=state.get('frequency_penalty', 0.0), + presence_penalty=state.get('presence_penalty', 0.0), + min_p=state.get('min_p', 0.0), + ) + + return output + + def encode(self, string, **kwargs): + return self.tokenizer.encode(string, **kwargs) + + def decode(self, ids, **kwargs): + return self.tokenizer.decode(ids, **kwargs) + + @property + def last_prompt_token_count(self): + # This would need to be tracked during generation + return 0 + + def unload(self): + logger.info("Unloading ExLlamaV3 model components...") + + if hasattr(self, 'vision_model') and self.vision_model is not None: + try: + del self.vision_model + except Exception as e: + logger.warning(f"Error unloading vision model: {e}") + self.vision_model = None + + if hasattr(self, 'model') and self.model is not None: + try: + self.model.unload() + del self.model + except Exception as e: + logger.warning(f"Error unloading main model: {e}") + self.model = None + + if hasattr(self, 'cache') and self.cache is not None: + self.cache = None + + if hasattr(self, 'generator') and self.generator is not None: + self.generator = None + + if hasattr(self, 'tokenizer') and self.tokenizer is not None: + self.tokenizer = None + + # Force GPU memory cleanup + import gc + + import torch + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.empty_cache() + + logger.info("ExLlamaV3 model fully unloaded") diff --git a/modules/html_generator.py b/modules/html_generator.py index 79237f7f..63a0cdd0 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -406,16 +406,27 @@ def format_message_attachments(history, role, index): for attachment in attachments: name = html.escape(attachment["name"]) - # Make clickable if URL exists - if "url" in attachment: - name = f'{name}' + if attachment.get("type") == "image": + # Show image preview + file_path = attachment.get("file_path", "") + attachments_html += ( + f'
' + ) + else: + # Make clickable if URL exists (web search) + if "url" in attachment: + name = f'{name}' + + attachments_html += ( + f'' + ) - attachments_html += ( - f'' - ) attachments_html += '' return attachments_html diff --git a/modules/loaders.py b/modules/loaders.py index 7546bc5b..e9437c16 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -55,6 +55,11 @@ loaders_and_params = OrderedDict({ 'trust_remote_code', 'no_use_fast', ], + 'ExLlamav3': [ + 'ctx_size', + 'cache_type', + 'gpu_split', + ], 'ExLlamav2_HF': [ 'ctx_size', 'cache_type', @@ -248,6 +253,41 @@ loaders_samplers = { 'grammar_string', 'grammar_file_row', }, + 'ExLlamav3': { + 'temperature', + 'dynatemp_low', + 'dynatemp_high', + 'dynatemp_exponent', + 'smoothing_factor', + 'min_p', + 'top_p', + 'top_k', + 'typical_p', + 'xtc_threshold', + 'xtc_probability', + 'tfs', + 'top_a', + 'dry_multiplier', + 'dry_allowed_length', + 'dry_base', + 'repetition_penalty', + 'frequency_penalty', + 'presence_penalty', + 'repetition_penalty_range', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'dynamic_temperature', + 'temperature_last', + 'auto_max_new_tokens', + 'ban_eos_token', + 'add_bos_token', + 'enable_thinking', + 'skip_special_tokens', + 'seed', + 'custom_token_bans', + 'dry_sequence_breakers', + }, 'ExLlamav2': { 'temperature', 'dynatemp_low', diff --git a/modules/models.py b/modules/models.py index c1e7fb56..cc500a40 100644 --- a/modules/models.py +++ b/modules/models.py @@ -19,6 +19,7 @@ def load_model(model_name, loader=None): 'llama.cpp': llama_cpp_server_loader, 'Transformers': transformers_loader, 'ExLlamav3_HF': ExLlamav3_HF_loader, + 'ExLlamav3': ExLlamav3_loader, 'ExLlamav2_HF': ExLlamav2_HF_loader, 'ExLlamav2': ExLlamav2_loader, 'TensorRT-LLM': TensorRT_LLM_loader, @@ -88,6 +89,14 @@ def ExLlamav3_HF_loader(model_name): return Exllamav3HF.from_pretrained(model_name) +def ExLlamav3_loader(model_name): + from modules.exllamav3 import Exllamav3Model + + model = Exllamav3Model.from_pretrained(model_name) + tokenizer = model.tokenizer + return model, tokenizer + + def ExLlamav2_HF_loader(model_name): from modules.exllamav2_hf import Exllamav2HF @@ -116,7 +125,9 @@ def unload_model(keep_model_name=False): return is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer') - if shared.model.__class__.__name__ == 'Exllamav3HF': + if shared.args.loader in ['ExLlamav3_HF', 'ExLlamav3']: + shared.model.unload() + elif shared.args.loader in ['ExLlamav2_HF', 'ExLlamav2'] and hasattr(shared.model, 'unload'): shared.model.unload() shared.model = shared.tokenizer = None diff --git a/modules/shared.py b/modules/shared.py index ab5198d1..1de4306b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -318,6 +318,8 @@ def fix_loader_name(name): return 'ExLlamav2_HF' elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']: return 'ExLlamav3_HF' + elif name in ['exllamav3']: + return 'ExLlamav3' elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']: return 'TensorRT-LLM' diff --git a/modules/text_generation.py b/modules/text_generation.py index 8d1950b9..d6a87ce8 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -40,7 +40,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield '' return - if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']: + if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF @@ -128,9 +128,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt from modules.torch_utils import get_device - if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel']: + if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']: input_ids = shared.tokenizer.encode(str(prompt)) - if shared.model.__class__.__name__ != 'Exllamav2Model': + if shared.model.__class__.__name__ not in ['Exllamav2Model', 'Exllamav3Model']: input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) @@ -148,7 +148,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu: return input_ids else: device = get_device() @@ -295,6 +295,8 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None, _StopEverythingStoppingCriteria ) + # Native ExLlamav3Model handles multimodal internally - no special routing needed + if shared.args.loader == 'Transformers': clear_torch_cache() diff --git a/modules/ui_chat.py b/modules/ui_chat.py index 1d85a398..3b922fb4 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -54,7 +54,7 @@ def create_ui(): gr.HTML(value='