diff --git a/extensions/openai/images.py b/extensions/openai/images.py index 3a5288e6..1ecb1e63 100644 --- a/extensions/openai/images.py +++ b/extensions/openai/images.py @@ -4,117 +4,50 @@ OpenAI-compatible image generation using local diffusion models. import base64 import io -import json -import os import time -from datetime import datetime -import numpy as np from extensions.openai.errors import ServiceUnavailableError from modules import shared -from modules.logging_colors import logger -from PIL.PngImagePlugin import PngInfo -def generations(prompt: str, size: str, response_format: str, n: int, - negative_prompt: str = "", steps: int = 9, seed: int = -1, - cfg_scale: float = 0.0, batch_count: int = 1): +def generations(request): """ Generate images using the loaded diffusion model. - - Args: - prompt: Text description of the desired image - size: Image dimensions as "WIDTHxHEIGHT" - response_format: 'url' or 'b64_json' - n: Number of images per batch - negative_prompt: What to avoid in the image - steps: Number of inference steps - seed: Random seed (-1 for random) - cfg_scale: Classifier-free guidance scale - batch_count: Number of sequential batches - - Returns: - dict with 'created' timestamp and 'data' list of images + Returns dict with 'created' timestamp and 'data' list of images. """ - import torch - from modules.image_models import get_pipeline_type - from modules.torch_utils import clear_torch_cache, get_device + from modules.ui_image_generation import generate if shared.image_model is None: raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.") - clear_torch_cache() + width, height = request.get_width_height() - # Parse dimensions - try: - width, height = [int(x) for x in size.split('x')] - except (ValueError, IndexError): - width, height = 1024, 1024 + # Build state dict: GenerationOptions fields + image-specific keys + state = request.model_dump() + state.update({ + 'image_model_menu': shared.image_model_name, + 'image_prompt': request.prompt, + 'image_neg_prompt': request.negative_prompt, + 'image_width': width, + 'image_height': height, + 'image_steps': request.steps, + 'image_seed': request.image_seed, + 'image_batch_size': request.batch_size, + 'image_batch_count': request.batch_count, + 'image_cfg_scale': request.cfg_scale, + 'image_llm_variations': request.llm_variations, + }) - # Handle seed - if seed == -1: - seed = np.random.randint(0, 2**32 - 1) - - device = get_device() or "cpu" - generator = torch.Generator(device).manual_seed(int(seed)) - - # Get pipeline type for CFG parameter name - pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model) - - # Build generation kwargs - gen_kwargs = { - "prompt": prompt, - "negative_prompt": negative_prompt, - "height": height, - "width": width, - "num_inference_steps": steps, - "num_images_per_prompt": n, - "generator": generator, - } - - # Pipeline-specific CFG parameter - if pipeline_type == 'qwenimage': - gen_kwargs["true_cfg_scale"] = cfg_scale - else: - gen_kwargs["guidance_scale"] = cfg_scale - - # Generate - all_images = [] - t0 = time.time() - - shared.stop_everything = False - - def interrupt_callback(pipe, step_index, timestep, callback_kwargs): - if shared.stop_everything: - pipe._interrupt = True - return callback_kwargs - - gen_kwargs["callback_on_step_end"] = interrupt_callback - - for i in range(batch_count): - if shared.stop_everything: - break - generator.manual_seed(int(seed + i)) - batch_results = shared.image_model(**gen_kwargs).images - all_images.extend(batch_results) - - t1 = time.time() - total_images = len(all_images) - total_steps = steps * batch_count - logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})') - - # Save images - _save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale) + # Exhaust generator, keep final result + images = [] + for images, _ in generate(state, save_images=False): + pass # Build response - resp = { - 'created': int(time.time()), - 'data': [] - } - - for img in all_images: + resp = {'created': int(time.time()), 'data': []} + for img in images: b64 = _image_to_base64(img) - if response_format == 'b64_json': + if request.response_format == 'b64_json': resp['data'].append({'b64_json': b64}) else: resp['data'].append({'url': f'data:image/png;base64,{b64}'}) @@ -126,29 +59,3 @@ def _image_to_base64(image) -> str: buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode('utf-8') - - -def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale): - """Save images with metadata.""" - date_str = datetime.now().strftime("%Y-%m-%d") - folder = os.path.join("user_data", "image_outputs", date_str) - os.makedirs(folder, exist_ok=True) - - metadata = { - 'image_prompt': prompt, - 'image_neg_prompt': negative_prompt, - 'image_width': width, - 'image_height': height, - 'image_steps': steps, - 'image_seed': seed, - 'image_cfg_scale': cfg_scale, - 'model': getattr(shared, 'image_model_name', 'unknown'), - } - - for idx, img in enumerate(images): - ts = datetime.now().strftime("%H-%M-%S") - filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png") - - png_info = PngInfo() - png_info.add_text("image_gen_settings", json.dumps(metadata)) - img.save(filepath, pnginfo=png_info) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 65805629..12f99ba4 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -7,23 +7,24 @@ import traceback from collections import deque from threading import Thread -import extensions.openai.completions as OAIcompletions -import extensions.openai.logits as OAIlogits -import extensions.openai.models as OAImodels import uvicorn -from extensions.openai.tokens import token_count, token_decode, token_encode -from extensions.openai.utils import _start_cloudflared from fastapi import Depends, FastAPI, Header, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.requests import Request from fastapi.responses import JSONResponse +from pydub import AudioSegment +from sse_starlette import EventSourceResponse +from starlette.concurrency import iterate_in_threadpool + +import extensions.openai.completions as OAIcompletions +import extensions.openai.logits as OAIlogits +import extensions.openai.models as OAImodels +from extensions.openai.tokens import token_count, token_decode, token_encode +from extensions.openai.utils import _start_cloudflared from modules import shared from modules.logging_colors import logger from modules.models import unload_model from modules.text_generation import stop_everything_event -from pydub import AudioSegment -from sse_starlette import EventSourceResponse -from starlette.concurrency import iterate_in_threadpool from .typing import ( ChatCompletionRequest, @@ -232,20 +233,7 @@ async def handle_image_generation(request_data: ImageGenerationRequest): import extensions.openai.images as OAIimages async with image_generation_semaphore: - width, height = request_data.get_width_height() - - response = await asyncio.to_thread( - OAIimages.generations, - prompt=request_data.prompt, - size=f"{width}x{height}", - response_format=request_data.response_format, - n=request_data.batch_size, # <-- use resolved batch_size - negative_prompt=request_data.negative_prompt, - steps=request_data.steps, - seed=request_data.seed, - cfg_scale=request_data.cfg_scale, - batch_count=request_data.batch_count, - ) + response = await asyncio.to_thread(OAIimages.generations, request_data) return JSONResponse(response) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index a24b844b..dfdb9a7e 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -264,20 +264,18 @@ class LoadLorasRequest(BaseModel): lora_names: List[str] -class ImageGenerationRequest(BaseModel): - """OpenAI-compatible image generation request with extended parameters.""" - # Required +class ImageGenerationRequestParams(BaseModel): + """Image-specific parameters for generation.""" prompt: str - - # Generation parameters negative_prompt: str = "" size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'") steps: int = Field(default=9, ge=1) cfg_scale: float = Field(default=0.0, ge=0.0) - seed: int = Field(default=-1, description="-1 for random") + image_seed: int = Field(default=-1, description="-1 for random") batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)") n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)") batch_count: int = Field(default=1, ge=1, description="Sequential batch count") + llm_variations: bool = False # OpenAI compatibility (unused) model: str | None = None @@ -286,7 +284,6 @@ class ImageGenerationRequest(BaseModel): @model_validator(mode='after') def resolve_batch_size(self): - """Use batch_size if provided, otherwise fall back to n.""" if self.batch_size is None: self.batch_size = self.n return self @@ -299,6 +296,10 @@ class ImageGenerationRequest(BaseModel): return 1024, 1024 +class ImageGenerationRequest(GenerationOptions, ImageGenerationRequestParams): + pass + + class ImageGenerationResponse(BaseModel): created: int = int(time.time()) data: List[dict] diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index ceb470ff..6ac0bc24 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -10,7 +10,6 @@ import numpy as np from PIL.PngImagePlugin import PngInfo from modules import shared, ui, utils -from modules.utils import check_model_loaded from modules.image_models import ( get_pipeline_type, load_image_model, @@ -19,7 +18,7 @@ from modules.image_models import ( from modules.image_utils import open_image_safely from modules.logging_colors import logger from modules.text_generation import stop_everything_event -from modules.utils import gradio +from modules.utils import check_model_loaded, gradio ASPECT_RATIOS = { "1:1 Square": (1, 1), @@ -725,13 +724,13 @@ def progress_bar_html(progress=0, text=""): return f'''