Compare commits

...

28 commits

Author SHA1 Message Date
oobabooga 536b433a30
Merge 366fe353f0 into a83821e941 2025-12-01 15:54:45 -03:00
oobabooga 366fe353f0 Revert CSS changes 2025-12-01 10:53:17 -08:00
oobabooga 9b07a83330 Populate the history gallery by default 2025-12-01 10:51:12 -08:00
oobabooga e301dd231e Remove some emojis 2025-12-01 10:49:22 -08:00
oobabooga 5b385dc546 Make the image galleries taller 2025-12-01 10:48:55 -08:00
oobabooga b42192c2b7 Implement settings autosaving 2025-12-01 10:43:42 -08:00
oobabooga c51e686135 Merge branch 'dev' into image_generation 2025-12-01 10:34:38 -08:00
oobabooga a83821e941 Revert "UI: Optimize typing in all textareas"
This reverts commit e24ba92ef2.
2025-12-01 10:34:23 -08:00
oobabooga 41618cf799 Merge branch 'dev' into image_generation 2025-12-01 09:35:22 -08:00
oobabooga 24fd963c38 Merge remote-tracking branch 'refs/remotes/origin/dev' into dev 2025-12-01 08:06:08 -08:00
oobabooga e24ba92ef2 UI: Optimize typing in all textareas 2025-12-01 08:05:21 -08:00
oobabooga cecb172d2c Add the code for 4-bit quantization 2025-11-27 18:29:32 -08:00
oobabooga 742db85de0 Hardcode 8-bit quantization for now 2025-11-27 18:23:36 -08:00
oobabooga 822e74ac97 Lint 2025-11-27 18:15:15 -08:00
oobabooga 30d1f502aa More informative download message 2025-11-27 16:37:03 -08:00
oobabooga 74eedf6050 Remove the CFG slider 2025-11-27 16:28:40 -08:00
oobabooga 9e33c6bfb7 Add missing files 2025-11-27 15:56:58 -08:00
oobabooga 666816a773 Small fixes 2025-11-27 15:48:53 -08:00
oobabooga 21f992e7f7 Organize the UI 2025-11-27 15:42:11 -08:00
oobabooga 148a5d1e44 Keep things more modular 2025-11-27 15:32:01 -08:00
oobabooga 0adda7a5c5 Lint 2025-11-27 14:39:21 -08:00
oobabooga aa074409cb Better events for the dimensions 2025-11-27 14:38:50 -08:00
oobabooga be799ba8eb Lint 2025-11-27 14:25:49 -08:00
oobabooga a873692234 Image generation now functional 2025-11-27 14:24:35 -08:00
oobabooga 2f11b3040d Add functions 2025-11-27 13:53:46 -08:00
oobabooga aa63c612de Progress on model loading 2025-11-27 13:46:54 -08:00
oobabooga 164c6fcdbf Add the UI structure 2025-11-27 13:44:07 -08:00
oobabooga 4ad2ad468e Add basic structure 2025-11-27 10:10:11 -08:00
7 changed files with 642 additions and 5 deletions

View file

@ -93,11 +93,11 @@ ol li p, ul li p {
display: inline-block;
}
#notebook-parent-tab, #chat-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab {
#notebook-parent-tab, #chat-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab, #image-ai-tab {
border: 0;
}
#notebook-parent-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab {
#notebook-parent-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab, #image-ai-tab {
padding: 1rem;
}
@ -1674,3 +1674,10 @@ button:focus {
.dark .sidebar-vertical-separator {
border-bottom: 1px solid rgb(255 255 255 / 10%);
}
button#swap-height-width {
position: absolute;
top: -50px;
right: 0;
border: 0;
}

97
modules/image_models.py Normal file
View file

@ -0,0 +1,97 @@
import time
import torch
import modules.shared as shared
from modules.logging_colors import logger
from modules.torch_utils import get_device
from modules.utils import resolve_model_path
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
"""
Load a diffusers image generation model.
Args:
model_name: Name of the model directory
dtype: 'bfloat16' or 'float16'
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
"""
from diffusers import PipelineQuantizationConfig, ZImagePipeline
logger.info(f"Loading image model \"{model_name}\"")
t0 = time.time()
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
target_dtype = dtype_map.get(dtype, torch.bfloat16)
model_path = resolve_model_path(model_name, image_model=True)
try:
# Define quantization config for 8-bit
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
)
# Define quantization config for 4-bit
# pipeline_quant_config = PipelineQuantizationConfig(
# quant_backend="bitsandbytes_4bit",
# quant_kwargs={
# "load_in_4bit": True,
# "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
# "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
# "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
# },
# )
pipe = ZImagePipeline.from_pretrained(
str(model_path),
quantization_config=pipeline_quant_config,
torch_dtype=target_dtype,
low_cpu_mem_usage=True,
)
if not cpu_offload:
pipe.to(get_device())
# Set attention backend
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
elif attn_backend == 'flash_attention_3':
pipe.transformer.set_attention_backend("_flash_3")
# sdpa is the default, no action needed
if compile_model:
logger.info("Compiling model (first run will be slow)...")
pipe.transformer.compile()
if cpu_offload:
pipe.enable_model_cpu_offload()
shared.image_model = pipe
shared.image_model_name = model_name
logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
return pipe
except Exception as e:
logger.error(f"Failed to load image model: {str(e)}")
return None
def unload_image_model():
"""Unload the current image model and free VRAM."""
if shared.image_model is None:
return
del shared.image_model
shared.image_model = None
shared.image_model_name = 'None'
from modules.torch_utils import clear_torch_cache
clear_torch_cache()
logger.info("Image model unloaded.")

View file

@ -11,7 +11,7 @@ import yaml
from modules.logging_colors import logger
from modules.presets import default_preset
# Model variables
# Text model variables
model = None
tokenizer = None
model_name = 'None'
@ -20,6 +20,10 @@ is_multimodal = False
model_dirty_from_training = False
lora_names = []
# Image model variables
image_model = None
image_model_name = 'None'
# Generation variables
stop_everything = False
generation_lock = None
@ -46,6 +50,15 @@ group.add_argument('--extensions', type=str, nargs='+', help='The list of extens
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
# Image generation
group = parser.add_argument_group('Image model')
group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).')
group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.')
group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
# Model loader
group = parser.add_argument_group('Model loader')
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, TensorRT-LLM.')
@ -290,6 +303,22 @@ settings = {
# Extensions
'default_extensions': [],
# Image generation settings
'image_prompt': '',
'image_neg_prompt': '',
'image_width': 1024,
'image_height': 1024,
'image_aspect_ratio': '1:1 Square',
'image_steps': 9,
'image_seed': -1,
'image_batch_size': 1,
'image_batch_count': 1,
'image_model_menu': 'None',
'image_dtype': 'bfloat16',
'image_attn_backend': 'sdpa',
'image_compile': False,
'image_cpu_offload': False,
}
default_settings = copy.deepcopy(settings)
@ -314,6 +343,20 @@ def do_cmd_flags_warnings():
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
def apply_image_model_cli_overrides():
"""Apply CLI flags for image model settings, overriding saved settings."""
if args.image_model:
settings['image_model_menu'] = args.image_model
if args.image_dtype is not None:
settings['image_dtype'] = args.image_dtype
if args.image_attn_backend is not None:
settings['image_attn_backend'] = args.image_attn_backend
if args.image_cpu_offload:
settings['image_cpu_offload'] = True
if args.image_compile:
settings['image_compile'] = True
def fix_loader_name(name):
if not name:
return name

View file

@ -280,6 +280,24 @@ def list_interface_input_elements():
'include_past_attachments',
]
# Image generation elements
elements += [
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_model_menu',
'image_dtype',
'image_attn_backend',
'image_compile',
'image_cpu_offload',
]
return elements
@ -509,7 +527,21 @@ def setup_auto_save():
'theme_state',
'show_two_notebook_columns',
'paste_to_attachment',
'include_past_attachments'
'include_past_attachments',
# Image generation tab (ui_image_generation.py)
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_model_menu',
'image_dtype',
'image_attn_backend',
'image_compile',
'image_cpu_offload',
]
for element_name in change_elements:

View file

@ -0,0 +1,425 @@
import os
import traceback
from datetime import datetime
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from modules import shared, ui, utils
from modules.image_models import load_image_model, unload_image_model
from modules.utils import gradio
ASPECT_RATIOS = {
"1:1 Square": (1, 1),
"16:9 Cinema": (16, 9),
"9:16 Mobile": (9, 16),
"4:3 Photo": (4, 3),
"Custom": None,
}
STEP = 32
def round_to_step(value, step=STEP):
return round(value / step) * step
def clamp(value, min_val, max_val):
return max(min_val, min(max_val, value))
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return current_width, current_height
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
if w_ratio == h_ratio:
base = min(current_width, current_height)
new_width = base
new_height = base
elif w_ratio < h_ratio:
new_width = current_width
new_height = round_to_step(current_width * h_ratio / w_ratio)
else:
new_height = current_height
new_width = round_to_step(current_height * w_ratio / h_ratio)
new_width = clamp(new_width, 256, 2048)
new_height = clamp(new_height, 256, 2048)
return int(new_width), int(new_height)
def update_height_from_width(width, aspect_ratio):
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
new_height = round_to_step(width * h_ratio / w_ratio)
new_height = clamp(new_height, 256, 2048)
return int(new_height)
def update_width_from_height(height, aspect_ratio):
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
new_width = round_to_step(height * w_ratio / h_ratio)
new_width = clamp(new_width, 256, 2048)
return int(new_width)
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
new_width, new_height = height, width
new_ratio = "Custom"
for name, ratios in ASPECT_RATIOS.items():
if ratios is None:
continue
w_r, h_r = ratios
expected_height = new_width * h_r / w_r
if abs(expected_height - new_height) < STEP:
new_ratio = name
break
return new_width, new_height, new_ratio
def create_ui():
if shared.settings['image_model_menu'] != 'None':
shared.image_model_name = shared.settings['image_model_menu']
with gr.Tab("Image AI", elem_id="image-ai-tab"):
with gr.Tabs():
# TAB 1: GENERATE
with gr.TabItem("Generate"):
with gr.Row():
with gr.Column(scale=4, min_width=350):
shared.gradio['image_prompt'] = gr.Textbox(
label="Prompt",
placeholder="Describe your imagination...",
lines=3,
autofocus=True,
value=shared.settings['image_prompt']
)
shared.gradio['image_neg_prompt'] = gr.Textbox(
label="Negative Prompt",
placeholder="Low quality...",
lines=3,
value=shared.settings['image_neg_prompt']
)
shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
gr.Markdown("### Dimensions")
with gr.Row():
with gr.Column():
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
with gr.Column():
shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height")
shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
with gr.Row():
shared.gradio['image_aspect_ratio'] = gr.Radio(
choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
value=shared.settings['image_aspect_ratio'],
label="Aspect Ratio",
interactive=True
)
gr.Markdown("### Config")
with gr.Row():
with gr.Column():
shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps")
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
with gr.Column():
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
with gr.Row():
shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False)
# TAB 2: GALLERY
with gr.TabItem("Gallery"):
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
# TAB 3: MODEL
with gr.TabItem("Model"):
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['image_model_menu'] = gr.Dropdown(
choices=utils.get_available_image_models(),
value=shared.settings['image_model_menu'],
label='Model',
elem_classes='slim-dropdown'
)
shared.gradio['image_refresh_models'] = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
shared.gradio['image_load_model'] = gr.Button("Load", variant='primary', elem_classes='refresh-button')
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button')
gr.Markdown("## Settings")
with gr.Row():
with gr.Column():
shared.gradio['image_dtype'] = gr.Dropdown(
choices=['bfloat16', 'float16'],
value=shared.settings['image_dtype'],
label='Data Type',
info='bfloat16 recommended for modern GPUs'
)
shared.gradio['image_attn_backend'] = gr.Dropdown(
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
value=shared.settings['image_attn_backend'],
label='Attention Backend',
info='SDPA is default. Flash Attention requires compatible GPU.'
)
with gr.Column():
shared.gradio['image_compile'] = gr.Checkbox(
value=shared.settings['image_compile'],
label='Compile Model',
info='Faster inference after first run. First run will be slow.'
)
shared.gradio['image_cpu_offload'] = gr.Checkbox(
value=shared.settings['image_cpu_offload'],
label='CPU Offload',
info='Enable for low VRAM GPUs. Slower but uses less memory.'
)
with gr.Column():
shared.gradio['image_download_path'] = gr.Textbox(
label="Download model",
placeholder="Tongyi-MAI/Z-Image-Turbo",
info="Enter HuggingFace path. Use : for branch, e.g. user/model:main"
)
shared.gradio['image_download_btn'] = gr.Button("Download", variant='primary')
shared.gradio['image_model_status'] = gr.Markdown(
value=f"Model: **{shared.settings['image_model_menu']}** (not loaded)" if shared.settings['image_model_menu'] != 'None' else "No model selected"
)
def create_event_handlers():
# Dimension controls
shared.gradio['image_aspect_ratio'].change(
apply_aspect_ratio,
gradio('image_aspect_ratio', 'image_width', 'image_height'),
gradio('image_width', 'image_height'),
show_progress=False
)
shared.gradio['image_width'].release(
update_height_from_width,
gradio('image_width', 'image_aspect_ratio'),
gradio('image_height'),
show_progress=False
)
shared.gradio['image_height'].release(
update_width_from_height,
gradio('image_height', 'image_aspect_ratio'),
gradio('image_width'),
show_progress=False
)
shared.gradio['image_swap_btn'].click(
swap_dimensions_and_update_ratio,
gradio('image_width', 'image_height', 'image_aspect_ratio'),
gradio('image_width', 'image_height', 'image_aspect_ratio'),
show_progress=False
)
# Generation
shared.gradio['image_generate_btn'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
shared.gradio['image_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
shared.gradio['image_neg_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
# Model management
shared.gradio['image_refresh_models'].click(
lambda: gr.update(choices=utils.get_available_image_models()),
None,
gradio('image_model_menu'),
show_progress=False
)
shared.gradio['image_load_model'].click(
load_image_model_wrapper,
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
gradio('image_model_status'),
show_progress=True
)
shared.gradio['image_unload_model'].click(
unload_image_model_wrapper,
None,
gradio('image_model_status'),
show_progress=False
)
shared.gradio['image_download_btn'].click(
download_image_model_wrapper,
gradio('image_download_path'),
gradio('image_model_status', 'image_model_menu'),
show_progress=True
)
# History
shared.gradio['image_refresh_history'].click(
get_history_images,
None,
gradio('image_history_gallery'),
show_progress=False
)
def generate(state):
model_name = state['image_model_menu']
if not model_name or model_name == 'None':
return [], "No image model selected. Go to the Model tab and select a model."
if shared.image_model is None:
result = load_image_model(
model_name,
dtype=state['image_dtype'],
attn_backend=state['image_attn_backend'],
cpu_offload=state['image_cpu_offload'],
compile_model=state['image_compile']
)
if result is None:
return [], f"Failed to load model `{model_name}`."
shared.image_model_name = model_name
seed = state['image_seed']
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
for i in range(int(state['image_batch_count'])):
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(
prompt=state['image_prompt'],
negative_prompt=state['image_neg_prompt'],
height=int(state['image_height']),
width=int(state['image_width']),
num_inference_steps=int(state['image_steps']),
guidance_scale=0.0,
num_images_per_prompt=int(state['image_batch_size']),
generator=generator,
).images
all_images.extend(batch_results)
save_generated_images(all_images, state['image_prompt'], seed)
return all_images, f"Seed: {seed}"
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
if not model_name or model_name == 'None':
yield "No model selected"
return
try:
yield f"Loading `{model_name}`..."
unload_image_model()
result = load_image_model(
model_name,
dtype=dtype,
attn_backend=attn_backend,
cpu_offload=cpu_offload,
compile_model=compile_model
)
if result is not None:
shared.image_model_name = model_name
yield f"✓ Loaded **{model_name}**"
else:
yield f"✗ Failed to load `{model_name}`"
except Exception:
yield f"Error:\n```\n{traceback.format_exc()}\n```"
def unload_image_model_wrapper():
unload_image_model()
if shared.image_model_name != 'None':
return f"Model: **{shared.image_model_name}** (not loaded)"
return "No model loaded"
def download_image_model_wrapper(model_path):
from huggingface_hub import snapshot_download
if not model_path:
yield "No model specified", gr.update()
return
try:
if ':' in model_path:
model_id, branch = model_path.rsplit(':', 1)
else:
model_id, branch = model_path, 'main'
folder_name = model_id.replace('/', '_')
output_folder = Path(shared.args.image_model_dir) / folder_name
yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
snapshot_download(
repo_id=model_id,
revision=branch,
local_dir=output_folder,
local_dir_use_symlinks=False,
)
new_choices = utils.get_available_image_models()
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
except Exception:
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
def save_generated_images(images, prompt, seed):
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{seed}_{idx}.png"
img.save(os.path.join(folder_path, filename))
def get_history_images():
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
image_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]

View file

@ -86,7 +86,7 @@ def check_model_loaded():
return True, None
def resolve_model_path(model_name_or_path):
def resolve_model_path(model_name_or_path, image_model=False):
"""
Resolves a model path, checking for a direct path
before the default models directory.
@ -95,6 +95,8 @@ def resolve_model_path(model_name_or_path):
path_candidate = Path(model_name_or_path)
if path_candidate.exists():
return path_candidate
elif image_model:
return Path(f'{shared.args.image_model_dir}/{model_name_or_path}')
else:
return Path(f'{shared.args.model_dir}/{model_name_or_path}')
@ -153,6 +155,31 @@ def get_available_models():
return filtered_gguf_files + model_dirs
def get_available_image_models():
model_dir = Path(shared.args.image_model_dir)
# Find directories with safetensors files
dirs_with_safetensors = set()
for item in os.listdir(model_dir):
item_path = model_dir / item
if item_path.is_dir():
if any(file.lower().endswith(('.safetensors', '.pt')) for file in os.listdir(item_path) if (item_path / file).is_file()):
dirs_with_safetensors.add(item)
# Find valid model directories
model_dirs = []
for item in os.listdir(model_dir):
item_path = model_dir / item
if not item_path.is_dir():
continue
model_dirs.append(item)
model_dirs = sorted(model_dirs, key=natural_keys)
return model_dirs
def get_available_ggufs():
model_list = []
model_dir = Path(shared.args.model_dir)

View file

@ -50,6 +50,7 @@ from modules import (
ui_chat,
ui_default,
ui_file_saving,
ui_image_generation,
ui_model_menu,
ui_notebook,
ui_parameters,
@ -163,6 +164,7 @@ def create_interface():
ui_chat.create_character_settings_ui() # Character tab
ui_model_menu.create_ui() # Model tab
if not shared.args.portable:
ui_image_generation.create_ui() # Image generation tab
training.create_ui() # Training tab
ui_session.create_ui() # Session tab
@ -170,6 +172,7 @@ def create_interface():
ui_chat.create_event_handlers()
ui_default.create_event_handlers()
ui_notebook.create_event_handlers()
ui_image_generation.create_event_handlers()
# Other events
ui_file_saving.create_event_handlers()
@ -256,6 +259,9 @@ if __name__ == "__main__":
if new_settings:
shared.settings.update(new_settings)
# Apply CLI overrides for image model settings (CLI flags take precedence over saved settings)
shared.apply_image_model_cli_overrides()
# Fallback settings for models
shared.model_config['.*'] = get_fallback_settings()
shared.model_config.move_to_end('.*', last=False) # Move to the beginning