From 27931537176fef1bc1335815097ebc780cbf1dbf Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 4 Dec 2025 07:57:23 -0800 Subject: [PATCH] Image: Add LLM-generated prompt variations --- modules/shared.py | 1 + modules/ui.py | 2 ++ modules/ui_image_generation.py | 58 ++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/modules/shared.py b/modules/shared.py index 4e17497b..1ecc0d28 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -319,6 +319,7 @@ settings = { 'image_seed': -1, 'image_batch_size': 1, 'image_batch_count': 1, + 'image_llm_variations': False, 'image_model_menu': 'None', 'image_dtype': 'bfloat16', 'image_attn_backend': 'sdpa', diff --git a/modules/ui.py b/modules/ui.py index ff5686e8..d95f7bc6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -293,6 +293,7 @@ def list_interface_input_elements(): 'image_seed', 'image_batch_size', 'image_batch_count', + 'image_llm_variations', 'image_model_menu', 'image_dtype', 'image_attn_backend', @@ -547,6 +548,7 @@ def setup_auto_save(): 'image_seed', 'image_batch_size', 'image_batch_count', + 'image_llm_variations', 'image_model_menu', 'image_dtype', 'image_attn_backend', diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index e85f1520..ceb470ff 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -10,6 +10,7 @@ import numpy as np from PIL.PngImagePlugin import PngInfo from modules import shared, ui, utils +from modules.utils import check_model_loaded from modules.image_models import ( get_pipeline_type, load_image_model, @@ -409,6 +410,11 @@ def create_ui(): with gr.Column(): shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.") shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.") + shared.gradio['image_llm_variations'] = gr.Checkbox( + value=shared.settings['image_llm_variations'], + label='LLM Prompt Variations', + info='Use the loaded LLM to generate creative prompt variations for each sequential batch.' + ) with gr.Column(scale=6, min_width=500): with gr.Column(elem_classes=["viewport-container"]): @@ -664,6 +670,54 @@ def create_event_handlers(): ) +def generate_prompt_variation(state): + """Generate a creative variation of the image prompt using the LLM.""" + from modules.chat import generate_chat_prompt + from modules.text_generation import generate_reply + + prompt = state['image_prompt'] + + # Check if LLM is loaded + model_loaded, _ = check_model_loaded() + if not model_loaded: + logger.warning("No LLM loaded for prompt variation. Using original prompt.") + return prompt + + augmented_message = f"{prompt}\n\n=====\n\nPlease create a creative variation of the image generation prompt above. Keep the same general subject and style, but vary the details, composition, lighting, or mood. Respond with only the new prompt, nothing else." + + # Use minimal state for generation + var_state = state.copy() + var_state['history'] = {'internal': [], 'visible': [], 'metadata': {}} + var_state['auto_max_new_tokens'] = True + var_state['enable_thinking'] = False + var_state['reasoning_effort'] = 'low' + var_state['start_with'] = "" + + formatted_prompt = generate_chat_prompt(augmented_message, var_state) + + variation = "" + for reply in generate_reply(formatted_prompt, var_state, stopping_strings=[], is_chat=True): + variation = reply + + # Strip thinking blocks if present + if "" in variation: + variation = variation.rsplit("", 1)[1] + elif "<|start|>assistant<|channel|>final<|message|>" in variation: + variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1] + elif "" in variation: + variation = variation.rsplit("", 1)[1] + + variation = variation.strip() + if len(variation) >= 2 and variation.startswith('"') and variation.endswith('"'): + variation = variation[1:-1] + + if variation: + logger.info(f"Prompt variation: {variation}...") + return variation + + return prompt + + def progress_bar_html(progress=0, text=""): """Generate HTML for progress bar. Empty div when progress <= 0.""" if progress <= 0: @@ -777,6 +831,10 @@ def generate(state): generator.manual_seed(int(seed + batch_idx)) + # Generate prompt variation if enabled + if state['image_llm_variations']: + gen_kwargs["prompt"] = generate_prompt_variation(state) + # Run generation in thread so we can yield progress result_holder = [] error_holder = []