diff --git a/modules/exllamav3.py b/modules/exllamav3.py index 66e25693..e580bbda 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -177,9 +177,6 @@ class Exllamav3Model: Process all possible image inputs and return modified prompt + embeddings. Returns: (processed_prompt, image_embeddings) """ - if not self.is_multimodal(): - return prompt, [] - # Collect images from various sources using shared utilities pil_images = [] @@ -234,8 +231,11 @@ class Exllamav3Model: """ Generate text with streaming using native ExLlamaV3 API """ - # Process images and modify prompt (ExLlamaV3-specific) - prompt, image_embeddings = self._process_images_for_generation(prompt, state) + image_embeddings = [] + + if shared.is_multimodal: + # Process images and modify prompt (ExLlamaV3-specific) + prompt, image_embeddings = self._process_images_for_generation(prompt, state) # Greedy decoding is a special case if state['temperature'] == 0: diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index e82edb90..5953803a 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -8,6 +8,7 @@ import sys import threading import time from pathlib import Path +from typing import Any, List import llama_cpp_binaries import requests @@ -129,10 +130,10 @@ class LlamaServer: return payload - def generate_with_streaming(self, prompt, state): - url = f"http://127.0.0.1:{self.port}/completion" - payload = self.prepare_payload(state) - + def _process_images_for_generation(self, state: dict) -> List[Any]: + """ + Process all possible image inputs and return PIL images + """ pil_images = [] # Source 1: Web UI (from chatbot_wrapper) if 'image_attachments' in state and state['image_attachments']: @@ -144,6 +145,21 @@ class LlamaServer: elif 'raw_images' in state and state['raw_images']: pil_images.extend(state.get('raw_images', [])) + return pil_images + + def is_multimodal(self) -> bool: + """Check if this model supports multimodal input.""" + return shared.args.mmproj not in [None, 'None'] + + def generate_with_streaming(self, prompt, state): + url = f"http://127.0.0.1:{self.port}/completion" + payload = self.prepare_payload(state) + + pil_images = [] + + if shared.is_multimodal: + pil_images = self._process_images_for_generation(state) + if pil_images: # Multimodal case IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image diff --git a/modules/models.py b/modules/models.py index cc500a40..938eed3d 100644 --- a/modules/models.py +++ b/modules/models.py @@ -55,6 +55,10 @@ def load_model(model_name, loader=None): if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp': shared.settings['truncation_length'] = shared.args.ctx_size + shared.is_multimodal = False + if loader.lower() in ('exllamav3', 'llama.cpp'): + shared.is_multimodal = model.is_multimodal() + logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.") logger.info(f"LOADER: \"{loader}\"") logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}") diff --git a/modules/shared.py b/modules/shared.py index e9d8a62f..a1f4571e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,6 +16,7 @@ model = None tokenizer = None model_name = 'None' is_seq2seq = False +is_multimodal = False model_dirty_from_training = False lora_names = []