diff --git a/extensions/openai/images.py b/extensions/openai/images.py index 92bd85f0..3a5288e6 100644 --- a/extensions/openai/images.py +++ b/extensions/openai/images.py @@ -1,70 +1,154 @@ +""" +OpenAI-compatible image generation using local diffusion models. +""" + +import base64 +import io +import json import os import time +from datetime import datetime -import requests - +import numpy as np from extensions.openai.errors import ServiceUnavailableError +from modules import shared +from modules.logging_colors import logger +from PIL.PngImagePlugin import PngInfo -def generations(prompt: str, size: str, response_format: str, n: int): - # Stable Diffusion callout wrapper for txt2img - # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E - # the results will be limited and likely poor. SD has hundreds of models and dozens of settings. - # If you want high quality tailored results you should just use the Stable Diffusion API directly. - # it's too general an API to try and shape the result with specific tags like negative prompts - # or "masterpiece", etc. SD configuration is beyond the scope of this API. - # At this point I will not add the edits and variations endpoints (ie. img2img) because they - # require changing the form data handling to accept multipart form data, also to properly support - # url return types will require file management and a web serving files... Perhaps later! - base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512)) - sd_defaults = { - 'sampler_name': 'DPM++ 2M Karras', # vast improvement - 'steps': 30, +def generations(prompt: str, size: str, response_format: str, n: int, + negative_prompt: str = "", steps: int = 9, seed: int = -1, + cfg_scale: float = 0.0, batch_count: int = 1): + """ + Generate images using the loaded diffusion model. + + Args: + prompt: Text description of the desired image + size: Image dimensions as "WIDTHxHEIGHT" + response_format: 'url' or 'b64_json' + n: Number of images per batch + negative_prompt: What to avoid in the image + steps: Number of inference steps + seed: Random seed (-1 for random) + cfg_scale: Classifier-free guidance scale + batch_count: Number of sequential batches + + Returns: + dict with 'created' timestamp and 'data' list of images + """ + import torch + from modules.image_models import get_pipeline_type + from modules.torch_utils import clear_torch_cache, get_device + + if shared.image_model is None: + raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.") + + clear_torch_cache() + + # Parse dimensions + try: + width, height = [int(x) for x in size.split('x')] + except (ValueError, IndexError): + width, height = 1024, 1024 + + # Handle seed + if seed == -1: + seed = np.random.randint(0, 2**32 - 1) + + device = get_device() or "cpu" + generator = torch.Generator(device).manual_seed(int(seed)) + + # Get pipeline type for CFG parameter name + pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model) + + # Build generation kwargs + gen_kwargs = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "num_inference_steps": steps, + "num_images_per_prompt": n, + "generator": generator, } - width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size + # Pipeline-specific CFG parameter + if pipeline_type == 'qwenimage': + gen_kwargs["true_cfg_scale"] = cfg_scale + else: + gen_kwargs["guidance_scale"] = cfg_scale - # to hack on better generation, edit default payload. - payload = { - 'prompt': prompt, # ignore prompt limit of 1000 characters - 'width': width, - 'height': height, - 'batch_size': n, - } - payload.update(sd_defaults) + # Generate + all_images = [] + t0 = time.time() - scale = min(width, height) / base_model_size - if scale >= 1.2: - # for better performance with the default size (1024), and larger res. - scaler = { - 'width': width // scale, - 'height': height // scale, - 'hr_scale': scale, - 'enable_hr': True, - 'hr_upscaler': 'Latent', - 'denoising_strength': 0.68, - } - payload.update(scaler) + shared.stop_everything = False + def interrupt_callback(pipe, step_index, timestep, callback_kwargs): + if shared.stop_everything: + pipe._interrupt = True + return callback_kwargs + + gen_kwargs["callback_on_step_end"] = interrupt_callback + + for i in range(batch_count): + if shared.stop_everything: + break + generator.manual_seed(int(seed + i)) + batch_results = shared.image_model(**gen_kwargs).images + all_images.extend(batch_results) + + t1 = time.time() + total_images = len(all_images) + total_steps = steps * batch_count + logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})') + + # Save images + _save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale) + + # Build response resp = { 'created': int(time.time()), 'data': [] } - from extensions.openai.script import params - # TODO: support SD_WEBUI_AUTH username:password pair. - sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img" - - response = requests.post(url=sd_url, json=payload) - r = response.json() - if response.status_code != 200 or 'images' not in r: - print(r) - raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None)) - # r['parameters']... - for b64_json in r['images']: + for img in all_images: + b64 = _image_to_base64(img) if response_format == 'b64_json': - resp['data'].extend([{'b64_json': b64_json}]) + resp['data'].append({'b64_json': b64}) else: - resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this + resp['data'].append({'url': f'data:image/png;base64,{b64}'}) return resp + + +def _image_to_base64(image) -> str: + buffered = io.BytesIO() + image.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + +def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale): + """Save images with metadata.""" + date_str = datetime.now().strftime("%Y-%m-%d") + folder = os.path.join("user_data", "image_outputs", date_str) + os.makedirs(folder, exist_ok=True) + + metadata = { + 'image_prompt': prompt, + 'image_neg_prompt': negative_prompt, + 'image_width': width, + 'image_height': height, + 'image_steps': steps, + 'image_seed': seed, + 'image_cfg_scale': cfg_scale, + 'model': getattr(shared, 'image_model_name', 'unknown'), + } + + for idx, img in enumerate(images): + ts = datetime.now().strftime("%H-%M-%S") + filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png") + + png_info = PngInfo() + png_info.add_text("image_gen_settings", json.dumps(metadata)) + img.save(filepath, pnginfo=png_info) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 9440cb1e..1e982731 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -7,26 +7,23 @@ import traceback from collections import deque from threading import Thread +import extensions.openai.completions as OAIcompletions +import extensions.openai.logits as OAIlogits +import extensions.openai.models as OAImodels import uvicorn +from extensions.openai.tokens import token_count, token_decode, token_encode +from extensions.openai.utils import _start_cloudflared from fastapi import Depends, FastAPI, Header, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.requests import Request from fastapi.responses import JSONResponse -from pydub import AudioSegment -from sse_starlette import EventSourceResponse -from starlette.concurrency import iterate_in_threadpool - -import extensions.openai.completions as OAIcompletions -import extensions.openai.images as OAIimages -import extensions.openai.logits as OAIlogits -import extensions.openai.models as OAImodels -from extensions.openai.errors import ServiceUnavailableError -from extensions.openai.tokens import token_count, token_decode, token_encode -from extensions.openai.utils import _start_cloudflared from modules import shared from modules.logging_colors import logger from modules.models import unload_model from modules.text_generation import stop_everything_event +from pydub import AudioSegment +from sse_starlette import EventSourceResponse +from starlette.concurrency import iterate_in_threadpool from .typing import ( ChatCompletionRequest, @@ -40,6 +37,8 @@ from .typing import ( EmbeddingsResponse, EncodeRequest, EncodeResponse, + ImageGenerationRequest, + ImageGenerationResponse, LoadLorasRequest, LoadModelRequest, LogitsRequest, @@ -228,19 +227,24 @@ async def handle_audio_transcription(request: Request): return JSONResponse(content=transcription) -@app.post('/v1/images/generations', dependencies=check_key) -async def handle_image_generation(request: Request): +@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key) +async def handle_image_generation(request_data: ImageGenerationRequest): + import extensions.openai.images as OAIimages - if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): - raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") + width, height = request_data.get_width_height() - body = await request.json() - prompt = body['prompt'] - size = body.get('size', '1024x1024') - response_format = body.get('response_format', 'url') # or b64_json - n = body.get('n', 1) # ignore the batch limits of max 10 - - response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n) + response = await asyncio.to_thread( + OAIimages.generations, + prompt=request_data.prompt, + size=f"{width}x{height}", + response_format=request_data.response_format, + n=request_data.batch_size, # <-- use resolved batch_size + negative_prompt=request_data.negative_prompt, + steps=request_data.steps, + seed=request_data.seed, + cfg_scale=request_data.cfg_scale, + batch_count=request_data.batch_count, + ) return JSONResponse(response) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 56d91582..a24b844b 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -264,6 +264,46 @@ class LoadLorasRequest(BaseModel): lora_names: List[str] +class ImageGenerationRequest(BaseModel): + """OpenAI-compatible image generation request with extended parameters.""" + # Required + prompt: str + + # Generation parameters + negative_prompt: str = "" + size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'") + steps: int = Field(default=9, ge=1) + cfg_scale: float = Field(default=0.0, ge=0.0) + seed: int = Field(default=-1, description="-1 for random") + batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)") + n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)") + batch_count: int = Field(default=1, ge=1, description="Sequential batch count") + + # OpenAI compatibility (unused) + model: str | None = None + response_format: str = "b64_json" + user: str | None = None + + @model_validator(mode='after') + def resolve_batch_size(self): + """Use batch_size if provided, otherwise fall back to n.""" + if self.batch_size is None: + self.batch_size = self.n + return self + + def get_width_height(self) -> tuple[int, int]: + try: + parts = self.size.lower().split('x') + return int(parts[0]), int(parts[1]) + except (ValueError, IndexError): + return 1024, 1024 + + +class ImageGenerationResponse(BaseModel): + created: int = int(time.time()) + data: List[dict] + + def to_json(obj): return json.dumps(obj.__dict__, indent=4)