mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-02-05 07:14:17 +01:00
Make it possible to interrupt a generation
This commit is contained in:
parent
c3bd1c901d
commit
180e1f0cbf
|
|
@ -17,6 +17,7 @@ from modules.image_models import (
|
|||
unload_image_model
|
||||
)
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import stop_everything_event
|
||||
from modules.torch_utils import get_device
|
||||
from modules.utils import gradio
|
||||
|
||||
|
|
@ -354,7 +355,7 @@ def create_ui():
|
|||
)
|
||||
|
||||
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
|
||||
shared.gradio['image_generating_btn'] = gr.Button("Generating...", size="lg", visible=False, interactive=False)
|
||||
shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
|
||||
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
|
||||
|
||||
gr.Markdown("### Dimensions")
|
||||
|
|
@ -522,21 +523,26 @@ def create_event_handlers():
|
|||
# Generation
|
||||
shared.gradio['image_generate_btn'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
shared.gradio['image_prompt'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
shared.gradio['image_neg_prompt'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_generating_btn', 'image_generate_btn')).then(
|
||||
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_generating_btn', 'image_generate_btn'))
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||
|
||||
# Stop button
|
||||
shared.gradio['image_stop_btn'].click(
|
||||
stop_everything_event, None, None, show_progress=False
|
||||
)
|
||||
|
||||
# Model management
|
||||
shared.gradio['image_refresh_models'].click(
|
||||
|
|
@ -680,6 +686,16 @@ def generate(state):
|
|||
if magic_suffix.strip(", ") not in prompt:
|
||||
prompt += magic_suffix
|
||||
|
||||
# Reset stop flag at start
|
||||
shared.stop_everything = False
|
||||
|
||||
# Callback to check for interruption during diffusion steps
|
||||
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
if shared.stop_everything:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
|
|
@ -689,6 +705,7 @@ def generate(state):
|
|||
"num_inference_steps": int(state['image_steps']),
|
||||
"num_images_per_prompt": int(state['image_batch_size']),
|
||||
"generator": generator,
|
||||
"callback_on_step_end": interrupt_callback,
|
||||
}
|
||||
|
||||
# Add pipeline-specific parameters for CFG
|
||||
|
|
@ -703,6 +720,9 @@ def generate(state):
|
|||
|
||||
t0 = time.time()
|
||||
for i in range(int(state['image_batch_count'])):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
|
|
|
|||
Loading…
Reference in a new issue