mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-05 06:35:15 +00:00
Image generation now functional
This commit is contained in:
parent
2f11b3040d
commit
a873692234
3 changed files with 127 additions and 59 deletions
|
|
@ -140,7 +140,7 @@ def create_ui():
|
|||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['image_model_menu'] = gr.Dropdown(choices=utils.get_available_image_models(), value=lambda: shared.image_model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
|
||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
|
||||
ui.create_refresh_button(shared.gradio['image_model_menu'], lambda: None, lambda: {'choices': utils.get_available_image_models()}, 'refresh-button', interactive=not mu)
|
||||
shared.gradio['image_load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
|
||||
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
|
||||
shared.gradio['image_save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
|
||||
|
|
@ -169,7 +169,7 @@ def create_ui():
|
|||
shared.gradio['image_download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
||||
shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.image_model_name == 'None' else 'Ready')
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
|
|
@ -220,6 +220,28 @@ def create_event_handlers():
|
|||
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
|
||||
|
||||
# Image model event handlers
|
||||
shared.gradio['image_load_model'].click(
|
||||
load_image_model_wrapper,
|
||||
gradio('image_model_menu'),
|
||||
gradio('image_model_status'),
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
shared.gradio['image_unload_model'].click(
|
||||
handle_unload_image_model_click,
|
||||
None,
|
||||
gradio('image_model_status'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['image_download_model_button'].click(
|
||||
download_image_model_wrapper,
|
||||
gradio('image_custom_model_menu'),
|
||||
gradio('image_model_status'),
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
|
||||
def load_model_wrapper(selected_model, loader, autoload=False):
|
||||
try:
|
||||
|
|
@ -471,3 +493,66 @@ def format_file_size(size_bytes):
|
|||
return f"{s:.2f} {size_names[i]}"
|
||||
else:
|
||||
return f"{s:.1f} {size_names[i]}"
|
||||
|
||||
|
||||
def load_image_model_wrapper(selected_model):
|
||||
"""Wrapper for loading image models with status updates."""
|
||||
from modules.image_models import load_image_model, unload_image_model
|
||||
|
||||
if selected_model == 'None' or not selected_model:
|
||||
yield "No model selected"
|
||||
return
|
||||
|
||||
try:
|
||||
yield f"Loading `{selected_model}`..."
|
||||
unload_image_model()
|
||||
result = load_image_model(selected_model)
|
||||
|
||||
if result is not None:
|
||||
yield f"Successfully loaded `{selected_model}`."
|
||||
else:
|
||||
yield f"Failed to load `{selected_model}`."
|
||||
except Exception:
|
||||
exc = traceback.format_exc()
|
||||
yield exc.replace('\n', '\n\n')
|
||||
|
||||
|
||||
def handle_unload_image_model_click():
|
||||
"""Handler for the image model unload button."""
|
||||
from modules.image_models import unload_image_model
|
||||
unload_image_model()
|
||||
return "Image model unloaded"
|
||||
|
||||
|
||||
def download_image_model_wrapper(custom_model):
|
||||
"""Download an image model from Hugging Face."""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if not custom_model:
|
||||
yield "No model specified"
|
||||
return
|
||||
|
||||
try:
|
||||
# Parse model name and branch
|
||||
if ':' in custom_model:
|
||||
model_name, branch = custom_model.rsplit(':', 1)
|
||||
else:
|
||||
model_name, branch = custom_model, 'main'
|
||||
|
||||
# Output folder
|
||||
output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1]
|
||||
|
||||
yield f"Downloading `{model_name}` (branch: {branch})..."
|
||||
|
||||
snapshot_download(
|
||||
repo_id=model_name,
|
||||
revision=branch,
|
||||
local_dir=output_folder,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
yield f"Model successfully saved to `{output_folder}/`."
|
||||
|
||||
except Exception:
|
||||
exc = traceback.format_exc()
|
||||
yield exc.replace('\n', '\n\n')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue