Set multimodal status during Model Loading (#7199)

This commit is contained in:
altoiddealer 2025-08-13 15:47:27 -04:00 committed by GitHub
parent 725a8bcf60
commit 57f6e9af5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 30 additions and 9 deletions

View file

@ -177,9 +177,6 @@ class Exllamav3Model:
Process all possible image inputs and return modified prompt + embeddings. Process all possible image inputs and return modified prompt + embeddings.
Returns: (processed_prompt, image_embeddings) Returns: (processed_prompt, image_embeddings)
""" """
if not self.is_multimodal():
return prompt, []
# Collect images from various sources using shared utilities # Collect images from various sources using shared utilities
pil_images = [] pil_images = []
@ -234,8 +231,11 @@ class Exllamav3Model:
""" """
Generate text with streaming using native ExLlamaV3 API Generate text with streaming using native ExLlamaV3 API
""" """
# Process images and modify prompt (ExLlamaV3-specific) image_embeddings = []
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
if shared.is_multimodal:
# Process images and modify prompt (ExLlamaV3-specific)
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
# Greedy decoding is a special case # Greedy decoding is a special case
if state['temperature'] == 0: if state['temperature'] == 0:

View file

@ -8,6 +8,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, List
import llama_cpp_binaries import llama_cpp_binaries
import requests import requests
@ -129,10 +130,10 @@ class LlamaServer:
return payload return payload
def generate_with_streaming(self, prompt, state): def _process_images_for_generation(self, state: dict) -> List[Any]:
url = f"http://127.0.0.1:{self.port}/completion" """
payload = self.prepare_payload(state) Process all possible image inputs and return PIL images
"""
pil_images = [] pil_images = []
# Source 1: Web UI (from chatbot_wrapper) # Source 1: Web UI (from chatbot_wrapper)
if 'image_attachments' in state and state['image_attachments']: if 'image_attachments' in state and state['image_attachments']:
@ -144,6 +145,21 @@ class LlamaServer:
elif 'raw_images' in state and state['raw_images']: elif 'raw_images' in state and state['raw_images']:
pil_images.extend(state.get('raw_images', [])) pil_images.extend(state.get('raw_images', []))
return pil_images
def is_multimodal(self) -> bool:
"""Check if this model supports multimodal input."""
return shared.args.mmproj not in [None, 'None']
def generate_with_streaming(self, prompt, state):
url = f"http://127.0.0.1:{self.port}/completion"
payload = self.prepare_payload(state)
pil_images = []
if shared.is_multimodal:
pil_images = self._process_images_for_generation(state)
if pil_images: if pil_images:
# Multimodal case # Multimodal case
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image

View file

@ -55,6 +55,10 @@ def load_model(model_name, loader=None):
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp': if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
shared.settings['truncation_length'] = shared.args.ctx_size shared.settings['truncation_length'] = shared.args.ctx_size
shared.is_multimodal = False
if loader.lower() in ('exllamav3', 'llama.cpp'):
shared.is_multimodal = model.is_multimodal()
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.") logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
logger.info(f"LOADER: \"{loader}\"") logger.info(f"LOADER: \"{loader}\"")
logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}") logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}")

View file

@ -16,6 +16,7 @@ model = None
tokenizer = None tokenizer = None
model_name = 'None' model_name = 'None'
is_seq2seq = False is_seq2seq = False
is_multimodal = False
model_dirty_from_training = False model_dirty_from_training = False
lora_names = [] lora_names = []