Add multimodal support (llama.cpp) (#7027)

This commit is contained in:
oobabooga 2025-08-10 01:27:25 -03:00 committed by GitHub
parent eb16f64017
commit d86b0ec010
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 86 additions and 18 deletions

View file

@ -11,6 +11,15 @@ from PIL import Image
from modules.logging_colors import logger
def convert_pil_to_base64(image: Image.Image) -> str:
"""Converts a PIL Image to a base64 encoded string."""
buffered = io.BytesIO()
# Save image to an in-memory bytes buffer in PNG format
image.save(buffered, format="PNG")
# Encode the bytes to a base64 string
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def decode_base64_image(base64_string: str) -> Image.Image:
"""Decodes a base64 string to a PIL Image."""
try:

View file

@ -813,19 +813,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
for file_path in files:
add_message_attachment(output, row_idx, file_path, is_user=True)
# Collect image attachments for multimodal generation
image_attachments = []
if 'metadata' in output:
user_key = f"user_{row_idx}"
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
for attachment in output['metadata'][user_key]["attachments"]:
if attachment.get("type") == "image":
image_attachments.append(attachment)
# Add image attachments to state for the generation
if image_attachments:
state['image_attachments'] = image_attachments
# Add web search results as attachments if enabled
if state.get('enable_web_search', False):
search_query = generate_search_query(text, state)
@ -881,6 +868,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
'metadata': output['metadata']
}
# Collect image attachments for multimodal generation
image_attachments = []
if 'metadata' in output:
user_key = f"user_{row_idx}"
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
for attachment in output['metadata'][user_key]["attachments"]:
if attachment.get("type") == "image":
image_attachments.append(attachment)
# Add image attachments to state for the generation
if image_attachments:
state['image_attachments'] = image_attachments
# Generate the prompt
kwargs = {
'_continue': _continue,

View file

@ -12,6 +12,10 @@ from pathlib import Path
import llama_cpp_binaries
import requests
from extensions.openai.image_utils import (
convert_image_attachments_to_pil,
convert_pil_to_base64
)
from modules import shared
from modules.logging_colors import logger
@ -128,15 +132,40 @@ class LlamaServer:
url = f"http://127.0.0.1:{self.port}/completion"
payload = self.prepare_payload(state)
pil_images = []
# Check for images from the Web UI (image_attachments)
if 'image_attachments' in state and state['image_attachments']:
pil_images.extend(convert_image_attachments_to_pil(state['image_attachments']))
# Else, check for images from the API (raw_images)
elif 'raw_images' in state and state['raw_images']:
pil_images.extend(state.get('raw_images', []))
if pil_images:
# Multimodal case
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
base64_images = [convert_pil_to_base64(img) for img in pil_images]
multimodal_prompt_object = {
"prompt": prompt,
"multimodal_data": base64_images
}
payload["prompt"] = multimodal_prompt_object
# Calculate an estimated token count
text_tokens = self.encode(prompt, add_bos_token=state["add_bos_token"])
self.last_prompt_token_count = len(text_tokens) + (len(pil_images) * IMAGE_TOKEN_COST_ESTIMATE)
else:
# Text only case
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
self.last_prompt_token_count = len(token_ids)
payload["prompt"] = token_ids
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - len(token_ids)
max_new_tokens = state['truncation_length'] - self.last_prompt_token_count
else:
max_new_tokens = state['max_new_tokens']
payload.update({
"prompt": token_ids,
"n_predict": max_new_tokens,
"stream": True,
"cache_prompt": True
@ -144,7 +173,7 @@ class LlamaServer:
if shared.args.verbose:
logger.info("GENERATE_PARAMS=")
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
printable_payload = {k: (v if k != "prompt" else "[multimodal object]" if pil_images else v) for k, v in payload.items()}
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
print()
@ -295,6 +324,13 @@ class LlamaServer:
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
if shared.args.rope_freq_base > 0:
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
if shared.args.mmproj not in [None, 'None']:
path = Path(shared.args.mmproj)
if not path.exists():
path = Path('user_data/mmproj') / shared.args.mmproj
if path.exists():
cmd += ["--mmproj", str(path)]
if shared.args.model_draft not in [None, 'None']:
path = Path(shared.args.model_draft)
if not path.exists():

View file

@ -28,6 +28,8 @@ loaders_and_params = OrderedDict({
'device_draft',
'ctx_size_draft',
'speculative_decoding_accordion',
'mmproj',
'mmproj_accordion',
'vram_info',
],
'Transformers': [

View file

@ -85,6 +85,7 @@ group.add_argument('--no-kv-offload', action='store_true', help='Do not offload
group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.')
# Cache
group = parser.add_argument_group('Context and cache')

View file

@ -167,6 +167,7 @@ def list_model_elements():
'gpu_layers_draft',
'device_draft',
'ctx_size_draft',
'mmproj',
]
return elements

View file

@ -59,6 +59,12 @@ def create_ui():
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
# Multimodal
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
with gr.Row():
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
# Speculative decoding
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
with gr.Row():

View file

@ -154,6 +154,19 @@ def get_available_ggufs():
return sorted(model_list, key=natural_keys)
def get_available_mmproj():
mmproj_dir = Path('user_data/mmproj')
if not mmproj_dir.exists():
return ['None']
mmproj_files = []
for item in mmproj_dir.iterdir():
if item.is_file() and item.suffix.lower() in ('.gguf', '.bin'):
mmproj_files.append(item.name)
return ['None'] + sorted(mmproj_files, key=natural_keys)
def get_available_presets():
return sorted(set((k.stem for k in Path('user_data/presets').glob('*.yaml'))), key=natural_keys)