Image: Add LLM-generated prompt variations

This commit is contained in:
oobabooga 2025-12-04 07:57:23 -08:00
parent 7fb9f19bd8
commit 2793153717
3 changed files with 61 additions and 0 deletions

View file

@ -319,6 +319,7 @@ settings = {
'image_seed': -1,
'image_batch_size': 1,
'image_batch_count': 1,
'image_llm_variations': False,
'image_model_menu': 'None',
'image_dtype': 'bfloat16',
'image_attn_backend': 'sdpa',

View file

@ -293,6 +293,7 @@ def list_interface_input_elements():
'image_seed',
'image_batch_size',
'image_batch_count',
'image_llm_variations',
'image_model_menu',
'image_dtype',
'image_attn_backend',
@ -547,6 +548,7 @@ def setup_auto_save():
'image_seed',
'image_batch_size',
'image_batch_count',
'image_llm_variations',
'image_model_menu',
'image_dtype',
'image_attn_backend',

View file

@ -10,6 +10,7 @@ import numpy as np
from PIL.PngImagePlugin import PngInfo
from modules import shared, ui, utils
from modules.utils import check_model_loaded
from modules.image_models import (
get_pipeline_type,
load_image_model,
@ -409,6 +410,11 @@ def create_ui():
with gr.Column():
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
shared.gradio['image_llm_variations'] = gr.Checkbox(
value=shared.settings['image_llm_variations'],
label='LLM Prompt Variations',
info='Use the loaded LLM to generate creative prompt variations for each sequential batch.'
)
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
@ -664,6 +670,54 @@ def create_event_handlers():
)
def generate_prompt_variation(state):
"""Generate a creative variation of the image prompt using the LLM."""
from modules.chat import generate_chat_prompt
from modules.text_generation import generate_reply
prompt = state['image_prompt']
# Check if LLM is loaded
model_loaded, _ = check_model_loaded()
if not model_loaded:
logger.warning("No LLM loaded for prompt variation. Using original prompt.")
return prompt
augmented_message = f"{prompt}\n\n=====\n\nPlease create a creative variation of the image generation prompt above. Keep the same general subject and style, but vary the details, composition, lighting, or mood. Respond with only the new prompt, nothing else."
# Use minimal state for generation
var_state = state.copy()
var_state['history'] = {'internal': [], 'visible': [], 'metadata': {}}
var_state['auto_max_new_tokens'] = True
var_state['enable_thinking'] = False
var_state['reasoning_effort'] = 'low'
var_state['start_with'] = ""
formatted_prompt = generate_chat_prompt(augmented_message, var_state)
variation = ""
for reply in generate_reply(formatted_prompt, var_state, stopping_strings=[], is_chat=True):
variation = reply
# Strip thinking blocks if present
if "</think>" in variation:
variation = variation.rsplit("</think>", 1)[1]
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
elif "</seed:think>" in variation:
variation = variation.rsplit("</seed:think>", 1)[1]
variation = variation.strip()
if len(variation) >= 2 and variation.startswith('"') and variation.endswith('"'):
variation = variation[1:-1]
if variation:
logger.info(f"Prompt variation: {variation}...")
return variation
return prompt
def progress_bar_html(progress=0, text=""):
"""Generate HTML for progress bar. Empty div when progress <= 0."""
if progress <= 0:
@ -777,6 +831,10 @@ def generate(state):
generator.manual_seed(int(seed + batch_idx))
# Generate prompt variation if enabled
if state['image_llm_variations']:
gen_kwargs["prompt"] = generate_prompt_variation(state)
# Run generation in thread so we can yield progress
result_holder = []
error_holder = []