mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-25 01:50:44 +01:00
Image: Add LLM-generated prompt variations
This commit is contained in:
parent
7fb9f19bd8
commit
2793153717
|
|
@ -319,6 +319,7 @@ settings = {
|
|||
'image_seed': -1,
|
||||
'image_batch_size': 1,
|
||||
'image_batch_count': 1,
|
||||
'image_llm_variations': False,
|
||||
'image_model_menu': 'None',
|
||||
'image_dtype': 'bfloat16',
|
||||
'image_attn_backend': 'sdpa',
|
||||
|
|
|
|||
|
|
@ -293,6 +293,7 @@ def list_interface_input_elements():
|
|||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_llm_variations',
|
||||
'image_model_menu',
|
||||
'image_dtype',
|
||||
'image_attn_backend',
|
||||
|
|
@ -547,6 +548,7 @@ def setup_auto_save():
|
|||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_llm_variations',
|
||||
'image_model_menu',
|
||||
'image_dtype',
|
||||
'image_attn_backend',
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import numpy as np
|
|||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
from modules import shared, ui, utils
|
||||
from modules.utils import check_model_loaded
|
||||
from modules.image_models import (
|
||||
get_pipeline_type,
|
||||
load_image_model,
|
||||
|
|
@ -409,6 +410,11 @@ def create_ui():
|
|||
with gr.Column():
|
||||
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
|
||||
shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
|
||||
shared.gradio['image_llm_variations'] = gr.Checkbox(
|
||||
value=shared.settings['image_llm_variations'],
|
||||
label='LLM Prompt Variations',
|
||||
info='Use the loaded LLM to generate creative prompt variations for each sequential batch.'
|
||||
)
|
||||
|
||||
with gr.Column(scale=6, min_width=500):
|
||||
with gr.Column(elem_classes=["viewport-container"]):
|
||||
|
|
@ -664,6 +670,54 @@ def create_event_handlers():
|
|||
)
|
||||
|
||||
|
||||
def generate_prompt_variation(state):
|
||||
"""Generate a creative variation of the image prompt using the LLM."""
|
||||
from modules.chat import generate_chat_prompt
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
prompt = state['image_prompt']
|
||||
|
||||
# Check if LLM is loaded
|
||||
model_loaded, _ = check_model_loaded()
|
||||
if not model_loaded:
|
||||
logger.warning("No LLM loaded for prompt variation. Using original prompt.")
|
||||
return prompt
|
||||
|
||||
augmented_message = f"{prompt}\n\n=====\n\nPlease create a creative variation of the image generation prompt above. Keep the same general subject and style, but vary the details, composition, lighting, or mood. Respond with only the new prompt, nothing else."
|
||||
|
||||
# Use minimal state for generation
|
||||
var_state = state.copy()
|
||||
var_state['history'] = {'internal': [], 'visible': [], 'metadata': {}}
|
||||
var_state['auto_max_new_tokens'] = True
|
||||
var_state['enable_thinking'] = False
|
||||
var_state['reasoning_effort'] = 'low'
|
||||
var_state['start_with'] = ""
|
||||
|
||||
formatted_prompt = generate_chat_prompt(augmented_message, var_state)
|
||||
|
||||
variation = ""
|
||||
for reply in generate_reply(formatted_prompt, var_state, stopping_strings=[], is_chat=True):
|
||||
variation = reply
|
||||
|
||||
# Strip thinking blocks if present
|
||||
if "</think>" in variation:
|
||||
variation = variation.rsplit("</think>", 1)[1]
|
||||
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
|
||||
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
|
||||
elif "</seed:think>" in variation:
|
||||
variation = variation.rsplit("</seed:think>", 1)[1]
|
||||
|
||||
variation = variation.strip()
|
||||
if len(variation) >= 2 and variation.startswith('"') and variation.endswith('"'):
|
||||
variation = variation[1:-1]
|
||||
|
||||
if variation:
|
||||
logger.info(f"Prompt variation: {variation}...")
|
||||
return variation
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def progress_bar_html(progress=0, text=""):
|
||||
"""Generate HTML for progress bar. Empty div when progress <= 0."""
|
||||
if progress <= 0:
|
||||
|
|
@ -777,6 +831,10 @@ def generate(state):
|
|||
|
||||
generator.manual_seed(int(seed + batch_idx))
|
||||
|
||||
# Generate prompt variation if enabled
|
||||
if state['image_llm_variations']:
|
||||
gen_kwargs["prompt"] = generate_prompt_variation(state)
|
||||
|
||||
# Run generation in thread so we can yield progress
|
||||
result_holder = []
|
||||
error_holder = []
|
||||
|
|
|
|||
Loading…
Reference in a new issue