Image generation now functional

This commit is contained in:
oobabooga 2025-11-27 14:24:35 -08:00
parent 2f11b3040d
commit a873692234
3 changed files with 127 additions and 59 deletions

View file

@ -140,7 +140,7 @@ def create_ui():
with gr.Column():
with gr.Row():
shared.gradio['image_model_menu'] = gr.Dropdown(choices=utils.get_available_image_models(), value=lambda: shared.image_model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
ui.create_refresh_button(shared.gradio['image_model_menu'], lambda: None, lambda: {'choices': utils.get_available_image_models()}, 'refresh-button', interactive=not mu)
shared.gradio['image_load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
shared.gradio['image_save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
@ -169,7 +169,7 @@ def create_ui():
shared.gradio['image_download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
with gr.Row():
shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.image_model_name == 'None' else 'Ready')
def create_event_handlers():
@ -220,6 +220,28 @@ def create_event_handlers():
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
# Image model event handlers
shared.gradio['image_load_model'].click(
load_image_model_wrapper,
gradio('image_model_menu'),
gradio('image_model_status'),
show_progress=True
)
shared.gradio['image_unload_model'].click(
handle_unload_image_model_click,
None,
gradio('image_model_status'),
show_progress=False
)
shared.gradio['image_download_model_button'].click(
download_image_model_wrapper,
gradio('image_custom_model_menu'),
gradio('image_model_status'),
show_progress=True
)
def load_model_wrapper(selected_model, loader, autoload=False):
try:
@ -471,3 +493,66 @@ def format_file_size(size_bytes):
return f"{s:.2f} {size_names[i]}"
else:
return f"{s:.1f} {size_names[i]}"
def load_image_model_wrapper(selected_model):
"""Wrapper for loading image models with status updates."""
from modules.image_models import load_image_model, unload_image_model
if selected_model == 'None' or not selected_model:
yield "No model selected"
return
try:
yield f"Loading `{selected_model}`..."
unload_image_model()
result = load_image_model(selected_model)
if result is not None:
yield f"Successfully loaded `{selected_model}`."
else:
yield f"Failed to load `{selected_model}`."
except Exception:
exc = traceback.format_exc()
yield exc.replace('\n', '\n\n')
def handle_unload_image_model_click():
"""Handler for the image model unload button."""
from modules.image_models import unload_image_model
unload_image_model()
return "Image model unloaded"
def download_image_model_wrapper(custom_model):
"""Download an image model from Hugging Face."""
from huggingface_hub import snapshot_download
if not custom_model:
yield "No model specified"
return
try:
# Parse model name and branch
if ':' in custom_model:
model_name, branch = custom_model.rsplit(':', 1)
else:
model_name, branch = custom_model, 'main'
# Output folder
output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1]
yield f"Downloading `{model_name}` (branch: {branch})..."
snapshot_download(
repo_id=model_name,
revision=branch,
local_dir=output_folder,
local_dir_use_symlinks=False,
)
yield f"Model successfully saved to `{output_folder}/`."
except Exception:
exc = traceback.format_exc()
yield exc.replace('\n', '\n\n')