Add multimodal support (llama.cpp) (#7027)

This commit is contained in:
oobabooga 2025-08-10 01:27:25 -03:00 committed by GitHub
parent eb16f64017
commit d86b0ec010
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 86 additions and 18 deletions

View file

@ -12,6 +12,10 @@ from pathlib import Path
import llama_cpp_binaries
import requests
from extensions.openai.image_utils import (
convert_image_attachments_to_pil,
convert_pil_to_base64
)
from modules import shared
from modules.logging_colors import logger
@ -128,15 +132,40 @@ class LlamaServer:
url = f"http://127.0.0.1:{self.port}/completion"
payload = self.prepare_payload(state)
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
self.last_prompt_token_count = len(token_ids)
pil_images = []
# Check for images from the Web UI (image_attachments)
if 'image_attachments' in state and state['image_attachments']:
pil_images.extend(convert_image_attachments_to_pil(state['image_attachments']))
# Else, check for images from the API (raw_images)
elif 'raw_images' in state and state['raw_images']:
pil_images.extend(state.get('raw_images', []))
if pil_images:
# Multimodal case
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
base64_images = [convert_pil_to_base64(img) for img in pil_images]
multimodal_prompt_object = {
"prompt": prompt,
"multimodal_data": base64_images
}
payload["prompt"] = multimodal_prompt_object
# Calculate an estimated token count
text_tokens = self.encode(prompt, add_bos_token=state["add_bos_token"])
self.last_prompt_token_count = len(text_tokens) + (len(pil_images) * IMAGE_TOKEN_COST_ESTIMATE)
else:
# Text only case
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
self.last_prompt_token_count = len(token_ids)
payload["prompt"] = token_ids
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - len(token_ids)
max_new_tokens = state['truncation_length'] - self.last_prompt_token_count
else:
max_new_tokens = state['max_new_tokens']
payload.update({
"prompt": token_ids,
"n_predict": max_new_tokens,
"stream": True,
"cache_prompt": True
@ -144,7 +173,7 @@ class LlamaServer:
if shared.args.verbose:
logger.info("GENERATE_PARAMS=")
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
printable_payload = {k: (v if k != "prompt" else "[multimodal object]" if pil_images else v) for k, v in payload.items()}
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
print()
@ -295,6 +324,13 @@ class LlamaServer:
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
if shared.args.rope_freq_base > 0:
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
if shared.args.mmproj not in [None, 'None']:
path = Path(shared.args.mmproj)
if not path.exists():
path = Path('user_data/mmproj') / shared.args.mmproj
if path.exists():
cmd += ["--mmproj", str(path)]
if shared.args.model_draft not in [None, 'None']:
path = Path(shared.args.model_draft)
if not path.exists():