mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
mtmd: Fail early if images are provided but the model doesn't support them (llama.cpp)
This commit is contained in:
parent
e6447cd24a
commit
d8fcc71616
|
|
@ -34,6 +34,7 @@ class LlamaServer:
|
||||||
self.process = None
|
self.process = None
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
self.vocabulary_size = None
|
self.vocabulary_size = None
|
||||||
|
self.has_multimodal = False
|
||||||
self.bos_token = "<s>"
|
self.bos_token = "<s>"
|
||||||
self.last_prompt_token_count = 0
|
self.last_prompt_token_count = 0
|
||||||
|
|
||||||
|
|
@ -144,6 +145,10 @@ class LlamaServer:
|
||||||
elif 'raw_images' in state and state['raw_images']:
|
elif 'raw_images' in state and state['raw_images']:
|
||||||
pil_images.extend(state.get('raw_images', []))
|
pil_images.extend(state.get('raw_images', []))
|
||||||
|
|
||||||
|
# Fail early if images are provided but the model doesn't support them
|
||||||
|
if pil_images and not self.has_multimodal:
|
||||||
|
raise RuntimeError("The loaded llama.cpp model does not support multimodal requests. You must load a vision model and provide an mmproj file.")
|
||||||
|
|
||||||
if pil_images:
|
if pil_images:
|
||||||
# Multimodal case
|
# Multimodal case
|
||||||
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
|
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
|
||||||
|
|
@ -261,8 +266,8 @@ class LlamaServer:
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
||||||
|
|
||||||
def _get_vocabulary_size(self):
|
def _get_model_properties(self):
|
||||||
"""Get and store the model's maximum context length."""
|
"""Get and store the model's properties, including vocab size and multimodal capability."""
|
||||||
url = f"http://127.0.0.1:{self.port}/v1/models"
|
url = f"http://127.0.0.1:{self.port}/v1/models"
|
||||||
response = self.session.get(url).json()
|
response = self.session.get(url).json()
|
||||||
|
|
||||||
|
|
@ -271,6 +276,10 @@ class LlamaServer:
|
||||||
if "meta" in model_info and "n_vocab" in model_info["meta"]:
|
if "meta" in model_info and "n_vocab" in model_info["meta"]:
|
||||||
self.vocabulary_size = model_info["meta"]["n_vocab"]
|
self.vocabulary_size = model_info["meta"]["n_vocab"]
|
||||||
|
|
||||||
|
# Check for multimodal capability
|
||||||
|
if "capabilities" in model_info and "multimodal" in model_info["capabilities"]:
|
||||||
|
self.has_multimodal = True
|
||||||
|
|
||||||
def _get_bos_token(self):
|
def _get_bos_token(self):
|
||||||
"""Get and store the model's BOS token."""
|
"""Get and store the model's BOS token."""
|
||||||
url = f"http://127.0.0.1:{self.port}/props"
|
url = f"http://127.0.0.1:{self.port}/props"
|
||||||
|
|
@ -421,7 +430,7 @@ class LlamaServer:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
# Server is now healthy, get model info
|
# Server is now healthy, get model info
|
||||||
self._get_vocabulary_size()
|
self._get_model_properties()
|
||||||
self._get_bos_token()
|
self._get_bos_token()
|
||||||
return self.port
|
return self.port
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue