diff --git a/extensions/modelslab/script.py b/extensions/modelslab/script.py new file mode 100644 index 00000000..74445dcf --- /dev/null +++ b/extensions/modelslab/script.py @@ -0,0 +1,610 @@ +import base64 +import io +import re +import time +import json +from datetime import date +from pathlib import Path + +import gradio as gr +import requests +from PIL import Image + +from modules import shared +from modules.ui import create_refresh_button + +# ModelsLab API parameters - can be customized in settings.json +params = { + # API Configuration + 'api_key': '', + 'base_url': 'https://modelslab.com/api/v6', + 'model': 'flux', + + # Generation Parameters + 'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful', + 'negative_prompt': '(worst quality, low quality:1.3)', + 'width': 1024, + 'height': 1024, + 'steps': 25, + 'cfg_scale': 7.5, + 'seed': -1, + + # Behavior Settings + 'mode': 1, # 0=Manual, 1=Interactive, 2=Always-on + 'save_images': True, + 'enhance_prompt': True, + 'safety_checker': True, + + # Text Integration + 'textgen_prefix': 'Please provide a detailed and vivid description of [subject]', + 'trigger_words': ['image', 'picture', 'photo', 'generate', 'draw', 'create', 'show me'], + + # Available Models + 'available_models': ['flux', 'sdxl', 'playground-v2', 'stable-diffusion'], + 'model_info': { + 'flux': {'name': 'Flux', 'desc': 'Best quality, prompt adherence (~$0.018)', 'speed': '2-4s'}, + 'sdxl': {'name': 'Stable Diffusion XL', 'desc': 'Artistic, creative style (~$0.015)', 'speed': '3-5s'}, + 'playground-v2': {'name': 'Playground v2.5', 'desc': 'UI mockups, designs (~$0.016)', 'speed': '2-3s'}, + 'stable-diffusion': {'name': 'Stable Diffusion', 'desc': 'General purpose, fast (~$0.012)', 'speed': '2-4s'} + } +} + +# Global state +picture_response = False +current_generation_task = None + +def load_settings(): + """Load settings from webui settings.json""" + try: + settings_file = Path('settings.json') + if settings_file.exists(): + with open(settings_file, 'r') as f: + data = json.load(f) + modelslab_settings = data.get('modelslab_api_pictures', {}) + for key, value in modelslab_settings.items(): + if key in params: + params[key] = value + except Exception as e: + print(f"ModelsLab: Could not load settings: {e}") + +def save_settings(): + """Save current params to webui settings.json""" + try: + settings_file = Path('settings.json') + if settings_file.exists(): + with open(settings_file, 'r') as f: + data = json.load(f) + else: + data = {} + + data['modelslab_api_pictures'] = params.copy() + + with open(settings_file, 'w') as f: + json.dump(data, f, indent=2) + except Exception as e: + print(f"ModelsLab: Could not save settings: {e}") + +class ModelsLabClient: + """ModelsLab API client with async support""" + + def __init__(self, api_key, base_url): + self.api_key = api_key + self.base_url = base_url + self.session = requests.Session() + + def generate_image(self, prompt, **kwargs): + """Generate image using ModelsLab API""" + if not self.api_key: + raise Exception("ModelsLab API key not configured") + + payload = { + "key": self.api_key, + "model_id": kwargs.get('model', 'flux'), + "prompt": prompt, + "negative_prompt": kwargs.get('negative_prompt', ''), + "width": kwargs.get('width', 1024), + "height": kwargs.get('height', 1024), + "samples": 1, + "num_inference_steps": kwargs.get('steps', 25), + "guidance_scale": kwargs.get('cfg_scale', 7.5), + "enhance_prompt": kwargs.get('enhance_prompt', True), + "safety_checker": kwargs.get('safety_checker', True) + } + + # Add seed if specified + if kwargs.get('seed') and kwargs.get('seed') != -1: + payload["seed"] = kwargs.get('seed') + + try: + response = self.session.post( + f"{self.base_url}/images/text2img", + json=payload, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + if response.status_code != 200: + raise Exception(f"API request failed: {response.status_code} - {response.text}") + + return self.handle_response(response.json()) + + except requests.exceptions.Timeout: + raise Exception("Request timed out. Please try again.") + except requests.exceptions.ConnectionError: + raise Exception("Connection error. Check your internet connection.") + except Exception as e: + raise Exception(f"API Error: {str(e)}") + + def handle_response(self, data): + """Handle API response including async polling""" + if data.get('status') == 'success': + output = data.get('output', []) + if output: + return output[0] # Return first image URL + else: + raise Exception("No image generated") + + elif data.get('status') == 'processing': + task_id = data.get('id') + if task_id: + return self.poll_result(task_id) + else: + raise Exception("Generation started but no task ID received") + + elif data.get('status') == 'error': + raise Exception(f"Generation failed: {data.get('message', 'Unknown error')}") + + else: + # Direct image URL response + output = data.get('output', []) + if output: + return output[0] + else: + raise Exception("Unexpected API response format") + + def poll_result(self, task_id): + """Poll for async generation completion""" + max_attempts = 30 # Maximum 1 minute wait + attempt = 0 + + while attempt < max_attempts: + try: + response = self.session.post( + f"{self.base_url}/images/fetch/{task_id}", + json={"key": self.api_key}, + timeout=10 + ) + + if response.status_code == 200: + data = response.json() + + if data.get('status') == 'success': + output = data.get('output', []) + if output: + return output[0] + else: + raise Exception("Generation completed but no image received") + + elif data.get('status') == 'failed': + raise Exception(f"Generation failed: {data.get('message', 'Unknown error')}") + + # Still processing, continue polling + time.sleep(2) + attempt += 1 + + else: + raise Exception(f"Polling failed: {response.status_code}") + + except requests.exceptions.Timeout: + attempt += 1 + time.sleep(2) + continue + + raise Exception("Generation timed out after 1 minute") + +def remove_surrounded_chars(string): + """Remove text between asterisks (actions)""" + return re.sub(r'\*[^\*]*?(\*|$)', '', string) + +def contains_trigger_words(string): + """Check if string contains image generation trigger words""" + string = remove_surrounded_chars(string).lower() + + # Check for trigger word patterns + for trigger in params['trigger_words']: + if trigger.lower() in string: + return True + + # More sophisticated patterns + patterns = [ + r'\b(send|mail|message|me)\b.+?\b(image|pic(ture)?|photo|snap(shot)?|selfie)\b', + r'\b(generate|create|draw|make)\b.+?\b(image|picture|art|drawing)\b', + r'\b(show|display)\s+me\b.+?\b(image|picture|photo)\b' + ] + + for pattern in patterns: + if re.search(pattern, string, re.IGNORECASE): + return True + + return False + +def extract_subject(string): + """Extract subject from user request""" + string = remove_surrounded_chars(string).strip() + + # Try various extraction patterns + patterns = [ + r'(?:image|picture|photo|drawing)\s+of\s+(.+?)(?:\.|$|,)', + r'(?:generate|create|draw|make)\s+(?:an?|the)?\s*(.+?)(?:\.|$|,)', + r'(?:show|display)\s+me\s+(?:an?|the)?\s*(.+?)(?:\.|$|,)', + r'(?:want|like)\s+to\s+see\s+(?:an?|the)?\s*(.+?)(?:\.|$|,)' + ] + + for pattern in patterns: + match = re.search(pattern, string, re.IGNORECASE) + if match: + subject = match.group(1).strip() + if subject: + return subject + + # Fallback: return the whole string cleaned up + return re.sub(r'\b(image|picture|photo|generate|create|draw|show|me|of|an?|the)\b', '', + string, flags=re.IGNORECASE).strip() + +def schedule_image_generation(prompt): + """Mark that next response should include an image""" + global picture_response, current_generation_task + picture_response = True + current_generation_task = prompt + print(f"ModelsLab: Scheduled image generation for: {prompt[:100]}...") + +def state_modifier(state): + """Modify generation state when image generation is scheduled""" + global picture_response + if picture_response: + state['stream'] = False # Disable streaming for image generation + return state + +def input_modifier(string): + """Process user input to detect image generation requests""" + global params + + if not params['api_key']: + return string # No API key, skip processing + + if params['mode'] == 0: # Manual mode only + return string + + # Check for trigger words + if contains_trigger_words(string): + subject = extract_subject(string) + + if subject: + # Generate enhanced prompt using textgen prefix + enhanced_prompt = params['textgen_prefix'].replace('[subject]', subject) + else: + enhanced_prompt = string + + # Schedule image generation + schedule_image_generation(enhanced_prompt) + + # Modify response based on mode + if params['mode'] == 2: # Always-on mode + return f"I'll generate an image of: {subject if subject else 'your request'}" + elif params['mode'] == 1: # Interactive mode + # Let the LLM respond naturally, image will be added + pass + + return string + +def output_modifier(string): + """Process model output to inject generated images""" + global picture_response, current_generation_task + + if not picture_response or not current_generation_task: + return string + + try: + # Generate the image + image_html = generate_modelslab_image(current_generation_task) + + # Reset state + picture_response = False + current_generation_task = None + + # Inject image into response + if params['mode'] == 2: # Always-on mode + return image_html + else: # Interactive mode + return f"{string}\n\n{image_html}" + + except Exception as e: + # Reset state on error + picture_response = False + current_generation_task = None + + error_msg = f"

Image generation failed: {str(e)}

" + return f"{string}\n\n{error_msg}" + +def generate_modelslab_image(prompt): + """Generate image using ModelsLab API and return HTML""" + try: + print(f"ModelsLab: Generating image for prompt: {prompt[:100]}...") + + # Initialize client + client = ModelsLabClient(params['api_key'], params['base_url']) + + # Prepare full prompt + if params['prompt_prefix']: + full_prompt = f"{params['prompt_prefix']}, {prompt}" + else: + full_prompt = prompt + + print(f"ModelsLab: Full prompt: {full_prompt[:150]}...") + + # Generate image + image_url = client.generate_image( + prompt=full_prompt, + negative_prompt=params['negative_prompt'], + model=params['model'], + width=params['width'], + height=params['height'], + steps=params['steps'], + cfg_scale=params['cfg_scale'], + seed=params['seed'] if params['seed'] != -1 else None, + enhance_prompt=params['enhance_prompt'], + safety_checker=params['safety_checker'] + ) + + print(f"ModelsLab: Image generated successfully: {image_url}") + + # Save image if enabled + if params['save_images']: + try: + save_generated_image(image_url, prompt) + except Exception as e: + print(f"ModelsLab: Could not save image: {e}") + + # Create HTML for display + html = create_image_html(image_url, prompt) + return html + + except Exception as e: + print(f"ModelsLab: Generation error: {e}") + raise e + +def save_generated_image(image_url, prompt): + """Save generated image to local storage""" + try: + # Create output directory + output_dir = Path("outputs/modelslab") + output_dir.mkdir(parents=True, exist_ok=True) + + # Download image + response = requests.get(image_url, timeout=30) + response.raise_for_status() + + # Generate filename + timestamp = date.today().strftime("%Y%m%d") + safe_prompt = re.sub(r'[^\w\s-]', '', prompt[:50]).strip() + safe_prompt = re.sub(r'[\s-]+', '_', safe_prompt) + filename = f"{timestamp}_{safe_prompt}_{params['model']}.png" + + # Save file + filepath = output_dir / filename + with open(filepath, 'wb') as f: + f.write(response.content) + + print(f"ModelsLab: Image saved to {filepath}") + + except Exception as e: + print(f"ModelsLab: Save error: {e}") + +def create_image_html(image_url, prompt): + """Create HTML for displaying generated image""" + model_name = params['model_info'].get(params['model'], {}).get('name', params['model']) + + html = f""" +
+
+ Generated with ModelsLab {model_name} +
+ Generated image +
+ Prompt: {prompt[:100]}{'...' if len(prompt) > 100 else ''} +
+
+ """ + + return html + +def validate_api_key(api_key): + """Validate API key by making a test request""" + if not api_key: + return False, "API key is required" + + try: + client = ModelsLabClient(api_key, params['base_url']) + # Make a minimal test request + test_payload = { + "key": api_key, + "model_id": "flux", + "prompt": "test", + "width": 256, + "height": 256, + "samples": 1 + } + + response = requests.post( + f"{params['base_url']}/images/text2img", + json=test_payload, + headers={"Content-Type": "application/json"}, + timeout=10 + ) + + if response.status_code == 200: + data = response.json() + if 'error' in data: + return False, f"API Error: {data.get('error', 'Unknown error')}" + return True, "API key is valid" + else: + return False, f"HTTP {response.status_code}: {response.text[:100]}" + + except Exception as e: + return False, f"Validation failed: {str(e)}" + +def update_param(key, value): + """Update parameter and save settings""" + params[key] = value + save_settings() + print(f"ModelsLab: Updated {key} = {value}") + +def test_generation(): + """Test image generation with current settings""" + if not params['api_key']: + return "

Please configure your API key first

" + + try: + test_prompt = "A cute cat sitting in a garden, photorealistic" + html = generate_modelslab_image(test_prompt) + return html + except Exception as e: + return f"

Test generation failed: {str(e)}

" + +def ui(): + """Create Gradio interface for ModelsLab extension""" + load_settings() # Load settings when UI is created + + with gr.Accordion("🎨 ModelsLab API Settings", open=True): + with gr.Row(): + api_key_input = gr.Textbox( + label="API Key", + type="password", + value=params['api_key'], + placeholder="Enter your ModelsLab API key", + interactive=True + ) + + validate_button = gr.Button("Validate Key", size="sm") + validation_output = gr.HTML(visible=False) + + with gr.Row(): + model_dropdown = gr.Dropdown( + label="Model", + choices=params['available_models'], + value=params['model'], + info="Choose generation model" + ) + + mode_radio = gr.Radio( + label="Generation Mode", + choices=[ + ("Manual Only", 0), + ("Interactive (Recommended)", 1), + ("Always Generate", 2) + ], + value=params['mode'], + info="0=Manual commands only, 1=Trigger words, 2=Always generate" + ) + + with gr.Accordion("💡 Model Information", open=False): + model_info_html = gr.HTML( + value="
".join([ + f"{info['name']}: {info['desc']} | Speed: {info['speed']}" + for info in params['model_info'].values() + ]) + ) + + with gr.Accordion("⚙️ Generation Settings", open=False): + with gr.Row(): + width_slider = gr.Slider( + minimum=256, maximum=1536, step=64, + value=params['width'], label="Width" + ) + height_slider = gr.Slider( + minimum=256, maximum=1536, step=64, + value=params['height'], label="Height" + ) + + with gr.Row(): + steps_slider = gr.Slider( + minimum=10, maximum=50, step=1, + value=params['steps'], label="Steps" + ) + cfg_slider = gr.Slider( + minimum=1.0, maximum=20.0, step=0.1, + value=params['cfg_scale'], label="CFG Scale" + ) + + with gr.Row(): + seed_input = gr.Number( + value=params['seed'], label="Seed (-1 for random)", + precision=0 + ) + + with gr.Column(): + prompt_prefix_input = gr.Textbox( + label="Prompt Prefix", + value=params['prompt_prefix'], + placeholder="Added to beginning of every prompt" + ) + negative_prompt_input = gr.Textbox( + label="Negative Prompt", + value=params['negative_prompt'], + placeholder="What to avoid in generation" + ) + + with gr.Row(): + enhance_prompt_check = gr.Checkbox( + label="Enhance Prompt", value=params['enhance_prompt'] + ) + safety_checker_check = gr.Checkbox( + label="Safety Checker", value=params['safety_checker'] + ) + save_images_check = gr.Checkbox( + label="Save Images", value=params['save_images'] + ) + + with gr.Row(): + test_button = gr.Button("🧪 Test Generation", variant="secondary") + test_output = gr.HTML() + + # Event handlers + def validate_key(key): + if not key: + return gr.update(visible=True, value="

Please enter an API key

") + + is_valid, message = validate_api_key(key) + color = "green" if is_valid else "red" + return gr.update( + visible=True, + value=f"

{message}

" + ) + + # Connect event handlers + api_key_input.change(lambda x: update_param('api_key', x), inputs=[api_key_input]) + model_dropdown.change(lambda x: update_param('model', x), inputs=[model_dropdown]) + mode_radio.change(lambda x: update_param('mode', x), inputs=[mode_radio]) + + width_slider.change(lambda x: update_param('width', x), inputs=[width_slider]) + height_slider.change(lambda x: update_param('height', x), inputs=[height_slider]) + steps_slider.change(lambda x: update_param('steps', x), inputs=[steps_slider]) + cfg_slider.change(lambda x: update_param('cfg_scale', x), inputs=[cfg_slider]) + seed_input.change(lambda x: update_param('seed', x), inputs=[seed_input]) + + prompt_prefix_input.change(lambda x: update_param('prompt_prefix', x), inputs=[prompt_prefix_input]) + negative_prompt_input.change(lambda x: update_param('negative_prompt', x), inputs=[negative_prompt_input]) + + enhance_prompt_check.change(lambda x: update_param('enhance_prompt', x), inputs=[enhance_prompt_check]) + safety_checker_check.change(lambda x: update_param('safety_checker', x), inputs=[safety_checker_check]) + save_images_check.change(lambda x: update_param('save_images', x), inputs=[save_images_check]) + + validate_button.click(validate_key, inputs=[api_key_input], outputs=[validation_output]) + test_button.click(test_generation, outputs=[test_output]) + +# Initialize settings on load +load_settings() +print("ModelsLab API extension loaded successfully!") +print(f"Current model: {params['model']}, Mode: {params['mode']}") +if not params['api_key']: + print("⚠️ Please configure your ModelsLab API key in the extension settings") \ No newline at end of file