Compare commits

...

3 commits

Author SHA1 Message Date
oobabooga 7fb9f19bd8 Progress bar style improvements 2025-12-04 06:20:45 -08:00
oobabooga a838223d18 Image: Add a progress bar during generation 2025-12-04 05:49:57 -08:00
oobabooga 14dbc3488e Image: Clear the torch cache after generation, not before 2025-12-04 05:32:58 -08:00
2 changed files with 113 additions and 28 deletions

View file

@ -1752,3 +1752,42 @@ button#swap-height-width {
.min.svelte-1yrv54 {
min-height: 0;
}
/* Image Generation Progress Bar */
#image-progress .image-ai-separator {
height: 24px;
margin: 20px 0;
border-top: 1px solid var(--input-border-color);
}
#image-progress .image-ai-progress-wrapper {
height: 24px;
margin: 20px 0;
}
#image-progress .image-ai-progress-track {
background: #e5e7eb;
border-radius: 4px;
overflow: hidden;
height: 8px;
}
.dark #image-progress .image-ai-progress-track {
background: #333;
}
#image-progress .image-ai-progress-fill {
background: #4a9eff;
height: 100%;
}
#image-progress .image-ai-progress-text {
text-align: center;
font-size: 12px;
color: #666;
margin-top: 4px;
}
.dark #image-progress .image-ai-progress-text {
color: #888;
}

View file

@ -373,7 +373,10 @@ def create_ui():
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
shared.gradio['image_progress'] = gr.HTML(
value=progress_bar_html(),
elem_id="image-progress"
)
gr.Markdown("### Dimensions")
with gr.Row():
@ -546,19 +549,19 @@ def create_event_handlers():
shared.gradio['image_generate_btn'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
shared.gradio['image_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
shared.gradio['image_neg_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
# Stop button
@ -661,23 +664,37 @@ def create_event_handlers():
)
def progress_bar_html(progress=0, text=""):
"""Generate HTML for progress bar. Empty div when progress <= 0."""
if progress <= 0:
return '<div class="image-ai-separator"></div>'
return f'''<div class="image-ai-progress-wrapper">
<div class="image-ai-progress-track">
<div class="image-ai-progress-fill" style="width: {progress*100:.1f}%;"></div>
</div>
<div class="image-ai-progress-text">{text}</div>
</div>'''
def generate(state):
"""
Generate images using the loaded model.
Automatically adjusts parameters based on pipeline type.
"""
import queue
import threading
import torch
from modules.torch_utils import clear_torch_cache, get_device
clear_torch_cache()
try:
model_name = state['image_model_menu']
if not model_name or model_name == 'None':
logger.error("No image model selected. Go to the Model tab and select a model.")
yield []
yield [], progress_bar_html()
return
if shared.image_model is None:
@ -691,7 +708,7 @@ def generate(state):
)
if result is None:
logger.error(f"Failed to load model `{model_name}`.")
yield []
yield [], progress_bar_html()
return
shared.image_model_name = model_name
@ -715,68 +732,97 @@ def generate(state):
# Process Prompt
prompt = state['image_prompt']
# Apply "Positive Magic" for Qwen models only
if pipeline_type == 'qwenimage':
magic_suffix = ", Ultra HD, 4K, cinematic composition"
# Avoid duplication if user already added it
if magic_suffix.strip(", ") not in prompt:
prompt += magic_suffix
# Reset stop flag at start
shared.stop_everything = False
# Callback to check for interruption during diffusion steps
batch_count = int(state['image_batch_count'])
steps_per_batch = int(state['image_steps'])
total_steps = steps_per_batch * batch_count
# Queue for progress updates from callback
progress_queue = queue.Queue()
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
if shared.stop_everything:
pipe._interrupt = True
progress_queue.put(step_index + 1)
return callback_kwargs
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
"negative_prompt": state['image_neg_prompt'],
"height": int(state['image_height']),
"width": int(state['image_width']),
"num_inference_steps": int(state['image_steps']),
"num_inference_steps": steps_per_batch,
"num_images_per_prompt": int(state['image_batch_size']),
"generator": generator,
"callback_on_step_end": interrupt_callback,
}
# Add pipeline-specific parameters for CFG
cfg_val = state.get('image_cfg_scale', 0.0)
if pipeline_type == 'qwenimage':
# Qwen-Image uses true_cfg_scale (typically 4.0)
gen_kwargs["true_cfg_scale"] = cfg_val
else:
# Z-Image and others use guidance_scale (typically 0.0 for Turbo)
gen_kwargs["guidance_scale"] = cfg_val
t0 = time.time()
for i in range(int(state['image_batch_count'])):
for batch_idx in range(batch_count):
if shared.stop_everything:
break
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
yield all_images
generator.manual_seed(int(seed + batch_idx))
# Run generation in thread so we can yield progress
result_holder = []
error_holder = []
def run_batch():
try:
result_holder.extend(shared.image_model(**gen_kwargs).images)
except Exception as e:
error_holder.append(e)
thread = threading.Thread(target=run_batch)
thread.start()
# Yield progress updates while generation runs
while thread.is_alive():
try:
step = progress_queue.get(timeout=0.1)
absolute_step = batch_idx * steps_per_batch + step
pct = absolute_step / total_steps
text = f"Batch {batch_idx + 1}/{batch_count} — Step {step}/{steps_per_batch}"
yield all_images, progress_bar_html(pct, text)
except queue.Empty:
pass
thread.join()
if error_holder:
raise error_holder[0]
all_images.extend(result_holder)
yield all_images, progress_bar_html((batch_idx + 1) / batch_count, f"Batch {batch_idx + 1}/{batch_count} complete")
t1 = time.time()
save_generated_images(all_images, state, seed)
total_images = int(state['image_batch_count']) * int(state['image_batch_size'])
total_steps = state["image_steps"] * int(state['image_batch_count'])
total_images = batch_count * int(state['image_batch_size'])
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
yield all_images
yield all_images, progress_bar_html()
clear_torch_cache()
except Exception as e:
logger.error(f"Image generation failed: {e}")
traceback.print_exc()
yield []
yield [], progress_bar_html()
clear_torch_cache()
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):