mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Small fixes
This commit is contained in:
parent
21f992e7f7
commit
666816a773
|
|
@ -38,19 +38,19 @@ def clamp(value, min_val, max_val):
|
||||||
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
|
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
|
||||||
"""
|
"""
|
||||||
Apply an aspect ratio preset.
|
Apply an aspect ratio preset.
|
||||||
|
|
||||||
Logic to prevent dimension creep:
|
Logic to prevent dimension creep:
|
||||||
- For tall ratios (like 9:16): keep width fixed, calculate height
|
- For tall ratios (like 9:16): keep width fixed, calculate height
|
||||||
- For wide ratios (like 16:9): keep height fixed, calculate width
|
- For wide ratios (like 16:9): keep height fixed, calculate width
|
||||||
- For square (1:1): use the smaller of the current dimensions
|
- For square (1:1): use the smaller of the current dimensions
|
||||||
|
|
||||||
Returns (new_width, new_height).
|
Returns (new_width, new_height).
|
||||||
"""
|
"""
|
||||||
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
||||||
return current_width, current_height
|
return current_width, current_height
|
||||||
|
|
||||||
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||||
|
|
||||||
if w_ratio == h_ratio:
|
if w_ratio == h_ratio:
|
||||||
# Square ratio - use the smaller current dimension to prevent creep
|
# Square ratio - use the smaller current dimension to prevent creep
|
||||||
base = min(current_width, current_height)
|
base = min(current_width, current_height)
|
||||||
|
|
@ -64,11 +64,11 @@ def apply_aspect_ratio(aspect_ratio, current_width, current_height):
|
||||||
# Wide ratio (like 16:9) - height is the smaller side, keep it fixed
|
# Wide ratio (like 16:9) - height is the smaller side, keep it fixed
|
||||||
new_height = current_height
|
new_height = current_height
|
||||||
new_width = round_to_step(current_height * w_ratio / h_ratio)
|
new_width = round_to_step(current_height * w_ratio / h_ratio)
|
||||||
|
|
||||||
# Clamp to slider bounds
|
# Clamp to slider bounds
|
||||||
new_width = clamp(new_width, 256, 2048)
|
new_width = clamp(new_width, 256, 2048)
|
||||||
new_height = clamp(new_height, 256, 2048)
|
new_height = clamp(new_height, 256, 2048)
|
||||||
|
|
||||||
return int(new_width), int(new_height)
|
return int(new_width), int(new_height)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -76,11 +76,11 @@ def update_height_from_width(width, aspect_ratio):
|
||||||
"""Update height when width changes (if not Custom)."""
|
"""Update height when width changes (if not Custom)."""
|
||||||
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
||||||
return gr.update()
|
return gr.update()
|
||||||
|
|
||||||
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||||
new_height = round_to_step(width * h_ratio / w_ratio)
|
new_height = round_to_step(width * h_ratio / w_ratio)
|
||||||
new_height = clamp(new_height, 256, 2048)
|
new_height = clamp(new_height, 256, 2048)
|
||||||
|
|
||||||
return int(new_height)
|
return int(new_height)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -88,18 +88,18 @@ def update_width_from_height(height, aspect_ratio):
|
||||||
"""Update width when height changes (if not Custom)."""
|
"""Update width when height changes (if not Custom)."""
|
||||||
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
||||||
return gr.update()
|
return gr.update()
|
||||||
|
|
||||||
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||||
new_width = round_to_step(height * w_ratio / h_ratio)
|
new_width = round_to_step(height * w_ratio / h_ratio)
|
||||||
new_width = clamp(new_width, 256, 2048)
|
new_width = clamp(new_width, 256, 2048)
|
||||||
|
|
||||||
return int(new_width)
|
return int(new_width)
|
||||||
|
|
||||||
|
|
||||||
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
|
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
|
||||||
"""Swap dimensions and update aspect ratio to match (or set to Custom)."""
|
"""Swap dimensions and update aspect ratio to match (or set to Custom)."""
|
||||||
new_width, new_height = height, width
|
new_width, new_height = height, width
|
||||||
|
|
||||||
# Try to find a matching aspect ratio for the swapped dimensions
|
# Try to find a matching aspect ratio for the swapped dimensions
|
||||||
new_ratio = "Custom"
|
new_ratio = "Custom"
|
||||||
for name, ratios in ASPECT_RATIOS.items():
|
for name, ratios in ASPECT_RATIOS.items():
|
||||||
|
|
@ -111,27 +111,27 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
|
||||||
if abs(expected_height - new_height) < STEP:
|
if abs(expected_height - new_height) < STEP:
|
||||||
new_ratio = name
|
new_ratio = name
|
||||||
break
|
break
|
||||||
|
|
||||||
return new_width, new_height, new_ratio
|
return new_width, new_height, new_ratio
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
# Get effective settings (CLI > yaml > defaults)
|
# Get effective settings (CLI > yaml > defaults)
|
||||||
settings = get_effective_settings()
|
settings = get_effective_settings()
|
||||||
|
|
||||||
# Update shared state (but don't load the model yet)
|
# Update shared state (but don't load the model yet)
|
||||||
if settings['model_name'] != 'None':
|
if settings['model_name'] != 'None':
|
||||||
shared.image_model_name = settings['model_name']
|
shared.image_model_name = settings['model_name']
|
||||||
|
|
||||||
with gr.Tab("Image AI", elem_id="image-ai-tab"):
|
with gr.Tab("Image AI", elem_id="image-ai-tab"):
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
# TAB 1: GENERATION STUDIO
|
# TAB 1: GENERATION STUDIO
|
||||||
with gr.TabItem("Generate"):
|
with gr.TabItem("Generate"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
# === LEFT COLUMN: CONTROLS ===
|
# === LEFT COLUMN: CONTROLS ===
|
||||||
with gr.Column(scale=4, min_width=350):
|
with gr.Column(scale=4, min_width=350):
|
||||||
|
|
||||||
# 1. PROMPT
|
# 1. PROMPT
|
||||||
prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True)
|
prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True)
|
||||||
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3)
|
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3)
|
||||||
|
|
@ -170,12 +170,12 @@ def create_ui():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
|
batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
|
||||||
batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
|
batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
|
||||||
|
|
||||||
# === RIGHT COLUMN: VIEWPORT ===
|
# === RIGHT COLUMN: VIEWPORT ===
|
||||||
with gr.Column(scale=6, min_width=500):
|
with gr.Column(scale=6, min_width=500):
|
||||||
with gr.Column(elem_classes=["viewport-container"]):
|
with gr.Column(elem_classes=["viewport-container"]):
|
||||||
output_gallery = gr.Gallery(
|
output_gallery = gr.Gallery(
|
||||||
label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
|
label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
used_seed = gr.Markdown(label="Info", interactive=False)
|
used_seed = gr.Markdown(label="Info", interactive=False)
|
||||||
|
|
@ -203,9 +203,9 @@ def create_ui():
|
||||||
image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
|
image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
|
||||||
image_load_model = gr.Button("Load", variant='primary', elem_classes='refresh-button')
|
image_load_model = gr.Button("Load", variant='primary', elem_classes='refresh-button')
|
||||||
image_unload_model = gr.Button("Unload", elem_classes='refresh-button')
|
image_unload_model = gr.Button("Unload", elem_classes='refresh-button')
|
||||||
|
|
||||||
gr.Markdown("## Settings")
|
gr.Markdown("## Settings")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
image_dtype = gr.Dropdown(
|
image_dtype = gr.Dropdown(
|
||||||
|
|
@ -214,14 +214,14 @@ def create_ui():
|
||||||
label='Data Type',
|
label='Data Type',
|
||||||
info='bfloat16 recommended for modern GPUs'
|
info='bfloat16 recommended for modern GPUs'
|
||||||
)
|
)
|
||||||
|
|
||||||
image_attn_backend = gr.Dropdown(
|
image_attn_backend = gr.Dropdown(
|
||||||
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
|
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
|
||||||
value=settings['attn_backend'],
|
value=settings['attn_backend'],
|
||||||
label='Attention Backend',
|
label='Attention Backend',
|
||||||
info='SDPA is default. Flash Attention requires compatible GPU.'
|
info='SDPA is default. Flash Attention requires compatible GPU.'
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
image_compile = gr.Checkbox(
|
image_compile = gr.Checkbox(
|
||||||
value=settings['compile_model'],
|
value=settings['compile_model'],
|
||||||
|
|
@ -234,7 +234,7 @@ def create_ui():
|
||||||
label='CPU Offload',
|
label='CPU Offload',
|
||||||
info='Enable for low VRAM GPUs. Slower but uses less memory.'
|
info='Enable for low VRAM GPUs. Slower but uses less memory.'
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
image_download_path = gr.Textbox(
|
image_download_path = gr.Textbox(
|
||||||
label="Download model",
|
label="Download model",
|
||||||
|
|
@ -247,7 +247,7 @@ def create_ui():
|
||||||
)
|
)
|
||||||
|
|
||||||
# === WIRING ===
|
# === WIRING ===
|
||||||
|
|
||||||
# Aspect ratio preset changes -> update dimensions
|
# Aspect ratio preset changes -> update dimensions
|
||||||
preset_radio.change(
|
preset_radio.change(
|
||||||
fn=apply_aspect_ratio,
|
fn=apply_aspect_ratio,
|
||||||
|
|
@ -255,7 +255,7 @@ def create_ui():
|
||||||
outputs=[width_slider, height_slider],
|
outputs=[width_slider, height_slider],
|
||||||
show_progress=False
|
show_progress=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Width slider changes -> update height (if not Custom)
|
# Width slider changes -> update height (if not Custom)
|
||||||
width_slider.release(
|
width_slider.release(
|
||||||
fn=update_height_from_width,
|
fn=update_height_from_width,
|
||||||
|
|
@ -263,7 +263,7 @@ def create_ui():
|
||||||
outputs=[height_slider],
|
outputs=[height_slider],
|
||||||
show_progress=False
|
show_progress=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Height slider changes -> update width (if not Custom)
|
# Height slider changes -> update width (if not Custom)
|
||||||
height_slider.release(
|
height_slider.release(
|
||||||
fn=update_width_from_height,
|
fn=update_width_from_height,
|
||||||
|
|
@ -271,7 +271,7 @@ def create_ui():
|
||||||
outputs=[width_slider],
|
outputs=[width_slider],
|
||||||
show_progress=False
|
show_progress=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Swap button -> swap dimensions and update aspect ratio
|
# Swap button -> swap dimensions and update aspect ratio
|
||||||
swap_btn.click(
|
swap_btn.click(
|
||||||
fn=swap_dimensions_and_update_ratio,
|
fn=swap_dimensions_and_update_ratio,
|
||||||
|
|
@ -299,7 +299,7 @@ def create_ui():
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=outputs
|
outputs=outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model tab events
|
# Model tab events
|
||||||
image_refresh_models.click(
|
image_refresh_models.click(
|
||||||
fn=lambda: gr.update(choices=utils.get_available_image_models()),
|
fn=lambda: gr.update(choices=utils.get_available_image_models()),
|
||||||
|
|
@ -307,28 +307,28 @@ def create_ui():
|
||||||
outputs=[image_model_menu],
|
outputs=[image_model_menu],
|
||||||
show_progress=False
|
show_progress=False
|
||||||
)
|
)
|
||||||
|
|
||||||
image_load_model.click(
|
image_load_model.click(
|
||||||
fn=load_image_model_wrapper,
|
fn=load_image_model_wrapper,
|
||||||
inputs=[image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile],
|
inputs=[image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile],
|
||||||
outputs=[image_model_status],
|
outputs=[image_model_status],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
|
|
||||||
image_unload_model.click(
|
image_unload_model.click(
|
||||||
fn=unload_image_model_wrapper,
|
fn=unload_image_model_wrapper,
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=[image_model_status],
|
outputs=[image_model_status],
|
||||||
show_progress=False
|
show_progress=False
|
||||||
)
|
)
|
||||||
|
|
||||||
image_download_btn.click(
|
image_download_btn.click(
|
||||||
fn=download_image_model_wrapper,
|
fn=download_image_model_wrapper,
|
||||||
inputs=[image_download_path],
|
inputs=[image_download_path],
|
||||||
outputs=[image_model_status, image_model_menu],
|
outputs=[image_model_status, image_model_menu],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# History
|
# History
|
||||||
refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery, show_progress=False)
|
refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery, show_progress=False)
|
||||||
|
|
||||||
|
|
@ -336,40 +336,39 @@ def create_ui():
|
||||||
def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq,
|
def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq,
|
||||||
model_menu, dtype_dropdown, attn_dropdown, cpu_offload_checkbox, compile_checkbox):
|
model_menu, dtype_dropdown, attn_dropdown, cpu_offload_checkbox, compile_checkbox):
|
||||||
"""Generate images with the current model settings."""
|
"""Generate images with the current model settings."""
|
||||||
|
|
||||||
# Get current UI values (these are Gradio components, we need their values)
|
|
||||||
model_name = shared.image_model_name
|
model_name = shared.image_model_name
|
||||||
|
|
||||||
if model_name == 'None':
|
if model_name == 'None':
|
||||||
return [], "No image model selected. Go to the Model tab and select a model."
|
return [], "No image model selected. Go to the Model tab and select a model."
|
||||||
|
|
||||||
# Auto-load model if not loaded
|
# Auto-load model if not loaded
|
||||||
if shared.image_model is None:
|
if shared.image_model is None:
|
||||||
# Load saved settings for the model
|
# Get effective settings (CLI > yaml > defaults)
|
||||||
saved_settings = load_image_model_settings()
|
settings = get_effective_settings()
|
||||||
|
|
||||||
result = load_image_model(
|
result = load_image_model(
|
||||||
model_name,
|
model_name,
|
||||||
dtype=saved_settings['dtype'],
|
dtype=settings['dtype'],
|
||||||
attn_backend=saved_settings['attn_backend'],
|
attn_backend=settings['attn_backend'],
|
||||||
cpu_offload=saved_settings['cpu_offload'],
|
cpu_offload=settings['cpu_offload'],
|
||||||
compile_model=saved_settings['compile_model']
|
compile_model=settings['compile_model']
|
||||||
)
|
)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
return [], f"Failed to load model `{model_name}`."
|
return [], f"Failed to load model `{model_name}`."
|
||||||
|
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = np.random.randint(0, 2**32 - 1)
|
seed = np.random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
generator = torch.Generator("cuda").manual_seed(int(seed))
|
generator = torch.Generator("cuda").manual_seed(int(seed))
|
||||||
all_images = []
|
all_images = []
|
||||||
|
|
||||||
# Sequential loop (easier on VRAM)
|
# Sequential loop (easier on VRAM)
|
||||||
for i in range(int(batch_count_seq)):
|
for i in range(int(batch_count_seq)):
|
||||||
current_seed = seed + i
|
current_seed = seed + i
|
||||||
generator.manual_seed(int(current_seed))
|
generator.manual_seed(int(current_seed))
|
||||||
|
|
||||||
# Parallel generation
|
# Parallel generation
|
||||||
batch_results = shared.image_model(
|
batch_results = shared.image_model(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|
@ -381,12 +380,12 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel
|
||||||
num_images_per_prompt=int(batch_size_parallel),
|
num_images_per_prompt=int(batch_size_parallel),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
).images
|
).images
|
||||||
|
|
||||||
all_images.extend(batch_results)
|
all_images.extend(batch_results)
|
||||||
|
|
||||||
# Save to disk
|
# Save to disk
|
||||||
save_generated_images(all_images, prompt, seed)
|
save_generated_images(all_images, prompt, seed)
|
||||||
|
|
||||||
return all_images, f"Seed: {seed}"
|
return all_images, f"Seed: {seed}"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -395,13 +394,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
|
||||||
if model_name == 'None' or not model_name:
|
if model_name == 'None' or not model_name:
|
||||||
yield "No model selected"
|
yield "No model selected"
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield f"Loading `{model_name}`..."
|
yield f"Loading `{model_name}`..."
|
||||||
|
|
||||||
# Unload existing model first
|
# Unload existing model first
|
||||||
unload_image_model()
|
unload_image_model()
|
||||||
|
|
||||||
# Load the new model
|
# Load the new model
|
||||||
result = load_image_model(
|
result = load_image_model(
|
||||||
model_name,
|
model_name,
|
||||||
|
|
@ -410,14 +409,14 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
compile_model=compile_model
|
compile_model=compile_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
# Save settings to yaml
|
# Save settings to yaml
|
||||||
save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model)
|
save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model)
|
||||||
yield f"✓ Loaded **{model_name}**"
|
yield f"✓ Loaded **{model_name}**"
|
||||||
else:
|
else:
|
||||||
yield f"✗ Failed to load `{model_name}`"
|
yield f"✗ Failed to load `{model_name}`"
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
exc = traceback.format_exc()
|
exc = traceback.format_exc()
|
||||||
yield f"Error:\n```\n{exc}\n```"
|
yield f"Error:\n```\n{exc}\n```"
|
||||||
|
|
@ -426,7 +425,7 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
|
||||||
def unload_image_model_wrapper():
|
def unload_image_model_wrapper():
|
||||||
"""Unload model wrapper."""
|
"""Unload model wrapper."""
|
||||||
unload_image_model()
|
unload_image_model()
|
||||||
|
|
||||||
if shared.image_model_name != 'None':
|
if shared.image_model_name != 'None':
|
||||||
return f"Model: **{shared.image_model_name}** (not loaded)"
|
return f"Model: **{shared.image_model_name}** (not loaded)"
|
||||||
else:
|
else:
|
||||||
|
|
@ -436,36 +435,36 @@ def unload_image_model_wrapper():
|
||||||
def download_image_model_wrapper(model_path):
|
def download_image_model_wrapper(model_path):
|
||||||
"""Download a model from Hugging Face."""
|
"""Download a model from Hugging Face."""
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
if not model_path:
|
if not model_path:
|
||||||
yield "No model specified", gr.update()
|
yield "No model specified", gr.update()
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Parse model name and branch
|
# Parse model name and branch
|
||||||
if ':' in model_path:
|
if ':' in model_path:
|
||||||
model_id, branch = model_path.rsplit(':', 1)
|
model_id, branch = model_path.rsplit(':', 1)
|
||||||
else:
|
else:
|
||||||
model_id, branch = model_path, 'main'
|
model_id, branch = model_path, 'main'
|
||||||
|
|
||||||
# Output folder name
|
# Output folder name
|
||||||
folder_name = model_id.split('/')[-1]
|
folder_name = model_id.split('/')[-1]
|
||||||
output_folder = Path(shared.args.image_model_dir) / folder_name
|
output_folder = Path(shared.args.image_model_dir) / folder_name
|
||||||
|
|
||||||
yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
|
yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
|
||||||
|
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=model_id,
|
repo_id=model_id,
|
||||||
revision=branch,
|
revision=branch,
|
||||||
local_dir=output_folder,
|
local_dir=output_folder,
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Refresh the model list
|
# Refresh the model list
|
||||||
new_choices = utils.get_available_image_models()
|
new_choices = utils.get_available_image_models()
|
||||||
|
|
||||||
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
|
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
exc = traceback.format_exc()
|
exc = traceback.format_exc()
|
||||||
yield f"Error:\n```\n{exc}\n```", gr.update()
|
yield f"Error:\n```\n{exc}\n```", gr.update()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue