mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-18 04:53:42 +00:00
Image: Add a progress bar during generation
This commit is contained in:
parent
14dbc3488e
commit
a838223d18
1 changed files with 72 additions and 26 deletions
|
|
@ -373,7 +373,10 @@ def create_ui():
|
|||
|
||||
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
|
||||
shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
|
||||
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
|
||||
shared.gradio['image_progress'] = gr.HTML(
|
||||
value=progress_bar_html(),
|
||||
elem_id="image-progress"
|
||||
)
|
||||
|
||||
gr.Markdown("### Dimensions")
|
||||
with gr.Row():
|
||||
|
|
@ -546,19 +549,19 @@ def create_event_handlers():
|
|||
shared.gradio['image_generate_btn'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
shared.gradio['image_prompt'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
shared.gradio['image_neg_prompt'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
# Stop button
|
||||
|
|
@ -661,11 +664,27 @@ def create_event_handlers():
|
|||
)
|
||||
|
||||
|
||||
def progress_bar_html(progress=0, text=""):
|
||||
"""Generate HTML for progress bar. Empty div when progress <= 0."""
|
||||
if progress <= 0:
|
||||
return '<div style="height: 24px; margin: 20px 0; border-top: 1px solid #444;"></div>'
|
||||
|
||||
return f'''<div style="height: 24px; margin: 20px 0;">
|
||||
<div style="background: #333; border-radius: 4px; overflow: hidden; height: 8px;">
|
||||
<div style="background: #4a9eff; height: 100%; width: {progress*100:.1f}%;"></div>
|
||||
</div>
|
||||
<div style="text-align: center; font-size: 11px; color: #888; margin-top: 4px;">{text}</div>
|
||||
</div>'''
|
||||
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate images using the loaded model.
|
||||
Automatically adjusts parameters based on pipeline type.
|
||||
"""
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import torch
|
||||
|
||||
from modules.torch_utils import clear_torch_cache, get_device
|
||||
|
|
@ -675,7 +694,7 @@ def generate(state):
|
|||
|
||||
if not model_name or model_name == 'None':
|
||||
logger.error("No image model selected. Go to the Model tab and select a model.")
|
||||
yield []
|
||||
yield [], progress_bar_html()
|
||||
return
|
||||
|
||||
if shared.image_model is None:
|
||||
|
|
@ -689,7 +708,7 @@ def generate(state):
|
|||
)
|
||||
if result is None:
|
||||
logger.error(f"Failed to load model `{model_name}`.")
|
||||
yield []
|
||||
yield [], progress_bar_html()
|
||||
return
|
||||
|
||||
shared.image_model_name = model_name
|
||||
|
|
@ -713,69 +732,96 @@ def generate(state):
|
|||
# Process Prompt
|
||||
prompt = state['image_prompt']
|
||||
|
||||
# Apply "Positive Magic" for Qwen models only
|
||||
if pipeline_type == 'qwenimage':
|
||||
magic_suffix = ", Ultra HD, 4K, cinematic composition"
|
||||
# Avoid duplication if user already added it
|
||||
if magic_suffix.strip(", ") not in prompt:
|
||||
prompt += magic_suffix
|
||||
|
||||
# Reset stop flag at start
|
||||
shared.stop_everything = False
|
||||
|
||||
# Callback to check for interruption during diffusion steps
|
||||
batch_count = int(state['image_batch_count'])
|
||||
steps_per_batch = int(state['image_steps'])
|
||||
total_steps = steps_per_batch * batch_count
|
||||
|
||||
# Queue for progress updates from callback
|
||||
progress_queue = queue.Queue()
|
||||
|
||||
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
if shared.stop_everything:
|
||||
pipe._interrupt = True
|
||||
|
||||
progress_queue.put(step_index + 1)
|
||||
return callback_kwargs
|
||||
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": state['image_neg_prompt'],
|
||||
"height": int(state['image_height']),
|
||||
"width": int(state['image_width']),
|
||||
"num_inference_steps": int(state['image_steps']),
|
||||
"num_inference_steps": steps_per_batch,
|
||||
"num_images_per_prompt": int(state['image_batch_size']),
|
||||
"generator": generator,
|
||||
"callback_on_step_end": interrupt_callback,
|
||||
}
|
||||
|
||||
# Add pipeline-specific parameters for CFG
|
||||
cfg_val = state.get('image_cfg_scale', 0.0)
|
||||
|
||||
if pipeline_type == 'qwenimage':
|
||||
# Qwen-Image uses true_cfg_scale (typically 4.0)
|
||||
gen_kwargs["true_cfg_scale"] = cfg_val
|
||||
else:
|
||||
# Z-Image and others use guidance_scale (typically 0.0 for Turbo)
|
||||
gen_kwargs["guidance_scale"] = cfg_val
|
||||
|
||||
t0 = time.time()
|
||||
for i in range(int(state['image_batch_count'])):
|
||||
|
||||
for batch_idx in range(batch_count):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
yield all_images
|
||||
generator.manual_seed(int(seed + batch_idx))
|
||||
|
||||
# Run generation in thread so we can yield progress
|
||||
result_holder = []
|
||||
error_holder = []
|
||||
|
||||
def run_batch():
|
||||
try:
|
||||
result_holder.extend(shared.image_model(**gen_kwargs).images)
|
||||
except Exception as e:
|
||||
error_holder.append(e)
|
||||
|
||||
thread = threading.Thread(target=run_batch)
|
||||
thread.start()
|
||||
|
||||
# Yield progress updates while generation runs
|
||||
while thread.is_alive():
|
||||
try:
|
||||
step = progress_queue.get(timeout=0.1)
|
||||
absolute_step = batch_idx * steps_per_batch + step
|
||||
pct = absolute_step / total_steps
|
||||
text = f"Batch {batch_idx + 1}/{batch_count} — Step {step}/{steps_per_batch}"
|
||||
yield all_images, progress_bar_html(pct, text)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
thread.join()
|
||||
|
||||
if error_holder:
|
||||
raise error_holder[0]
|
||||
|
||||
all_images.extend(result_holder)
|
||||
yield all_images, progress_bar_html((batch_idx + 1) / batch_count, f"Batch {batch_idx + 1}/{batch_count} complete")
|
||||
|
||||
t1 = time.time()
|
||||
save_generated_images(all_images, state, seed)
|
||||
|
||||
total_images = int(state['image_batch_count']) * int(state['image_batch_size'])
|
||||
total_steps = state["image_steps"] * int(state['image_batch_count'])
|
||||
total_images = batch_count * int(state['image_batch_size'])
|
||||
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||
|
||||
yield all_images
|
||||
yield all_images, progress_bar_html()
|
||||
clear_torch_cache()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed: {e}")
|
||||
traceback.print_exc()
|
||||
yield []
|
||||
yield [], progress_bar_html()
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue