Make it possible to interrupt a generation

This commit is contained in:
oobabooga 2025-12-02 06:06:07 -08:00
parent c3bd1c901d
commit 180e1f0cbf

View file

@ -17,6 +17,7 @@ from modules.image_models import (
unload_image_model
)
from modules.logging_colors import logger
from modules.text_generation import stop_everything_event
from modules.torch_utils import get_device
from modules.utils import gradio
@ -354,7 +355,7 @@ def create_ui():
)
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
shared.gradio['image_generating_btn'] = gr.Button("Generating...", size="lg", visible=False, interactive=False)
shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
gr.Markdown("### Dimensions")
@ -522,21 +523,26 @@ def create_event_handlers():
# Generation
shared.gradio['image_generate_btn'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
shared.gradio['image_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
shared.gradio['image_neg_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
# Stop button
shared.gradio['image_stop_btn'].click(
stop_everything_event, None, None, show_progress=False
)
# Model management
shared.gradio['image_refresh_models'].click(
@ -680,6 +686,16 @@ def generate(state):
if magic_suffix.strip(", ") not in prompt:
prompt += magic_suffix
# Reset stop flag at start
shared.stop_everything = False
# Callback to check for interruption during diffusion steps
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
if shared.stop_everything:
pipe._interrupt = True
return callback_kwargs
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
@ -689,6 +705,7 @@ def generate(state):
"num_inference_steps": int(state['image_steps']),
"num_images_per_prompt": int(state['image_batch_size']),
"generator": generator,
"callback_on_step_end": interrupt_callback,
}
# Add pipeline-specific parameters for CFG
@ -703,6 +720,9 @@ def generate(state):
t0 = time.time()
for i in range(int(state['image_batch_count'])):
if shared.stop_everything:
break
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)