diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 6006ecf1..77b22bd4 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -17,6 +17,7 @@ from modules.image_models import (
unload_image_model
)
from modules.logging_colors import logger
+from modules.text_generation import stop_everything_event
from modules.torch_utils import get_device
from modules.utils import gradio
@@ -354,7 +355,7 @@ def create_ui():
)
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
- shared.gradio['image_generating_btn'] = gr.Button("Generating...", size="lg", visible=False, interactive=False)
+ shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
gr.HTML("
")
gr.Markdown("### Dimensions")
@@ -522,21 +523,26 @@ def create_event_handlers():
# Generation
shared.gradio['image_generate_btn'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
- lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
+ lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
- lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
+ lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
shared.gradio['image_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
- lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
+ lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
- lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
+ lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
shared.gradio['image_neg_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
- lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
+ lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
- lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
+ lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
+
+ # Stop button
+ shared.gradio['image_stop_btn'].click(
+ stop_everything_event, None, None, show_progress=False
+ )
# Model management
shared.gradio['image_refresh_models'].click(
@@ -680,6 +686,16 @@ def generate(state):
if magic_suffix.strip(", ") not in prompt:
prompt += magic_suffix
+ # Reset stop flag at start
+ shared.stop_everything = False
+
+ # Callback to check for interruption during diffusion steps
+ def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
+ if shared.stop_everything:
+ pipe._interrupt = True
+
+ return callback_kwargs
+
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
@@ -689,6 +705,7 @@ def generate(state):
"num_inference_steps": int(state['image_steps']),
"num_images_per_prompt": int(state['image_batch_size']),
"generator": generator,
+ "callback_on_step_end": interrupt_callback,
}
# Add pipeline-specific parameters for CFG
@@ -703,6 +720,9 @@ def generate(state):
t0 = time.time()
for i in range(int(state['image_batch_count'])):
+ if shared.stop_everything:
+ break
+
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)