Compare commits

...

10 commits

Author SHA1 Message Date
dependabot[bot] 622f43fab2
Merge 3670ef56cf into c93d27add3 2025-12-04 02:34:17 +00:00
oobabooga c93d27add3 Update llama.cpp 2025-12-03 18:29:43 -08:00
oobabooga fbca54957e Image generation: Yield partial results for batch count > 1 2025-12-03 16:13:07 -08:00
oobabooga 49c60882bf Image generation: Safer image uploading 2025-12-03 16:07:51 -08:00
oobabooga 59285d501d Image generation: Small UI improvements 2025-12-03 16:03:31 -08:00
oobabooga 373baa5c9c UI: Minor image gallery improvements 2025-12-03 14:45:02 -08:00
oobabooga 906dc54969 Load --image-model before --model 2025-12-03 12:15:38 -08:00
oobabooga 4468c49439 Add semaphore to image generation API endpoint 2025-12-03 12:02:47 -08:00
oobabooga 5ad174fad2 docs: Add an image generation API example 2025-12-03 11:58:54 -08:00
oobabooga 5433ef3333 Add an API endpoint for generating images 2025-12-03 11:50:56 -08:00
27 changed files with 330 additions and 158 deletions

View file

@ -139,6 +139,35 @@ curl http://127.0.0.1:5000/v1/completions \
For base64-encoded images, just replace the inner "url" values with this format: `data:image/FORMAT;base64,BASE64_STRING` where FORMAT is the file type (png, jpeg, gif, etc.) and BASE64_STRING is your base64-encoded image data.
#### Image generation
```shell
curl http://127.0.0.1:5000/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"prompt": "an orange tree",
"steps": 9,
"cfg_scale": 0,
"batch_size": 1,
"batch_count": 1
}'
```
You need to load an image model first. You can do this via the UI, or by adding `--image-model your_model_name` when launching the server.
The output is a JSON object containing a `data` array. Each element has a `b64_json` field with the base64-encoded PNG image:
```json
{
"created": 1764791227,
"data": [
{
"b64_json": "iVBORw0KGgo..."
}
]
}
```
#### SSE streaming
```shell
@ -419,7 +448,6 @@ The following environment variables can be used (they take precedence over every
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | sentence-transformers/all-mpnet-base-v2 |
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
@ -430,7 +458,6 @@ You can also set the following variables in your `settings.yaml` file:
```
openai-embedding_device: cuda
openai-embedding_model: "sentence-transformers/all-mpnet-base-v2"
openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1
```

View file

@ -1,70 +1,154 @@
"""
OpenAI-compatible image generation using local diffusion models.
"""
import base64
import io
import json
import os
import time
from datetime import datetime
import requests
import numpy as np
from extensions.openai.errors import ServiceUnavailableError
from modules import shared
from modules.logging_colors import logger
from PIL.PngImagePlugin import PngInfo
def generations(prompt: str, size: str, response_format: str, n: int):
# Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
# it's too general an API to try and shape the result with specific tags like negative prompts
# or "masterpiece", etc. SD configuration is beyond the scope of this API.
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
sd_defaults = {
'sampler_name': 'DPM++ 2M Karras', # vast improvement
'steps': 30,
def generations(prompt: str, size: str, response_format: str, n: int,
negative_prompt: str = "", steps: int = 9, seed: int = -1,
cfg_scale: float = 0.0, batch_count: int = 1):
"""
Generate images using the loaded diffusion model.
Args:
prompt: Text description of the desired image
size: Image dimensions as "WIDTHxHEIGHT"
response_format: 'url' or 'b64_json'
n: Number of images per batch
negative_prompt: What to avoid in the image
steps: Number of inference steps
seed: Random seed (-1 for random)
cfg_scale: Classifier-free guidance scale
batch_count: Number of sequential batches
Returns:
dict with 'created' timestamp and 'data' list of images
"""
import torch
from modules.image_models import get_pipeline_type
from modules.torch_utils import clear_torch_cache, get_device
if shared.image_model is None:
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
clear_torch_cache()
# Parse dimensions
try:
width, height = [int(x) for x in size.split('x')]
except (ValueError, IndexError):
width, height = 1024, 1024
# Handle seed
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
device = get_device() or "cpu"
generator = torch.Generator(device).manual_seed(int(seed))
# Get pipeline type for CFG parameter name
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": steps,
"num_images_per_prompt": n,
"generator": generator,
}
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
# Pipeline-specific CFG parameter
if pipeline_type == 'qwenimage':
gen_kwargs["true_cfg_scale"] = cfg_scale
else:
gen_kwargs["guidance_scale"] = cfg_scale
# to hack on better generation, edit default payload.
payload = {
'prompt': prompt, # ignore prompt limit of 1000 characters
'width': width,
'height': height,
'batch_size': n,
}
payload.update(sd_defaults)
# Generate
all_images = []
t0 = time.time()
scale = min(width, height) / base_model_size
if scale >= 1.2:
# for better performance with the default size (1024), and larger res.
scaler = {
'width': width // scale,
'height': height // scale,
'hr_scale': scale,
'enable_hr': True,
'hr_upscaler': 'Latent',
'denoising_strength': 0.68,
}
payload.update(scaler)
shared.stop_everything = False
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
if shared.stop_everything:
pipe._interrupt = True
return callback_kwargs
gen_kwargs["callback_on_step_end"] = interrupt_callback
for i in range(batch_count):
if shared.stop_everything:
break
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
t1 = time.time()
total_images = len(all_images)
total_steps = steps * batch_count
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
# Save images
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
# Build response
resp = {
'created': int(time.time()),
'data': []
}
from extensions.openai.script import params
# TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
response = requests.post(url=sd_url, json=payload)
r = response.json()
if response.status_code != 200 or 'images' not in r:
print(r)
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
# r['parameters']...
for b64_json in r['images']:
for img in all_images:
b64 = _image_to_base64(img)
if response_format == 'b64_json':
resp['data'].extend([{'b64_json': b64_json}])
resp['data'].append({'b64_json': b64})
else:
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
resp['data'].append({'url': f'data:image/png;base64,{b64}'})
return resp
def _image_to_base64(image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
"""Save images with metadata."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder, exist_ok=True)
metadata = {
'image_prompt': prompt,
'image_neg_prompt': negative_prompt,
'image_width': width,
'image_height': height,
'image_steps': steps,
'image_seed': seed,
'image_cfg_scale': cfg_scale,
'model': getattr(shared, 'image_model_name', 'unknown'),
}
for idx, img in enumerate(images):
ts = datetime.now().strftime("%H-%M-%S")
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
png_info = PngInfo()
png_info.add_text("image_gen_settings", json.dumps(metadata))
img.save(filepath, pnginfo=png_info)

View file

@ -7,26 +7,23 @@ import traceback
from collections import deque
from threading import Thread
import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
import uvicorn
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
import extensions.openai.completions as OAIcompletions
import extensions.openai.images as OAIimages
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from modules import shared
from modules.logging_colors import logger
from modules.models import unload_model
from modules.text_generation import stop_everything_event
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
from .typing import (
ChatCompletionRequest,
@ -40,6 +37,8 @@ from .typing import (
EmbeddingsResponse,
EncodeRequest,
EncodeResponse,
ImageGenerationRequest,
ImageGenerationResponse,
LoadLorasRequest,
LoadModelRequest,
LogitsRequest,
@ -54,12 +53,12 @@ from .typing import (
params = {
'embedding_device': 'cpu',
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
'sd_webui_url': '',
'debug': 0
}
streaming_semaphore = asyncio.Semaphore(1)
image_generation_semaphore = asyncio.Semaphore(1)
def verify_api_key(authorization: str = Header(None)) -> None:
@ -228,20 +227,26 @@ async def handle_audio_transcription(request: Request):
return JSONResponse(content=transcription)
@app.post('/v1/images/generations', dependencies=check_key)
async def handle_image_generation(request: Request):
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
async def handle_image_generation(request_data: ImageGenerationRequest):
import extensions.openai.images as OAIimages
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
async with image_generation_semaphore:
width, height = request_data.get_width_height()
body = await request.json()
prompt = body['prompt']
size = body.get('size', '1024x1024')
response_format = body.get('response_format', 'url') # or b64_json
n = body.get('n', 1) # ignore the batch limits of max 10
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
return JSONResponse(response)
response = await asyncio.to_thread(
OAIimages.generations,
prompt=request_data.prompt,
size=f"{width}x{height}",
response_format=request_data.response_format,
n=request_data.batch_size, # <-- use resolved batch_size
negative_prompt=request_data.negative_prompt,
steps=request_data.steps,
seed=request_data.seed,
cfg_scale=request_data.cfg_scale,
batch_count=request_data.batch_count,
)
return JSONResponse(response)
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)

View file

@ -264,6 +264,46 @@ class LoadLorasRequest(BaseModel):
lora_names: List[str]
class ImageGenerationRequest(BaseModel):
"""OpenAI-compatible image generation request with extended parameters."""
# Required
prompt: str
# Generation parameters
negative_prompt: str = ""
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
steps: int = Field(default=9, ge=1)
cfg_scale: float = Field(default=0.0, ge=0.0)
seed: int = Field(default=-1, description="-1 for random")
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
# OpenAI compatibility (unused)
model: str | None = None
response_format: str = "b64_json"
user: str | None = None
@model_validator(mode='after')
def resolve_batch_size(self):
"""Use batch_size if provided, otherwise fall back to n."""
if self.batch_size is None:
self.batch_size = self.n
return self
def get_width_height(self) -> tuple[int, int]:
try:
parts = self.size.lower().split('x')
return int(parts[0]), int(parts[1])
except (ValueError, IndexError):
return 1024, 1024
class ImageGenerationResponse(BaseModel):
created: int = int(time.time())
data: List[dict]
def to_json(obj):
return json.dumps(obj.__dict__, indent=4)

View file

@ -36,3 +36,17 @@ function switch_to_character() {
document.getElementById("character-tab-button").click();
scrollToTop();
}
function switch_to_image_ai_generate() {
const container = document.querySelector("#image-ai-tab");
const buttons = container.getElementsByTagName("button");
for (let i = 0; i < buttons.length; i++) {
if (buttons[i].textContent.trim() === "Generate") {
buttons[i].click();
break;
}
}
scrollToTop();
}

View file

@ -3,7 +3,6 @@ import copy
import functools
import html
import json
import os
import pprint
import re
import shutil
@ -26,6 +25,7 @@ from modules.html_generator import (
convert_to_markdown,
make_thumbnail
)
from modules.image_utils import open_image_safely
from modules.logging_colors import logger
from modules.text_generation import (
generate_reply,
@ -1516,20 +1516,6 @@ def load_instruction_template_memoized(template):
return load_instruction_template(template)
def open_image_safely(path):
if path is None or not isinstance(path, str) or not Path(path).exists():
return None
if os.path.islink(path):
return None
try:
return Image.open(path)
except Exception as e:
logger.error(f"Failed to open image file: {path}. Reason: {e}")
return None
def upload_character(file, img_path, tavern=False):
img = open_image_safely(img_path)
decoded_file = file if isinstance(file, str) else file.decode('utf-8')

View file

@ -1,9 +1,7 @@
"""
Shared image processing utilities for multimodal support.
Used by both ExLlamaV3 and llama.cpp implementations.
"""
import base64
import io
import os
from pathlib import Path
from typing import Any, List, Tuple
from PIL import Image
@ -11,6 +9,20 @@ from PIL import Image
from modules.logging_colors import logger
def open_image_safely(path):
if path is None or not isinstance(path, str) or not Path(path).exists():
return None
if os.path.islink(path):
return None
try:
return Image.open(path)
except Exception as e:
logger.error(f"Failed to open image file: {path}. Reason: {e}")
return None
def convert_pil_to_base64(image: Image.Image) -> str:
"""Converts a PIL Image to a base64 encoded string."""
buffered = io.BytesIO()

View file

@ -7,7 +7,6 @@ from pathlib import Path
import gradio as gr
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from modules import shared, ui, utils
@ -16,6 +15,7 @@ from modules.image_models import (
load_image_model,
unload_image_model
)
from modules.image_utils import open_image_safely
from modules.logging_colors import logger
from modules.text_generation import stop_everything_event
from modules.utils import gradio
@ -29,7 +29,7 @@ ASPECT_RATIOS = {
}
STEP = 16
IMAGES_PER_PAGE = 64
IMAGES_PER_PAGE = 32
# Settings keys to save in PNG metadata (Generate tab only)
METADATA_SETTINGS_KEYS = [
@ -159,7 +159,7 @@ def save_generated_images(images, state, actual_seed):
def read_image_metadata(image_path):
"""Read generation metadata from PNG file."""
try:
with Image.open(image_path) as img:
with open_image_safely(image_path) as img:
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
return json.loads(img.text['image_gen_settings'])
except Exception as e:
@ -172,7 +172,7 @@ def format_metadata_for_display(metadata):
if not metadata:
return "No generation settings found in this image."
lines = ["**Generation Settings**", ""]
lines = []
# Display in a nice order
display_order = [
@ -418,9 +418,9 @@ def create_ui():
# Pagination controls
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
shared.gradio['image_prev_page'] = gr.Button("◀ Prev Page", elem_classes="refresh-button")
shared.gradio['image_page_info'] = gr.Markdown(value=get_initial_page_info, elem_id="image-page-info")
shared.gradio['image_next_page'] = gr.Button("Next ", elem_classes="refresh-button")
shared.gradio['image_next_page'] = gr.Button("Next Page ", elem_classes="refresh-button")
shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
@ -441,7 +441,7 @@ def create_ui():
)
with gr.Column(scale=1):
gr.Markdown("### Selected Image")
gr.Markdown("### Generation Settings")
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
shared.gradio['image_send_to_generate'] = gr.Button("Send to Generate", variant="primary")
shared.gradio['image_gallery_status'] = gr.Markdown("")
@ -649,6 +649,7 @@ def create_event_handlers():
'image_cfg_scale',
'image_gallery_status'
),
js=f'() => {{{ui.switch_tabs_js}; switch_to_image_ai_generate()}}',
show_progress=False
)
@ -676,7 +677,8 @@ def generate(state):
if not model_name or model_name == 'None':
logger.error("No image model selected. Go to the Model tab and select a model.")
return []
yield []
return
if shared.image_model is None:
result = load_image_model(
@ -689,7 +691,8 @@ def generate(state):
)
if result is None:
logger.error(f"Failed to load model `{model_name}`.")
return []
yield []
return
shared.image_model_name = model_name
@ -759,6 +762,7 @@ def generate(state):
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
yield all_images
t1 = time.time()
save_generated_images(all_images, state, seed)
@ -767,12 +771,12 @@ def generate(state):
total_steps = state["image_steps"] * int(state['image_batch_count'])
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
return all_images
yield all_images
except Exception as e:
logger.error(f"Image generation failed: {e}")
traceback.print_exc()
return []
yield []
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):

View file

@ -44,8 +44,8 @@ sse-starlette==1.6.5
tiktoken
# CUDA wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.16/exllamav3-0.0.16+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.16/exllamav3-0.0.16+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"

View file

@ -42,7 +42,7 @@ sse-starlette==1.6.5
tiktoken
# AMD wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+rocm6.4.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+rocm6.4.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+rocm6.2.4.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64"

View file

@ -42,7 +42,7 @@ sse-starlette==1.6.5
tiktoken
# AMD wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkanavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkanavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+rocm6.2.4.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64"

View file

@ -42,5 +42,5 @@ sse-starlette==1.6.5
tiktoken
# Mac wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0-py3-none-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0-py3-none-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0-py3-none-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0-py3-none-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0"

View file

@ -42,5 +42,5 @@ sse-starlette==1.6.5
tiktoken
# Mac wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0-py3-none-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0-py3-none-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0-py3-none-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0-py3-none-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0"

View file

@ -42,5 +42,5 @@ sse-starlette==1.6.5
tiktoken
# llama.cpp (CPU only, AVX2)
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cpuavx2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cpuavx2-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cpuavx2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cpuavx2-py3-none-win_amd64.whl; platform_system == "Windows"

View file

@ -42,5 +42,5 @@ sse-starlette==1.6.5
tiktoken
# llama.cpp (CPU only, no AVX2)
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cpuavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cpuavx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cpuavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cpuavx-py3-none-win_amd64.whl; platform_system == "Windows"

View file

@ -44,8 +44,8 @@ sse-starlette==1.6.5
tiktoken
# CUDA wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cu124avx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cu124avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cu124avx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cu124avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.16/exllamav3-0.0.16+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/turboderp-org/exllamav3/releases/download/v0.0.16/exllamav3-0.0.16+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# CUDA wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# AMD wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+rocm6.4.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+rocm6.4.4-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# AMD wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+rocm6.4.4avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+rocm6.4.4avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# Mac wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0-py3-none-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0-py3-none-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0-py3-none-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0-py3-none-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# Mac wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0-py3-none-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0-py3-none-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0-py3-none-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0-py3-none-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# llama.cpp (CPU only, AVX2)
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cpuavx2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cpuavx2-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cpuavx2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cpuavx2-py3-none-win_amd64.whl; platform_system == "Windows"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# llama.cpp (CPU only, no AVX2)
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cpuavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cpuavx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cpuavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cpuavx-py3-none-win_amd64.whl; platform_system == "Windows"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# CUDA wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cu124avx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+cu124avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cu124avx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+cu124avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# Vulkan wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -23,5 +23,5 @@ sse-starlette==1.6.5
tiktoken
# CUDA wheels
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.62.0/llama_cpp_binaries-0.62.0+vulkanavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows"
https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.64.0/llama_cpp_binaries-0.64.0+vulkanavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"

View file

@ -275,6 +275,22 @@ if __name__ == "__main__":
if extension not in shared.args.extensions:
shared.args.extensions.append(extension)
# Load image model if specified via CLI
if shared.args.image_model:
logger.info(f"Loading image model: {shared.args.image_model}")
result = load_image_model(
shared.args.image_model,
dtype=shared.settings.get('image_dtype', 'bfloat16'),
attn_backend=shared.settings.get('image_attn_backend', 'sdpa'),
cpu_offload=shared.settings.get('image_cpu_offload', False),
compile_model=shared.settings.get('image_compile', False),
quant_method=shared.settings.get('image_quant', 'none')
)
if result is not None:
shared.image_model_name = shared.args.image_model
else:
logger.error(f"Failed to load image model: {shared.args.image_model}")
available_models = utils.get_available_models()
# Model defined through --model
@ -321,22 +337,6 @@ if __name__ == "__main__":
if shared.args.lora:
add_lora_to_model(shared.args.lora)
# Load image model if specified via CLI
if shared.args.image_model:
logger.info(f"Loading image model: {shared.args.image_model}")
result = load_image_model(
shared.args.image_model,
dtype=shared.settings.get('image_dtype', 'bfloat16'),
attn_backend=shared.settings.get('image_attn_backend', 'sdpa'),
cpu_offload=shared.settings.get('image_cpu_offload', False),
compile_model=shared.settings.get('image_compile', False),
quant_method=shared.settings.get('image_quant', 'none')
)
if result is not None:
shared.image_model_name = shared.args.image_model
else:
logger.error(f"Failed to load image model: {shared.args.image_model}")
shared.generation_lock = Lock()
if shared.args.idle_timeout > 0: