diff --git a/extensions/openai/image_utils.py b/extensions/openai/image_utils.py index c54f0532..658f00d7 100644 --- a/extensions/openai/image_utils.py +++ b/extensions/openai/image_utils.py @@ -11,6 +11,15 @@ from PIL import Image from modules.logging_colors import logger +def convert_pil_to_base64(image: Image.Image) -> str: + """Converts a PIL Image to a base64 encoded string.""" + buffered = io.BytesIO() + # Save image to an in-memory bytes buffer in PNG format + image.save(buffered, format="PNG") + # Encode the bytes to a base64 string + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + def decode_base64_image(base64_string: str) -> Image.Image: """Decodes a base64 string to a PIL Image.""" try: diff --git a/modules/chat.py b/modules/chat.py index 639feebf..696fa350 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -813,19 +813,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess for file_path in files: add_message_attachment(output, row_idx, file_path, is_user=True) - # Collect image attachments for multimodal generation - image_attachments = [] - if 'metadata' in output: - user_key = f"user_{row_idx}" - if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]: - for attachment in output['metadata'][user_key]["attachments"]: - if attachment.get("type") == "image": - image_attachments.append(attachment) - - # Add image attachments to state for the generation - if image_attachments: - state['image_attachments'] = image_attachments - # Add web search results as attachments if enabled if state.get('enable_web_search', False): search_query = generate_search_query(text, state) @@ -881,6 +868,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess 'metadata': output['metadata'] } + # Collect image attachments for multimodal generation + image_attachments = [] + if 'metadata' in output: + user_key = f"user_{row_idx}" + if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]: + for attachment in output['metadata'][user_key]["attachments"]: + if attachment.get("type") == "image": + image_attachments.append(attachment) + + # Add image attachments to state for the generation + if image_attachments: + state['image_attachments'] = image_attachments + # Generate the prompt kwargs = { '_continue': _continue, diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index e64f1694..072ff83b 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -12,6 +12,10 @@ from pathlib import Path import llama_cpp_binaries import requests +from extensions.openai.image_utils import ( + convert_image_attachments_to_pil, + convert_pil_to_base64 +) from modules import shared from modules.logging_colors import logger @@ -128,15 +132,40 @@ class LlamaServer: url = f"http://127.0.0.1:{self.port}/completion" payload = self.prepare_payload(state) - token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"]) - self.last_prompt_token_count = len(token_ids) + pil_images = [] + # Check for images from the Web UI (image_attachments) + if 'image_attachments' in state and state['image_attachments']: + pil_images.extend(convert_image_attachments_to_pil(state['image_attachments'])) + # Else, check for images from the API (raw_images) + elif 'raw_images' in state and state['raw_images']: + pil_images.extend(state.get('raw_images', [])) + + if pil_images: + # Multimodal case + IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image + + base64_images = [convert_pil_to_base64(img) for img in pil_images] + multimodal_prompt_object = { + "prompt": prompt, + "multimodal_data": base64_images + } + payload["prompt"] = multimodal_prompt_object + + # Calculate an estimated token count + text_tokens = self.encode(prompt, add_bos_token=state["add_bos_token"]) + self.last_prompt_token_count = len(text_tokens) + (len(pil_images) * IMAGE_TOKEN_COST_ESTIMATE) + else: + # Text only case + token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"]) + self.last_prompt_token_count = len(token_ids) + payload["prompt"] = token_ids + if state['auto_max_new_tokens']: - max_new_tokens = state['truncation_length'] - len(token_ids) + max_new_tokens = state['truncation_length'] - self.last_prompt_token_count else: max_new_tokens = state['max_new_tokens'] payload.update({ - "prompt": token_ids, "n_predict": max_new_tokens, "stream": True, "cache_prompt": True @@ -144,7 +173,7 @@ class LlamaServer: if shared.args.verbose: logger.info("GENERATE_PARAMS=") - printable_payload = {k: v for k, v in payload.items() if k != "prompt"} + printable_payload = {k: (v if k != "prompt" else "[multimodal object]" if pil_images else v) for k, v in payload.items()} pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() @@ -295,6 +324,13 @@ class LlamaServer: cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)] if shared.args.rope_freq_base > 0: cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)] + if shared.args.mmproj not in [None, 'None']: + path = Path(shared.args.mmproj) + if not path.exists(): + path = Path('user_data/mmproj') / shared.args.mmproj + + if path.exists(): + cmd += ["--mmproj", str(path)] if shared.args.model_draft not in [None, 'None']: path = Path(shared.args.model_draft) if not path.exists(): diff --git a/modules/loaders.py b/modules/loaders.py index 151de990..feca9985 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -28,6 +28,8 @@ loaders_and_params = OrderedDict({ 'device_draft', 'ctx_size_draft', 'speculative_decoding_accordion', + 'mmproj', + 'mmproj_accordion', 'vram_info', ], 'Transformers': [ diff --git a/modules/shared.py b/modules/shared.py index 1de4306b..e9d8a62f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -85,6 +85,7 @@ group.add_argument('--no-kv-offload', action='store_true', help='Do not offload group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"') group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') +group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.') # Cache group = parser.add_argument_group('Context and cache') diff --git a/modules/ui.py b/modules/ui.py index e7805046..1171cd48 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -167,6 +167,7 @@ def list_model_elements(): 'gpu_layers_draft', 'device_draft', 'ctx_size_draft', + 'mmproj', ] return elements diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 0ab67e7a..9fa8a4f4 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -59,6 +59,12 @@ def create_ui(): shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code) shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.') + # Multimodal + with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']: + with gr.Row(): + shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu) + ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu) + # Speculative decoding with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']: with gr.Row(): diff --git a/modules/utils.py b/modules/utils.py index 117ad590..4927ef04 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -154,6 +154,19 @@ def get_available_ggufs(): return sorted(model_list, key=natural_keys) +def get_available_mmproj(): + mmproj_dir = Path('user_data/mmproj') + if not mmproj_dir.exists(): + return ['None'] + + mmproj_files = [] + for item in mmproj_dir.iterdir(): + if item.is_file() and item.suffix.lower() in ('.gguf', '.bin'): + mmproj_files.append(item.name) + + return ['None'] + sorted(mmproj_files, key=natural_keys) + + def get_available_presets(): return sorted(set((k.stem for k in Path('user_data/presets').glob('*.yaml'))), key=natural_keys) diff --git a/user_data/mmproj/place-your-mmproj-here.txt b/user_data/mmproj/place-your-mmproj-here.txt new file mode 100644 index 00000000..e69de29b