mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-05 14:45:28 +00:00
Add multimodal support (llama.cpp) (#7027)
This commit is contained in:
parent
eb16f64017
commit
d86b0ec010
9 changed files with 86 additions and 18 deletions
|
|
@ -12,6 +12,10 @@ from pathlib import Path
|
|||
import llama_cpp_binaries
|
||||
import requests
|
||||
|
||||
from extensions.openai.image_utils import (
|
||||
convert_image_attachments_to_pil,
|
||||
convert_pil_to_base64
|
||||
)
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
|
@ -128,15 +132,40 @@ class LlamaServer:
|
|||
url = f"http://127.0.0.1:{self.port}/completion"
|
||||
payload = self.prepare_payload(state)
|
||||
|
||||
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||
self.last_prompt_token_count = len(token_ids)
|
||||
pil_images = []
|
||||
# Check for images from the Web UI (image_attachments)
|
||||
if 'image_attachments' in state and state['image_attachments']:
|
||||
pil_images.extend(convert_image_attachments_to_pil(state['image_attachments']))
|
||||
# Else, check for images from the API (raw_images)
|
||||
elif 'raw_images' in state and state['raw_images']:
|
||||
pil_images.extend(state.get('raw_images', []))
|
||||
|
||||
if pil_images:
|
||||
# Multimodal case
|
||||
IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image
|
||||
|
||||
base64_images = [convert_pil_to_base64(img) for img in pil_images]
|
||||
multimodal_prompt_object = {
|
||||
"prompt": prompt,
|
||||
"multimodal_data": base64_images
|
||||
}
|
||||
payload["prompt"] = multimodal_prompt_object
|
||||
|
||||
# Calculate an estimated token count
|
||||
text_tokens = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||
self.last_prompt_token_count = len(text_tokens) + (len(pil_images) * IMAGE_TOKEN_COST_ESTIMATE)
|
||||
else:
|
||||
# Text only case
|
||||
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||
self.last_prompt_token_count = len(token_ids)
|
||||
payload["prompt"] = token_ids
|
||||
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - len(token_ids)
|
||||
max_new_tokens = state['truncation_length'] - self.last_prompt_token_count
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
payload.update({
|
||||
"prompt": token_ids,
|
||||
"n_predict": max_new_tokens,
|
||||
"stream": True,
|
||||
"cache_prompt": True
|
||||
|
|
@ -144,7 +173,7 @@ class LlamaServer:
|
|||
|
||||
if shared.args.verbose:
|
||||
logger.info("GENERATE_PARAMS=")
|
||||
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
|
||||
printable_payload = {k: (v if k != "prompt" else "[multimodal object]" if pil_images else v) for k, v in payload.items()}
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
||||
print()
|
||||
|
||||
|
|
@ -295,6 +324,13 @@ class LlamaServer:
|
|||
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
||||
if shared.args.rope_freq_base > 0:
|
||||
cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)]
|
||||
if shared.args.mmproj not in [None, 'None']:
|
||||
path = Path(shared.args.mmproj)
|
||||
if not path.exists():
|
||||
path = Path('user_data/mmproj') / shared.args.mmproj
|
||||
|
||||
if path.exists():
|
||||
cmd += ["--mmproj", str(path)]
|
||||
if shared.args.model_draft not in [None, 'None']:
|
||||
path = Path(shared.args.model_draft)
|
||||
if not path.exists():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue