mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Compare commits
3 commits
235b94f097
...
7fb9f19bd8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fb9f19bd8 | ||
|
|
a838223d18 | ||
|
|
14dbc3488e |
39
css/main.css
39
css/main.css
|
|
@ -1752,3 +1752,42 @@ button#swap-height-width {
|
|||
.min.svelte-1yrv54 {
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
/* Image Generation Progress Bar */
|
||||
#image-progress .image-ai-separator {
|
||||
height: 24px;
|
||||
margin: 20px 0;
|
||||
border-top: 1px solid var(--input-border-color);
|
||||
}
|
||||
|
||||
#image-progress .image-ai-progress-wrapper {
|
||||
height: 24px;
|
||||
margin: 20px 0;
|
||||
}
|
||||
|
||||
#image-progress .image-ai-progress-track {
|
||||
background: #e5e7eb;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
height: 8px;
|
||||
}
|
||||
|
||||
.dark #image-progress .image-ai-progress-track {
|
||||
background: #333;
|
||||
}
|
||||
|
||||
#image-progress .image-ai-progress-fill {
|
||||
background: #4a9eff;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
#image-progress .image-ai-progress-text {
|
||||
text-align: center;
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.dark #image-progress .image-ai-progress-text {
|
||||
color: #888;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -373,7 +373,10 @@ def create_ui():
|
|||
|
||||
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
|
||||
shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
|
||||
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
|
||||
shared.gradio['image_progress'] = gr.HTML(
|
||||
value=progress_bar_html(),
|
||||
elem_id="image-progress"
|
||||
)
|
||||
|
||||
gr.Markdown("### Dimensions")
|
||||
with gr.Row():
|
||||
|
|
@ -546,19 +549,19 @@ def create_event_handlers():
|
|||
shared.gradio['image_generate_btn'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
shared.gradio['image_prompt'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
shared.gradio['image_neg_prompt'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
# Stop button
|
||||
|
|
@ -661,23 +664,37 @@ def create_event_handlers():
|
|||
)
|
||||
|
||||
|
||||
def progress_bar_html(progress=0, text=""):
|
||||
"""Generate HTML for progress bar. Empty div when progress <= 0."""
|
||||
if progress <= 0:
|
||||
return '<div class="image-ai-separator"></div>'
|
||||
|
||||
return f'''<div class="image-ai-progress-wrapper">
|
||||
<div class="image-ai-progress-track">
|
||||
<div class="image-ai-progress-fill" style="width: {progress*100:.1f}%;"></div>
|
||||
</div>
|
||||
<div class="image-ai-progress-text">{text}</div>
|
||||
</div>'''
|
||||
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate images using the loaded model.
|
||||
Automatically adjusts parameters based on pipeline type.
|
||||
"""
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import torch
|
||||
|
||||
from modules.torch_utils import clear_torch_cache, get_device
|
||||
|
||||
clear_torch_cache()
|
||||
|
||||
try:
|
||||
model_name = state['image_model_menu']
|
||||
|
||||
if not model_name or model_name == 'None':
|
||||
logger.error("No image model selected. Go to the Model tab and select a model.")
|
||||
yield []
|
||||
yield [], progress_bar_html()
|
||||
return
|
||||
|
||||
if shared.image_model is None:
|
||||
|
|
@ -691,7 +708,7 @@ def generate(state):
|
|||
)
|
||||
if result is None:
|
||||
logger.error(f"Failed to load model `{model_name}`.")
|
||||
yield []
|
||||
yield [], progress_bar_html()
|
||||
return
|
||||
|
||||
shared.image_model_name = model_name
|
||||
|
|
@ -715,68 +732,97 @@ def generate(state):
|
|||
# Process Prompt
|
||||
prompt = state['image_prompt']
|
||||
|
||||
# Apply "Positive Magic" for Qwen models only
|
||||
if pipeline_type == 'qwenimage':
|
||||
magic_suffix = ", Ultra HD, 4K, cinematic composition"
|
||||
# Avoid duplication if user already added it
|
||||
if magic_suffix.strip(", ") not in prompt:
|
||||
prompt += magic_suffix
|
||||
|
||||
# Reset stop flag at start
|
||||
shared.stop_everything = False
|
||||
|
||||
# Callback to check for interruption during diffusion steps
|
||||
batch_count = int(state['image_batch_count'])
|
||||
steps_per_batch = int(state['image_steps'])
|
||||
total_steps = steps_per_batch * batch_count
|
||||
|
||||
# Queue for progress updates from callback
|
||||
progress_queue = queue.Queue()
|
||||
|
||||
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
if shared.stop_everything:
|
||||
pipe._interrupt = True
|
||||
|
||||
progress_queue.put(step_index + 1)
|
||||
return callback_kwargs
|
||||
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": state['image_neg_prompt'],
|
||||
"height": int(state['image_height']),
|
||||
"width": int(state['image_width']),
|
||||
"num_inference_steps": int(state['image_steps']),
|
||||
"num_inference_steps": steps_per_batch,
|
||||
"num_images_per_prompt": int(state['image_batch_size']),
|
||||
"generator": generator,
|
||||
"callback_on_step_end": interrupt_callback,
|
||||
}
|
||||
|
||||
# Add pipeline-specific parameters for CFG
|
||||
cfg_val = state.get('image_cfg_scale', 0.0)
|
||||
|
||||
if pipeline_type == 'qwenimage':
|
||||
# Qwen-Image uses true_cfg_scale (typically 4.0)
|
||||
gen_kwargs["true_cfg_scale"] = cfg_val
|
||||
else:
|
||||
# Z-Image and others use guidance_scale (typically 0.0 for Turbo)
|
||||
gen_kwargs["guidance_scale"] = cfg_val
|
||||
|
||||
t0 = time.time()
|
||||
for i in range(int(state['image_batch_count'])):
|
||||
|
||||
for batch_idx in range(batch_count):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
yield all_images
|
||||
generator.manual_seed(int(seed + batch_idx))
|
||||
|
||||
# Run generation in thread so we can yield progress
|
||||
result_holder = []
|
||||
error_holder = []
|
||||
|
||||
def run_batch():
|
||||
try:
|
||||
result_holder.extend(shared.image_model(**gen_kwargs).images)
|
||||
except Exception as e:
|
||||
error_holder.append(e)
|
||||
|
||||
thread = threading.Thread(target=run_batch)
|
||||
thread.start()
|
||||
|
||||
# Yield progress updates while generation runs
|
||||
while thread.is_alive():
|
||||
try:
|
||||
step = progress_queue.get(timeout=0.1)
|
||||
absolute_step = batch_idx * steps_per_batch + step
|
||||
pct = absolute_step / total_steps
|
||||
text = f"Batch {batch_idx + 1}/{batch_count} — Step {step}/{steps_per_batch}"
|
||||
yield all_images, progress_bar_html(pct, text)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
thread.join()
|
||||
|
||||
if error_holder:
|
||||
raise error_holder[0]
|
||||
|
||||
all_images.extend(result_holder)
|
||||
yield all_images, progress_bar_html((batch_idx + 1) / batch_count, f"Batch {batch_idx + 1}/{batch_count} complete")
|
||||
|
||||
t1 = time.time()
|
||||
save_generated_images(all_images, state, seed)
|
||||
|
||||
total_images = int(state['image_batch_count']) * int(state['image_batch_size'])
|
||||
total_steps = state["image_steps"] * int(state['image_batch_count'])
|
||||
total_images = batch_count * int(state['image_batch_size'])
|
||||
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||
|
||||
yield all_images
|
||||
yield all_images, progress_bar_html()
|
||||
clear_torch_cache()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed: {e}")
|
||||
traceback.print_exc()
|
||||
yield []
|
||||
yield [], progress_bar_html()
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
|
||||
|
|
|
|||
Loading…
Reference in a new issue