mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add multimodal support (llama.cpp) (#7027)
This commit is contained in:
parent
eb16f64017
commit
d86b0ec010
|
|
@ -11,6 +11,15 @@ from PIL import Image
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pil_to_base64(image: Image.Image) -> str:
|
||||||
|
"""Converts a PIL Image to a base64 encoded string."""
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
# Save image to an in-memory bytes buffer in PNG format
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
# Encode the bytes to a base64 string
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_image(base64_string: str) -> Image.Image:
|
def decode_base64_image(base64_string: str) -> Image.Image:
|
||||||
"""Decodes a base64 string to a PIL Image."""
|
"""Decodes a base64 string to a PIL Image."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -813,19 +813,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
||||||
for file_path in files:
|
for file_path in files:
|
||||||
add_message_attachment(output, row_idx, file_path, is_user=True)
|
add_message_attachment(output, row_idx, file_path, is_user=True)
|
||||||
|
|
||||||
# Collect image attachments for multimodal generation
|
|
||||||
image_attachments = []
|
|
||||||
if 'metadata' in output:
|
|
||||||
user_key = f"user_{row_idx}"
|
|
||||||
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
|
|
||||||
for attachment in output['metadata'][user_key]["attachments"]:
|
|
||||||
if attachment.get("type") == "image":
|
|
||||||
image_attachments.append(attachment)
|
|
||||||
|
|
||||||
# Add image attachments to state for the generation
|
|
||||||
if image_attachments:
|
|
||||||
state['image_attachments'] = image_attachments
|
|
||||||
|
|
||||||
# Add web search results as attachments if enabled
|
# Add web search results as attachments if enabled
|
||||||
if state.get('enable_web_search', False):
|
if state.get('enable_web_search', False):
|
||||||
search_query = generate_search_query(text, state)
|
search_query = generate_search_query(text, state)
|
||||||
|
|
@ -881,6 +868,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
||||||
'metadata': output['metadata']
|
'metadata': output['metadata']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Collect image attachments for multimodal generation
|
||||||
|
image_attachments = []
|
||||||
|
if 'metadata' in output:
|
||||||
|
user_key = f"user_{row_idx}"
|
||||||
|
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
|
||||||
|
for attachment in output['metadata'][user_key]["attachments"]:
|
||||||
|
if attachment.get("type") == "image":
|
||||||
|
image_attachments.append(attachment)
|
||||||
|
|
||||||
|
# Add image attachments to state for the generation
|
||||||
|
if image_attachments:
|
||||||
|
state['image_attachments'] = image_attachments
|
||||||
|
|
||||||
# Generate the prompt
|
# Generate the prompt
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'_continue': _continue,
|
'_continue': _continue,
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,10 @@ from pathlib import Path
|
||||||
import llama_cpp_binaries
|
import llama_cpp_binaries
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from extensions.openai.image_utils import (
|
||||||
|
convert_image_attachments_to_pil,
|
||||||
|
convert_pil_to_base64
|
||||||
|
)
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
@ -128,15 +132,40 @@ class LlamaServer:
|
||||||
url = f"http://127.0.0.1:{self.port}/completion"
|
url = f"http://127.0.0.1:{self.port}/completion"
|
||||||
payload = self.prepare_payload(state)
|
payload = self.prepare_payload(state)
|
||||||
|
|
||||||
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
pil_images = []
|
||||||
self.last_prompt_token_count = len(token_ids)
|
# Check for images from the Web UI (image_attachments)
|
||||||
|
if 'image_attachments' in state and state['image_attachments']:
|
||||||
|
pil_images.extend(convert_image_attachments_to_pil(state['image_attachments']))
|
||||||
|
# Else, check for images from the API (raw_images)
|
||||||
|
elif 'raw_images' in state and state['raw_images']:
|
||||||
|
pil_images.extend(state.get('raw_images', []))
|
||||||
|
|
||||||
|
if pil_images:
|
||||||
|
# Multimodal case
|
||||||
|
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
|
||||||
|
|
||||||
|
base64_images = [convert_pil_to_base64(img) for img in pil_images]
|
||||||
|
multimodal_prompt_object = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"multimodal_data": base64_images
|
||||||
|
}
|
||||||
|
payload["prompt"] = multimodal_prompt_object
|
||||||
|
|
||||||
|
# Calculate an estimated token count
|
||||||
|
text_tokens = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||||
|
self.last_prompt_token_count = len(text_tokens) + (len(pil_images) * IMAGE_TOKEN_COST_ESTIMATE)
|
||||||
|
else:
|
||||||
|
# Text only case
|
||||||
|
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||||
|
self.last_prompt_token_count = len(token_ids)
|
||||||
|
payload["prompt"] = token_ids
|
||||||
|
|
||||||
if state['auto_max_new_tokens']:
|
if state['auto_max_new_tokens']:
|
||||||
max_new_tokens = state['truncation_length'] - len(token_ids)
|
max_new_tokens = state['truncation_length'] - self.last_prompt_token_count
|
||||||
else:
|
else:
|
||||||
max_new_tokens = state['max_new_tokens']
|
max_new_tokens = state['max_new_tokens']
|
||||||
|
|
||||||
payload.update({
|
payload.update({
|
||||||
"prompt": token_ids,
|
|
||||||
"n_predict": max_new_tokens,
|
"n_predict": max_new_tokens,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"cache_prompt": True
|
"cache_prompt": True
|
||||||
|
|
@ -144,7 +173,7 @@ class LlamaServer:
|
||||||
|
|
||||||
if shared.args.verbose:
|
if shared.args.verbose:
|
||||||
logger.info("GENERATE_PARAMS=")
|
logger.info("GENERATE_PARAMS=")
|
||||||
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
|
printable_payload = {k: (v if k != "prompt" else "[multimodal object]" if pil_images else v) for k, v in payload.items()}
|
||||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
@ -295,6 +324,13 @@ class LlamaServer:
|
||||||
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
||||||
if shared.args.rope_freq_base > 0:
|
if shared.args.rope_freq_base > 0:
|
||||||
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
|
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
|
||||||
|
if shared.args.mmproj not in [None, 'None']:
|
||||||
|
path = Path(shared.args.mmproj)
|
||||||
|
if not path.exists():
|
||||||
|
path = Path('user_data/mmproj') / shared.args.mmproj
|
||||||
|
|
||||||
|
if path.exists():
|
||||||
|
cmd += ["--mmproj", str(path)]
|
||||||
if shared.args.model_draft not in [None, 'None']:
|
if shared.args.model_draft not in [None, 'None']:
|
||||||
path = Path(shared.args.model_draft)
|
path = Path(shared.args.model_draft)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,8 @@ loaders_and_params = OrderedDict({
|
||||||
'device_draft',
|
'device_draft',
|
||||||
'ctx_size_draft',
|
'ctx_size_draft',
|
||||||
'speculative_decoding_accordion',
|
'speculative_decoding_accordion',
|
||||||
|
'mmproj',
|
||||||
|
'mmproj_accordion',
|
||||||
'vram_info',
|
'vram_info',
|
||||||
],
|
],
|
||||||
'Transformers': [
|
'Transformers': [
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,7 @@ group.add_argument('--no-kv-offload', action='store_true', help='Do not offload
|
||||||
group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||||
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||||
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||||
|
group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.')
|
||||||
|
|
||||||
# Cache
|
# Cache
|
||||||
group = parser.add_argument_group('Context and cache')
|
group = parser.add_argument_group('Context and cache')
|
||||||
|
|
|
||||||
|
|
@ -167,6 +167,7 @@ def list_model_elements():
|
||||||
'gpu_layers_draft',
|
'gpu_layers_draft',
|
||||||
'device_draft',
|
'device_draft',
|
||||||
'ctx_size_draft',
|
'ctx_size_draft',
|
||||||
|
'mmproj',
|
||||||
]
|
]
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,12 @@ def create_ui():
|
||||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
||||||
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
||||||
|
|
||||||
|
# Multimodal
|
||||||
|
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
|
||||||
|
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
|
||||||
|
|
@ -154,6 +154,19 @@ def get_available_ggufs():
|
||||||
return sorted(model_list, key=natural_keys)
|
return sorted(model_list, key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_mmproj():
|
||||||
|
mmproj_dir = Path('user_data/mmproj')
|
||||||
|
if not mmproj_dir.exists():
|
||||||
|
return ['None']
|
||||||
|
|
||||||
|
mmproj_files = []
|
||||||
|
for item in mmproj_dir.iterdir():
|
||||||
|
if item.is_file() and item.suffix.lower() in ('.gguf', '.bin'):
|
||||||
|
mmproj_files.append(item.name)
|
||||||
|
|
||||||
|
return ['None'] + sorted(mmproj_files, key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_available_presets():
|
def get_available_presets():
|
||||||
return sorted(set((k.stem for k in Path('user_data/presets').glob('*.yaml'))), key=natural_keys)
|
return sorted(set((k.stem for k in Path('user_data/presets').glob('*.yaml'))), key=natural_keys)
|
||||||
|
|
||||||
|
|
|
||||||
0
user_data/mmproj/place-your-mmproj-here.txt
Normal file
0
user_data/mmproj/place-your-mmproj-here.txt
Normal file
Loading…
Reference in a new issue