mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add multimodal support (llama.cpp) (#7027)
This commit is contained in:
parent
eb16f64017
commit
d86b0ec010
|
|
@ -11,6 +11,15 @@ from PIL import Image
|
|||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def convert_pil_to_base64(image: Image.Image) -> str:
|
||||
"""Converts a PIL Image to a base64 encoded string."""
|
||||
buffered = io.BytesIO()
|
||||
# Save image to an in-memory bytes buffer in PNG format
|
||||
image.save(buffered, format="PNG")
|
||||
# Encode the bytes to a base64 string
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
def decode_base64_image(base64_string: str) -> Image.Image:
|
||||
"""Decodes a base64 string to a PIL Image."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -813,19 +813,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
for file_path in files:
|
||||
add_message_attachment(output, row_idx, file_path, is_user=True)
|
||||
|
||||
# Collect image attachments for multimodal generation
|
||||
image_attachments = []
|
||||
if 'metadata' in output:
|
||||
user_key = f"user_{row_idx}"
|
||||
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
|
||||
for attachment in output['metadata'][user_key]["attachments"]:
|
||||
if attachment.get("type") == "image":
|
||||
image_attachments.append(attachment)
|
||||
|
||||
# Add image attachments to state for the generation
|
||||
if image_attachments:
|
||||
state['image_attachments'] = image_attachments
|
||||
|
||||
# Add web search results as attachments if enabled
|
||||
if state.get('enable_web_search', False):
|
||||
search_query = generate_search_query(text, state)
|
||||
|
|
@ -881,6 +868,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
'metadata': output['metadata']
|
||||
}
|
||||
|
||||
# Collect image attachments for multimodal generation
|
||||
image_attachments = []
|
||||
if 'metadata' in output:
|
||||
user_key = f"user_{row_idx}"
|
||||
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
|
||||
for attachment in output['metadata'][user_key]["attachments"]:
|
||||
if attachment.get("type") == "image":
|
||||
image_attachments.append(attachment)
|
||||
|
||||
# Add image attachments to state for the generation
|
||||
if image_attachments:
|
||||
state['image_attachments'] = image_attachments
|
||||
|
||||
# Generate the prompt
|
||||
kwargs = {
|
||||
'_continue': _continue,
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ from pathlib import Path
|
|||
import llama_cpp_binaries
|
||||
import requests
|
||||
|
||||
from extensions.openai.image_utils import (
|
||||
convert_image_attachments_to_pil,
|
||||
convert_pil_to_base64
|
||||
)
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
|
@ -128,15 +132,40 @@ class LlamaServer:
|
|||
url = f"http://127.0.0.1:{self.port}/completion"
|
||||
payload = self.prepare_payload(state)
|
||||
|
||||
pil_images = []
|
||||
# Check for images from the Web UI (image_attachments)
|
||||
if 'image_attachments' in state and state['image_attachments']:
|
||||
pil_images.extend(convert_image_attachments_to_pil(state['image_attachments']))
|
||||
# Else, check for images from the API (raw_images)
|
||||
elif 'raw_images' in state and state['raw_images']:
|
||||
pil_images.extend(state.get('raw_images', []))
|
||||
|
||||
if pil_images:
|
||||
# Multimodal case
|
||||
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
|
||||
|
||||
base64_images = [convert_pil_to_base64(img) for img in pil_images]
|
||||
multimodal_prompt_object = {
|
||||
"prompt": prompt,
|
||||
"multimodal_data": base64_images
|
||||
}
|
||||
payload["prompt"] = multimodal_prompt_object
|
||||
|
||||
# Calculate an estimated token count
|
||||
text_tokens = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||
self.last_prompt_token_count = len(text_tokens) + (len(pil_images) * IMAGE_TOKEN_COST_ESTIMATE)
|
||||
else:
|
||||
# Text only case
|
||||
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||
self.last_prompt_token_count = len(token_ids)
|
||||
payload["prompt"] = token_ids
|
||||
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - len(token_ids)
|
||||
max_new_tokens = state['truncation_length'] - self.last_prompt_token_count
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
payload.update({
|
||||
"prompt": token_ids,
|
||||
"n_predict": max_new_tokens,
|
||||
"stream": True,
|
||||
"cache_prompt": True
|
||||
|
|
@ -144,7 +173,7 @@ class LlamaServer:
|
|||
|
||||
if shared.args.verbose:
|
||||
logger.info("GENERATE_PARAMS=")
|
||||
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
|
||||
printable_payload = {k: (v if k != "prompt" else "[multimodal object]" if pil_images else v) for k, v in payload.items()}
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
||||
print()
|
||||
|
||||
|
|
@ -295,6 +324,13 @@ class LlamaServer:
|
|||
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
||||
if shared.args.rope_freq_base > 0:
|
||||
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
|
||||
if shared.args.mmproj not in [None, 'None']:
|
||||
path = Path(shared.args.mmproj)
|
||||
if not path.exists():
|
||||
path = Path('user_data/mmproj') / shared.args.mmproj
|
||||
|
||||
if path.exists():
|
||||
cmd += ["--mmproj", str(path)]
|
||||
if shared.args.model_draft not in [None, 'None']:
|
||||
path = Path(shared.args.model_draft)
|
||||
if not path.exists():
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ loaders_and_params = OrderedDict({
|
|||
'device_draft',
|
||||
'ctx_size_draft',
|
||||
'speculative_decoding_accordion',
|
||||
'mmproj',
|
||||
'mmproj_accordion',
|
||||
'vram_info',
|
||||
],
|
||||
'Transformers': [
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ group.add_argument('--no-kv-offload', action='store_true', help='Do not offload
|
|||
group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||
group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.')
|
||||
|
||||
# Cache
|
||||
group = parser.add_argument_group('Context and cache')
|
||||
|
|
|
|||
|
|
@ -167,6 +167,7 @@ def list_model_elements():
|
|||
'gpu_layers_draft',
|
||||
'device_draft',
|
||||
'ctx_size_draft',
|
||||
'mmproj',
|
||||
]
|
||||
|
||||
return elements
|
||||
|
|
|
|||
|
|
@ -59,6 +59,12 @@ def create_ui():
|
|||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
||||
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
||||
|
||||
# Multimodal
|
||||
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
|
||||
with gr.Row():
|
||||
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
|
||||
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
|
||||
|
||||
# Speculative decoding
|
||||
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
||||
with gr.Row():
|
||||
|
|
|
|||
|
|
@ -154,6 +154,19 @@ def get_available_ggufs():
|
|||
return sorted(model_list, key=natural_keys)
|
||||
|
||||
|
||||
def get_available_mmproj():
|
||||
mmproj_dir = Path('user_data/mmproj')
|
||||
if not mmproj_dir.exists():
|
||||
return ['None']
|
||||
|
||||
mmproj_files = []
|
||||
for item in mmproj_dir.iterdir():
|
||||
if item.is_file() and item.suffix.lower() in ('.gguf', '.bin'):
|
||||
mmproj_files.append(item.name)
|
||||
|
||||
return ['None'] + sorted(mmproj_files, key=natural_keys)
|
||||
|
||||
|
||||
def get_available_presets():
|
||||
return sorted(set((k.stem for k in Path('user_data/presets').glob('*.yaml'))), key=natural_keys)
|
||||
|
||||
|
|
|
|||
0
user_data/mmproj/place-your-mmproj-here.txt
Normal file
0
user_data/mmproj/place-your-mmproj-here.txt
Normal file
Loading…
Reference in a new issue