mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Set multimodal status during Model Loading (#7199)
This commit is contained in:
parent
725a8bcf60
commit
57f6e9af5a
|
|
@ -177,9 +177,6 @@ class Exllamav3Model:
|
||||||
Process all possible image inputs and return modified prompt + embeddings.
|
Process all possible image inputs and return modified prompt + embeddings.
|
||||||
Returns: (processed_prompt, image_embeddings)
|
Returns: (processed_prompt, image_embeddings)
|
||||||
"""
|
"""
|
||||||
if not self.is_multimodal():
|
|
||||||
return prompt, []
|
|
||||||
|
|
||||||
# Collect images from various sources using shared utilities
|
# Collect images from various sources using shared utilities
|
||||||
pil_images = []
|
pil_images = []
|
||||||
|
|
||||||
|
|
@ -234,8 +231,11 @@ class Exllamav3Model:
|
||||||
"""
|
"""
|
||||||
Generate text with streaming using native ExLlamaV3 API
|
Generate text with streaming using native ExLlamaV3 API
|
||||||
"""
|
"""
|
||||||
# Process images and modify prompt (ExLlamaV3-specific)
|
image_embeddings = []
|
||||||
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
|
|
||||||
|
if shared.is_multimodal:
|
||||||
|
# Process images and modify prompt (ExLlamaV3-specific)
|
||||||
|
prompt, image_embeddings = self._process_images_for_generation(prompt, state)
|
||||||
|
|
||||||
# Greedy decoding is a special case
|
# Greedy decoding is a special case
|
||||||
if state['temperature'] == 0:
|
if state['temperature'] == 0:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
import llama_cpp_binaries
|
import llama_cpp_binaries
|
||||||
import requests
|
import requests
|
||||||
|
|
@ -129,10 +130,10 @@ class LlamaServer:
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def _process_images_for_generation(self, state: dict) -> List[Any]:
|
||||||
url = f"http://127.0.0.1:{self.port}/completion"
|
"""
|
||||||
payload = self.prepare_payload(state)
|
Process all possible image inputs and return PIL images
|
||||||
|
"""
|
||||||
pil_images = []
|
pil_images = []
|
||||||
# Source 1: Web UI (from chatbot_wrapper)
|
# Source 1: Web UI (from chatbot_wrapper)
|
||||||
if 'image_attachments' in state and state['image_attachments']:
|
if 'image_attachments' in state and state['image_attachments']:
|
||||||
|
|
@ -144,6 +145,21 @@ class LlamaServer:
|
||||||
elif 'raw_images' in state and state['raw_images']:
|
elif 'raw_images' in state and state['raw_images']:
|
||||||
pil_images.extend(state.get('raw_images', []))
|
pil_images.extend(state.get('raw_images', []))
|
||||||
|
|
||||||
|
return pil_images
|
||||||
|
|
||||||
|
def is_multimodal(self) -> bool:
|
||||||
|
"""Check if this model supports multimodal input."""
|
||||||
|
return shared.args.mmproj not in [None, 'None']
|
||||||
|
|
||||||
|
def generate_with_streaming(self, prompt, state):
|
||||||
|
url = f"http://127.0.0.1:{self.port}/completion"
|
||||||
|
payload = self.prepare_payload(state)
|
||||||
|
|
||||||
|
pil_images = []
|
||||||
|
|
||||||
|
if shared.is_multimodal:
|
||||||
|
pil_images = self._process_images_for_generation(state)
|
||||||
|
|
||||||
if pil_images:
|
if pil_images:
|
||||||
# Multimodal case
|
# Multimodal case
|
||||||
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
|
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,10 @@ def load_model(model_name, loader=None):
|
||||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
||||||
shared.settings['truncation_length'] = shared.args.ctx_size
|
shared.settings['truncation_length'] = shared.args.ctx_size
|
||||||
|
|
||||||
|
shared.is_multimodal = False
|
||||||
|
if loader.lower() in ('exllamav3', 'llama.cpp'):
|
||||||
|
shared.is_multimodal = model.is_multimodal()
|
||||||
|
|
||||||
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
||||||
logger.info(f"LOADER: \"{loader}\"")
|
logger.info(f"LOADER: \"{loader}\"")
|
||||||
logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}")
|
logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}")
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
model_name = 'None'
|
model_name = 'None'
|
||||||
is_seq2seq = False
|
is_seq2seq = False
|
||||||
|
is_multimodal = False
|
||||||
model_dirty_from_training = False
|
model_dirty_from_training = False
|
||||||
lora_names = []
|
lora_names = []
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue