From 180e1f0cbf857276345bd1c4bc056c6806b00cd8 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 2 Dec 2025 06:06:07 -0800 Subject: [PATCH] Make it possible to interrupt a generation --- modules/ui_image_generation.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index 6006ecf1..77b22bd4 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -17,6 +17,7 @@ from modules.image_models import ( unload_image_model ) from modules.logging_colors import logger +from modules.text_generation import stop_everything_event from modules.torch_utils import get_device from modules.utils import gradio @@ -354,7 +355,7 @@ def create_ui(): ) shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg") - shared.gradio['image_generating_btn'] = gr.Button("Generating...", size="lg", visible=False, interactive=False) + shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False) gr.HTML("
") gr.Markdown("### Dimensions") @@ -522,21 +523,26 @@ def create_event_handlers(): # Generation shared.gradio['image_generate_btn'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then( + lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then( generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then( - lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn')) + lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn')) shared.gradio['image_prompt'].submit( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then( + lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then( generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then( - lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn')) + lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn')) shared.gradio['image_neg_prompt'].submit( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( - lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then( + lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then( generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then( - lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn')) + lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn')) + + # Stop button + shared.gradio['image_stop_btn'].click( + stop_everything_event, None, None, show_progress=False + ) # Model management shared.gradio['image_refresh_models'].click( @@ -680,6 +686,16 @@ def generate(state): if magic_suffix.strip(", ") not in prompt: prompt += magic_suffix + # Reset stop flag at start + shared.stop_everything = False + + # Callback to check for interruption during diffusion steps + def interrupt_callback(pipe, step_index, timestep, callback_kwargs): + if shared.stop_everything: + pipe._interrupt = True + + return callback_kwargs + # Build generation kwargs gen_kwargs = { "prompt": prompt, @@ -689,6 +705,7 @@ def generate(state): "num_inference_steps": int(state['image_steps']), "num_images_per_prompt": int(state['image_batch_size']), "generator": generator, + "callback_on_step_end": interrupt_callback, } # Add pipeline-specific parameters for CFG @@ -703,6 +720,9 @@ def generate(state): t0 = time.time() for i in range(int(state['image_batch_count'])): + if shared.stop_everything: + break + generator.manual_seed(int(seed + i)) batch_results = shared.image_model(**gen_kwargs).images all_images.extend(batch_results)