mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-20 07:30:19 +01:00
Image: Several fixes
This commit is contained in:
parent
8eac99599a
commit
afa29b9554
|
|
@ -28,8 +28,7 @@ A Gradio web UI for Large Language Models.
|
|||
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
||||
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
||||
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
||||
Image generation: A dedicated tab for diffusers models like Z-Image-Turbo and Qwen-Image. Features 4-bit/8-bit quantization and a persistent gallery with metadata (tutorial).
|
||||
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo** and **Qwen-Image**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
|
||||
- Aesthetic UI with dark and light themes.
|
||||
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
|
||||
|
|
|
|||
|
|
@ -43,6 +43,9 @@ def generations(request):
|
|||
for images, _ in generate(state, save_images=False):
|
||||
pass
|
||||
|
||||
if not images:
|
||||
raise ServiceUnavailableError("Image generation failed or produced no images.")
|
||||
|
||||
# Build response
|
||||
resp = {'created': int(time.time()), 'data': []}
|
||||
for img in images:
|
||||
|
|
|
|||
|
|
@ -141,16 +141,24 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
if not cpu_offload:
|
||||
pipe.to(get_device())
|
||||
|
||||
# Set attention backend (if supported by the pipeline)
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
|
||||
if attn_backend == 'flash_attention_2':
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
# sdpa is the default, no action needed
|
||||
modules = ["transformer", "unet"]
|
||||
|
||||
# Set attention backend
|
||||
if attn_backend == 'flash_attention_2':
|
||||
for name in modules:
|
||||
mod = getattr(pipe, name, None)
|
||||
if hasattr(mod, "set_attention_backend"):
|
||||
mod.set_attention_backend("flash")
|
||||
break
|
||||
|
||||
# Compile model
|
||||
if compile_model:
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
pipe.transformer.compile()
|
||||
for name in modules:
|
||||
mod = getattr(pipe, name, None)
|
||||
if hasattr(mod, "compile"):
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
mod.compile()
|
||||
break
|
||||
|
||||
if cpu_offload:
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
|
|
|||
|
|
@ -43,10 +43,6 @@ METADATA_SETTINGS_KEYS = [
|
|||
'image_cfg_scale',
|
||||
]
|
||||
|
||||
# Cache for all image paths
|
||||
_image_cache = []
|
||||
_cache_timestamp = 0
|
||||
|
||||
|
||||
def round_to_step(value, step=STEP):
|
||||
return round(value / step) * step
|
||||
|
|
@ -134,6 +130,9 @@ def build_generation_metadata(state, actual_seed):
|
|||
|
||||
def save_generated_images(images, state, actual_seed):
|
||||
"""Save images with generation metadata embedded in PNG."""
|
||||
if shared.args.multi_user:
|
||||
return
|
||||
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
|
@ -157,9 +156,14 @@ def save_generated_images(images, state, actual_seed):
|
|||
def read_image_metadata(image_path):
|
||||
"""Read generation metadata from PNG file."""
|
||||
try:
|
||||
with open_image_safely(image_path) as img:
|
||||
img = open_image_safely(image_path)
|
||||
if img is None:
|
||||
return None
|
||||
try:
|
||||
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
|
||||
return json.loads(img.text['image_gen_settings'])
|
||||
finally:
|
||||
img.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not read metadata from {image_path}: {e}")
|
||||
return None
|
||||
|
|
@ -198,19 +202,12 @@ def format_metadata_for_display(metadata):
|
|||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def get_all_history_images(force_refresh=False):
|
||||
"""Get all history images sorted by modification time (newest first). Uses caching."""
|
||||
global _image_cache, _cache_timestamp
|
||||
|
||||
def get_all_history_images():
|
||||
"""Get all history images sorted by modification time (newest first)."""
|
||||
output_dir = os.path.join("user_data", "image_outputs")
|
||||
if not os.path.exists(output_dir):
|
||||
return []
|
||||
|
||||
# Check if we need to refresh cache
|
||||
current_time = time.time()
|
||||
if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
|
||||
return _image_cache
|
||||
|
||||
image_files = []
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for file in files:
|
||||
|
|
@ -219,15 +216,12 @@ def get_all_history_images(force_refresh=False):
|
|||
image_files.append((full_path, os.path.getmtime(full_path)))
|
||||
|
||||
image_files.sort(key=lambda x: x[1], reverse=True)
|
||||
_image_cache = [x[0] for x in image_files]
|
||||
_cache_timestamp = current_time
|
||||
|
||||
return _image_cache
|
||||
return [x[0] for x in image_files]
|
||||
|
||||
|
||||
def get_paginated_images(page=0, force_refresh=False):
|
||||
def get_paginated_images(page=0):
|
||||
"""Get images for a specific page."""
|
||||
all_images = get_all_history_images(force_refresh)
|
||||
all_images = get_all_history_images()
|
||||
total_images = len(all_images)
|
||||
total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
|
||||
|
||||
|
|
@ -250,7 +244,7 @@ def get_initial_page_info():
|
|||
|
||||
def refresh_gallery(current_page=0):
|
||||
"""Refresh gallery with current page."""
|
||||
images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
|
||||
images, page, total_pages, total_images = get_paginated_images(current_page)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
|
@ -286,11 +280,7 @@ def on_gallery_select(evt: gr.SelectData, current_page):
|
|||
if evt.index is None:
|
||||
return "", "Select an image to view its settings"
|
||||
|
||||
if not _image_cache:
|
||||
get_all_history_images()
|
||||
|
||||
# Get the current page's images to find the actual file path
|
||||
all_images = _image_cache
|
||||
all_images = get_all_history_images()
|
||||
total_images = len(all_images)
|
||||
|
||||
# Calculate the actual index in the full list
|
||||
|
|
|
|||
Loading…
Reference in a new issue