API: Move OpenAI-compatible API from extensions/openai to modules/api

This commit is contained in:
oobabooga 2026-03-20 14:46:00 -03:00
parent 2e4232e02b
commit bf6fbc019d
23 changed files with 51 additions and 65 deletions

0
modules/api/__init__.py Normal file
View file

View file

@ -0,0 +1,11 @@
#!/usr/bin/env python3
# preload the embedding model, useful for Docker images to prevent re-download on config change
# Dockerfile:
# ENV OPENEDAI_EMBEDDING_MODEL="sentence-transformers/all-mpnet-base-v2" # Optional
# RUN python3 cache_embedded_model.py
import os
import sentence_transformers
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
model = sentence_transformers.SentenceTransformer(st_model)

907
modules/api/completions.py Normal file
View file

@ -0,0 +1,907 @@
import copy
import functools
import json
import time
from collections import deque
from pathlib import Path
import tiktoken
import yaml
from pydantic import ValidationError
from .errors import InvalidRequestError
from .typing import ToolDefinition
from .utils import debug_msg
from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format
from modules import shared
from modules.reasoning import extract_reasoning
from modules.chat import (
generate_chat_prompt,
generate_chat_reply,
load_character_memoized,
load_instruction_template_memoized
)
from modules.image_utils import convert_openai_messages_to_images
from modules.logging_colors import logger
from modules.presets import load_preset_memoized
from modules.text_generation import decode, encode, generate_reply
@functools.cache
def load_chat_template_file(filepath):
"""Load a chat template from a file path (.jinja, .jinja2, or .yaml/.yml)."""
filepath = Path(filepath)
ext = filepath.suffix.lower()
text = filepath.read_text(encoding='utf-8')
if ext in ['.yaml', '.yml']:
data = yaml.safe_load(text)
return data.get('instruction_template', '')
return text
def _get_raw_logprob_entries(offset=0):
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
Returns (new_entries, new_offset).
"""
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
return [], offset
all_entries = shared.model.last_completion_probabilities
new_entries = all_entries[offset:]
return new_entries, len(all_entries)
def _dict_to_logprob_entries(token_dict):
"""Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format."""
if not token_dict:
return []
return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}]
def _parse_entry_top(entry):
"""Extract the top logprobs list from a raw entry, handling both key names."""
return entry.get('top_logprobs', entry.get('top_probs', []))
def format_chat_logprobs(entries):
"""Format logprob entries into OpenAI chat completions logprobs format.
Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]}
"""
if not entries:
return None
content = []
for entry in entries:
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
top_list = []
for item in top:
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_list.append({
"token": t,
"logprob": lp,
"bytes": list(t.encode('utf-8')) if t else None
})
content.append({
"token": token_str,
"logprob": token_logprob,
"bytes": list(token_str.encode('utf-8')) if token_str else None,
"top_logprobs": top_list
})
return {"content": content, "refusal": None} if content else None
def format_completion_logprobs(entries):
"""Format logprob entries into OpenAI completions logprobs format.
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
"""
if not entries:
return None
tokens = []
token_logprobs = []
top_logprobs = []
text_offset = []
offset = 0
for entry in entries:
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
tokens.append(token_str)
token_logprobs.append(token_logprob)
text_offset.append(offset)
offset += len(token_str)
top_dict = {}
for item in top:
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_dict[t] = lp
top_logprobs.append(top_dict)
if not tokens:
return None
return {
"tokens": tokens,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"text_offset": text_offset
}
def process_parameters(body, is_legacy=False):
generate_params = body
max_tokens_str = 'length' if is_legacy else 'max_tokens'
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
if generate_params['truncation_length'] == 0:
generate_params['truncation_length'] = shared.settings['truncation_length']
if generate_params['temperature'] == 0:
generate_params['do_sample'] = False
generate_params['top_k'] = 1
if body['preset'] is not None:
preset = load_preset_memoized(body['preset'])
generate_params.update(preset)
generate_params['custom_stopping_strings'] = []
if 'stop' in body: # str or array, max len 4 (ignored)
if isinstance(body['stop'], str):
generate_params['custom_stopping_strings'] = [body['stop']]
elif isinstance(body['stop'], list):
generate_params['custom_stopping_strings'] = body['stop']
# Resolve logprobs: for chat completions, logprobs is a bool and the count
# comes from top_logprobs. Normalize to an int for all backends.
logprobs = body.get('logprobs', None)
top_logprobs = body.get('top_logprobs', None)
if logprobs is True:
logprobs = max(top_logprobs, 1) if top_logprobs is not None else 5
generate_params['logprobs'] = logprobs
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
from transformers import LogitsProcessorList
from modules.transformers_loader import (
LogitsBiasProcessor,
LogprobProcessor
)
logits_processor = []
logit_bias = body.get('logit_bias', None)
if logit_bias: # {str: float, ...}
logits_processor = [LogitsBiasProcessor(logit_bias)]
if logprobs is not None and logprobs > 0:
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([generate_params['logprob_proc']])
if logits_processor: # requires logits_processor support
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
return generate_params
def process_multimodal_content(content):
"""Extract text and add image placeholders from OpenAI multimodal format"""
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts = []
image_placeholders = ""
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get('type', '')
if item_type == 'text':
text_parts.append(item.get('text', ''))
elif item_type == 'image_url':
image_placeholders += "<__media__>"
final_text = ' '.join(text_parts)
if image_placeholders:
return f"{image_placeholders}\n\n{final_text}"
else:
return final_text
return str(content)
def convert_history(history):
'''
Chat histories in this program are in the format [message, reply].
This function converts OpenAI histories to that format.
'''
chat_dialogue = []
current_message = ""
current_reply = ""
user_input = ""
user_input_last = True
system_message = ""
seen_non_system = False
for entry in history:
content = entry["content"]
role = entry["role"]
if role == "user":
seen_non_system = True
# Extract text content (images handled by model-specific code)
content = process_multimodal_content(content)
user_input = content
user_input_last = True
if current_message:
chat_dialogue.append([current_message, '', '', {}])
current_message = ""
current_message = content
elif role == "assistant":
seen_non_system = True
meta = {}
tool_calls = entry.get("tool_calls")
if tool_calls and isinstance(tool_calls, list):
meta["tool_calls"] = tool_calls
if content.strip() == "":
content = "" # keep empty content, don't skip
current_reply = content
user_input_last = False
if current_message:
chat_dialogue.append([current_message, current_reply, '', meta])
current_message = ""
current_reply = ""
else:
chat_dialogue.append(['', current_reply, '', meta])
elif role == "tool":
seen_non_system = True
user_input_last = False
meta = {}
if "tool_call_id" in entry:
meta["tool_call_id"] = entry["tool_call_id"]
chat_dialogue.append(['', '', content, meta])
elif role in ("system", "developer"):
if not seen_non_system:
# Leading system messages go to custom_system_message (placed at top)
system_message += f"\n{content}" if system_message else content
else:
# Mid-conversation system messages: preserve position in history
if current_message:
chat_dialogue.append([current_message, '', '', {}])
current_message = ""
chat_dialogue.append([content, '', '', {"role": "system"}])
if not user_input_last:
user_input = ""
return user_input, system_message, {
'internal': chat_dialogue,
'visible': copy.deepcopy(chat_dialogue),
'messages': history # Store original messages for multimodal models
}
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False, stop_event=None) -> dict:
if body.get('functions', []):
raise InvalidRequestError(message="functions is not supported.", param='functions')
if body.get('function_call', ''):
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages')
tools = None
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and body['tools']:
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
tool_choice = body.get('tool_choice', None)
if tool_choice == "none":
tools = None # Disable tool detection entirely
messages = body['messages']
for m in messages:
if 'role' not in m:
raise InvalidRequestError(message="messages: missing role", param='messages')
elif m['role'] == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
# Handle multimodal content validation
content = m.get('content')
if content is None:
# OpenAI allows content: null on assistant messages when tool_calls is present
if m['role'] == 'assistant' and m.get('tool_calls'):
m['content'] = ''
else:
raise InvalidRequestError(message="messages: missing content", param='messages')
# Validate multimodal content structure
if isinstance(content, list):
for item in content:
if not isinstance(item, dict) or 'type' not in item:
raise InvalidRequestError(message="messages: invalid content item format", param='messages')
if item['type'] not in ['text', 'image_url']:
raise InvalidRequestError(message="messages: unsupported content type", param='messages')
if item['type'] == 'text' and 'text' not in item:
raise InvalidRequestError(message="messages: missing text in content item", param='messages')
if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']):
raise InvalidRequestError(message="messages: missing image_url in content item", param='messages')
# Chat Completions
object_type = 'chat.completion' if not stream else 'chat.completion.chunk'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# generation parameters
generate_params = process_parameters(body, is_legacy=is_legacy)
if stop_event is not None:
generate_params['stop_event'] = stop_event
continue_ = body['continue_']
# Instruction template
if body['instruction_template_str']:
instruction_template_str = body['instruction_template_str']
elif body['instruction_template']:
instruction_template = body['instruction_template']
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
instruction_template_str = load_instruction_template_memoized(instruction_template)
elif shared.args.chat_template_file:
instruction_template_str = load_chat_template_file(shared.args.chat_template_file)
else:
instruction_template_str = shared.settings['instruction_template_str']
chat_template_str = body['chat_template_str'] or shared.default_settings['chat_template_str']
chat_instruct_command = body['chat_instruct_command'] or shared.default_settings['chat-instruct_command']
# Chat character
character = body['character'] or shared.default_settings['character']
character = "Assistant" if character == "None" else character
name1 = body['user_name'] or shared.default_settings['name1']
name1, name2, _, greeting, context = load_character_memoized(character, name1, '')
name2 = body['bot_name'] or name2
context = body['context'] or context
greeting = body['greeting'] or greeting
user_bio = body['user_bio'] or ''
# History
user_input, custom_system_message, history = convert_history(messages)
generate_params.update({
'mode': body['mode'],
'name1': name1,
'name2': name2,
'context': context,
'greeting': greeting,
'user_bio': user_bio,
'instruction_template_str': instruction_template_str,
'custom_system_message': custom_system_message,
'chat_template_str': chat_template_str,
'chat-instruct_command': chat_instruct_command,
'tools': tools,
'history': history,
'stream': stream
})
max_tokens = generate_params['max_new_tokens']
if max_tokens in [None, 0]:
generate_params['max_new_tokens'] = 512
generate_params['auto_max_new_tokens'] = True
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
chat_logprobs_offset = [0] # mutable for closure access in streaming
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False, reasoning_content=None):
# begin streaming
delta = {}
if include_role:
delta['role'] = 'assistant'
delta['refusal'] = None
if content is not None:
delta['content'] = content
if reasoning_content is not None:
delta['reasoning_content'] = reasoning_content
if chunk_tool_calls:
delta['tool_calls'] = chunk_tool_calls
chunk = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [{
"index": 0,
"finish_reason": None,
"delta": delta,
"logprobs": None,
}],
}
if logprob_proc:
entries = _dict_to_logprob_entries(logprob_proc.token_alternatives)
formatted = format_chat_logprobs(entries)
if formatted:
chunk[resp_list][0]["logprobs"] = formatted
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0])
if entries:
formatted = format_chat_logprobs(entries)
if formatted:
chunk[resp_list][0]["logprobs"] = formatted
return chunk
# Check if usage should be included in streaming chunks per OpenAI spec
stream_options = body.get('stream_options')
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
# generate reply #######################################
if prompt_only:
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
yield {'prompt': prompt}
return
if stream:
chunk = chat_streaming_chunk('', include_role=True)
if include_usage:
chunk['usage'] = None
yield chunk
generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
answer = ''
seen_content = ''
seen_reasoning = ''
tool_calls = []
end_last_tool_call = 0
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
_tool_parsers = None
# Filter supported_tools when tool_choice specifies a particular function
if supported_tools and isinstance(tool_choice, dict):
specified_func = tool_choice.get("function", {}).get("name")
if specified_func and specified_func in supported_tools:
supported_tools = [specified_func]
if supported_tools is not None:
_template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '')
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
for a in generator:
answer = a['internal'][-1][1]
if supported_tools is not None:
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else []
if len(tool_call) > 0:
for tc in tool_call:
tc["id"] = get_tool_call_id()
if stream:
tc["index"] = len(tool_calls)
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
tool_calls.append(tc)
end_last_tool_call = len(answer)
# Stop generation before streaming content if tool_calls were detected,
# so that raw tool markup is not sent as content deltas.
if len(tool_calls) > 0:
break
if stream:
# Strip reasoning/thinking blocks so only final content is streamed.
# Reasoning is emitted separately as reasoning_content deltas.
reasoning, content = extract_reasoning(answer)
if reasoning is not None:
new_reasoning = reasoning[len(seen_reasoning):]
new_content = content[len(seen_content):]
else:
new_reasoning = None
new_content = answer[len(seen_content):]
if (not new_content and not new_reasoning) or chr(0xfffd) in (new_content or '') + (new_reasoning or ''):
continue
chunk = chat_streaming_chunk(
content=new_content if new_content else None,
reasoning_content=new_reasoning if new_reasoning else None,
)
if include_usage:
chunk['usage'] = None
if reasoning is not None:
seen_reasoning = reasoning
seen_content = content
else:
seen_content = answer
yield chunk
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
completion_token_count = len(encode(answer)[0])
if len(tool_calls) > 0:
stop_reason = "tool_calls"
elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
stop_reason = "length"
else:
stop_reason = "stop"
if stream:
chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls)
chunk[resp_list][0]['finish_reason'] = stop_reason
usage = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
if include_usage:
chunk['usage'] = None
yield chunk
# Separate usage-only chunk with choices: [] per OpenAI spec
yield {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [],
"usage": usage
}
else:
yield chunk
else:
reasoning, content = extract_reasoning(answer)
message = {
"role": "assistant",
"refusal": None,
"content": None if tool_calls else content,
**({"reasoning_content": reasoning} if reasoning else {}),
**({"tool_calls": tool_calls} if tool_calls else {}),
}
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": message,
"logprobs": None,
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc:
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
formatted = format_chat_logprobs(all_entries)
if formatted:
resp[resp_list][0]["logprobs"] = formatted
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
raw = getattr(shared.model, 'last_completion_probabilities', None)
if raw:
formatted = format_chat_logprobs(raw)
if formatted:
resp[resp_list][0]["logprobs"] = formatted
yield resp
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
object_type = 'text_completion'
created_time = int(time.time())
cmpl_id = "cmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
prompt_str = 'context' if is_legacy else 'prompt'
# Handle both prompt and messages format for unified multimodal support
if prompt_str not in body or body[prompt_str] is None:
if 'messages' in body:
# Convert messages format to prompt for completions endpoint
prompt_text = ""
for message in body.get('messages', []):
if isinstance(message, dict) and 'content' in message:
# Extract text content from multimodal messages
content = message['content']
if isinstance(content, str):
prompt_text += content
elif isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get('type') == 'text':
prompt_text += item.get('text', '')
# Allow empty prompts for image-only requests
body[prompt_str] = prompt_text
else:
raise InvalidRequestError("Missing required input", param=prompt_str)
# common params
generate_params = process_parameters(body, is_legacy=is_legacy)
max_tokens = generate_params['max_new_tokens']
generate_params['stream'] = stream
if stop_event is not None:
generate_params['stop_event'] = stop_event
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
suffix = body['suffix'] if body['suffix'] else ''
echo = body['echo']
# Add messages to generate_params if present for multimodal processing
if body.get('messages'):
generate_params['messages'] = body['messages']
raw_images = convert_openai_messages_to_images(generate_params['messages'])
if raw_images:
logger.info(f"Found {len(raw_images)} image(s) in request.")
generate_params['raw_images'] = raw_images
n_completions = body.get('n', 1) or 1
if not stream:
prompt_arg = body[prompt_str]
# Handle empty/None prompts (e.g., image-only requests)
if prompt_arg is None:
prompt_arg = ""
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and len(prompt_arg) > 0 and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0
choice_index = 0
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
original_seed = generate_params.get('seed', -1)
for _n in range(n_completions):
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
if original_seed >= 0:
generate_params['seed'] = original_seed + _n
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
if logprob_proc:
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
completion_logprobs = format_completion_logprobs(all_entries)
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
raw = getattr(shared.model, 'last_completion_probabilities', None)
completion_logprobs = format_completion_logprobs(raw)
else:
completion_logprobs = None
respi = {
"index": choice_index,
"finish_reason": stop_reason,
"text": prefix + answer + suffix,
"logprobs": completion_logprobs,
}
resp_list_data.append(respi)
choice_index += 1
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: resp_list_data,
"usage": {
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}
yield resp
else:
prompt = body[prompt_str]
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
# Check if usage should be included in streaming chunks per OpenAI spec
stream_options = body.get('stream_options')
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
cmpl_logprobs_offset = [0] # mutable for closure access in streaming
def text_streaming_chunk(content):
# begin streaming
if logprob_proc:
chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives))
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0])
chunk_logprobs = format_completion_logprobs(entries) if entries else None
else:
chunk_logprobs = None
chunk = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": chunk_logprobs,
}],
}
return chunk
chunk = text_streaming_chunk(prefix)
if include_usage:
chunk['usage'] = None
yield chunk
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
chunk = text_streaming_chunk(new_content)
if include_usage:
chunk['usage'] = None
yield chunk
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
chunk = text_streaming_chunk(suffix)
chunk[resp_list][0]["finish_reason"] = stop_reason
usage = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
if include_usage:
chunk['usage'] = None
yield chunk
# Separate usage-only chunk with choices: [] per OpenAI spec
yield {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
"system_fingerprint": None,
resp_list: [],
"usage": usage
}
else:
yield chunk
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
generator = chat_completions_common(body, is_legacy, stream=False, stop_event=stop_event)
return deque(generator, maxlen=1).pop()
def stream_chat_completions(body: dict, is_legacy: bool = False, stop_event=None):
for resp in chat_completions_common(body, is_legacy, stream=True, stop_event=stop_event):
yield resp
def completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
generator = completions_common(body, is_legacy, stream=False, stop_event=stop_event)
return deque(generator, maxlen=1).pop()
def stream_completions(body: dict, is_legacy: bool = False, stop_event=None):
for resp in completions_common(body, is_legacy, stream=True, stop_event=stop_event):
yield resp
def validateTools(tools: list[dict]):
# Validate each tool definition in the JSON array
valid_tools = None
for idx in range(len(tools)):
tool = tools[idx]
try:
tool_definition = ToolDefinition(**tool)
# Backfill defaults so Jinja2 templates don't crash on missing fields
func = tool.get("function", {})
if "description" not in func:
func["description"] = ""
if "parameters" not in func:
func["parameters"] = {"type": "object", "properties": {}}
if valid_tools is None:
valid_tools = []
valid_tools.append(tool)
except ValidationError:
raise InvalidRequestError(message=f"Invalid tool specification at index {idx}.", param='tools')
return valid_tools

96
modules/api/embeddings.py Normal file
View file

@ -0,0 +1,96 @@
import os
import numpy as np
from transformers import AutoModel
from .errors import ServiceUnavailableError
from .utils import debug_msg, float_list_to_base64
from modules.logging_colors import logger
embeddings_params_initialized = False
def initialize_embedding_params():
'''
using 'lazy loading' to avoid circular import
so this function will be executed only once
'''
global embeddings_params_initialized
if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", 'sentence-transformers/all-mpnet-base-v2')
embeddings_model = None
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", 'cpu')
if embeddings_device.lower() == 'auto':
embeddings_device = None
embeddings_params_initialized = True
def load_embedding_model(model: str):
try:
from sentence_transformers import SentenceTransformer
except ModuleNotFoundError:
logger.error("The sentence_transformers module has not been found. Please install it manually with pip install -U sentence-transformers.")
raise ModuleNotFoundError
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
print(f"Try embedding model: {model} on {embeddings_device}")
if 'jina-embeddings' in model:
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) # trust_remote_code is needed to use the encode method
embeddings_model = embeddings_model.to(embeddings_device)
else:
embeddings_model = SentenceTransformer(model, device=embeddings_device)
print(f"Loaded embedding model: {model}")
except Exception as e:
embeddings_model = None
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
def get_embeddings_model():
initialize_embedding_params()
global embeddings_model, st_model
if st_model and not embeddings_model:
load_embedding_model(st_model) # lazy load the model
return embeddings_model
def get_embeddings_model_name() -> str:
initialize_embedding_params()
global st_model
return st_model
def get_embeddings(input: list) -> np.ndarray:
model = get_embeddings_model()
debug_msg(f"embedding model : {model}")
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
return embedding
def embeddings(input: list, encoding_format: str) -> dict:
embeddings = get_embeddings(input)
if encoding_format == "base64":
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
else:
data = [{"object": "embedding", "embedding": emb.tolist(), "index": n} for n, emb in enumerate(embeddings)]
response = {
"object": "list",
"data": data,
"model": st_model, # return the real model
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
}
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
return response

31
modules/api/errors.py Normal file
View file

@ -0,0 +1,31 @@
class OpenAIError(Exception):
def __init__(self, message=None, code=500, internal_message=''):
self.message = message
self.code = code
self.internal_message = internal_message
def __repr__(self):
return "%s(message=%r, code=%d)" % (
self.__class__.__name__,
self.message,
self.code,
)
class InvalidRequestError(OpenAIError):
def __init__(self, message, param, code=400, internal_message=''):
super().__init__(message, code, internal_message)
self.param = param
def __repr__(self):
return "%s(message=%r, code=%d, param=%s)" % (
self.__class__.__name__,
self.message,
self.code,
self.param,
)
class ServiceUnavailableError(OpenAIError):
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
super().__init__(message, code, internal_message)

69
modules/api/images.py Normal file
View file

@ -0,0 +1,69 @@
"""
OpenAI-compatible image generation using local diffusion models.
"""
import base64
import io
import time
from .errors import ServiceUnavailableError
from modules import shared
def generations(request):
"""
Generate images using the loaded diffusion model.
Returns dict with 'created' timestamp and 'data' list of images.
"""
from modules.ui_image_generation import generate
if shared.image_model is None:
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
width, height = request.get_width_height()
# Build state dict: GenerationOptions fields + image-specific keys
state = request.model_dump()
state.update({
'image_model_menu': shared.image_model_name,
'image_prompt': request.prompt,
'image_neg_prompt': request.negative_prompt,
'image_width': width,
'image_height': height,
'image_steps': request.steps,
'image_seed': request.image_seed,
'image_batch_size': request.batch_size,
'image_batch_count': request.batch_count,
'image_cfg_scale': request.cfg_scale,
'image_llm_variations': False,
})
# Exhaust generator, keep final result
images = []
for images, _ in generate(state, save_images=False):
pass
if not images:
raise ServiceUnavailableError("Image generation failed or produced no images.")
# Build response
resp = {'created': int(time.time()), 'data': []}
for img in images:
b64 = _image_to_base64(img)
image_obj = {'revised_prompt': request.prompt}
if request.response_format == 'b64_json':
image_obj['b64_json'] = b64
else:
image_obj['url'] = f'data:image/png;base64,{b64}'
resp['data'].append(image_obj)
return resp
def _image_to_base64(image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')

9
modules/api/logits.py Normal file
View file

@ -0,0 +1,9 @@
from .completions import process_parameters
from modules.logits import get_next_logits
def _get_next_logits(body):
# Pre-process the input payload to simulate a real generation
use_samplers = body['use_samplers']
state = process_parameters(body)
return get_next_logits(body['prompt'], state, use_samplers, "", top_logits=body['top_logits'], return_dict=True)

85
modules/api/models.py Normal file
View file

@ -0,0 +1,85 @@
from modules import loaders, shared
from modules.logging_colors import logger
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.utils import get_available_loras, get_available_models
def get_current_model_info():
return {
'model_name': shared.model_name,
'lora_names': shared.lora_names,
'loader': shared.args.loader
}
def list_models():
return {'model_names': get_available_models()}
def list_models_openai_format():
"""Returns model list in OpenAI API format"""
if shared.model_name and shared.model_name != 'None':
data = [model_info_dict(shared.model_name)]
else:
data = []
return {
"object": "list",
"data": data
}
def model_info_dict(model_name: str) -> dict:
return {
"id": model_name,
"object": "model",
"created": 0,
"owned_by": "user"
}
def _load_model(data):
model_name = data["model_name"]
args = data["args"]
settings = data["settings"]
unload_model()
model_settings = get_model_metadata(model_name)
update_model_parameters(model_settings)
# Update shared.args with custom model loading settings
# Security: only allow keys that correspond to model loading
# parameters exposed in the UI. Never allow security-sensitive
# flags like trust_remote_code or extra_flags to be set via the API.
blocked_keys = {'extra_flags'}
allowed_keys = set(loaders.list_model_elements()) - blocked_keys
if args:
for k in args:
if k in allowed_keys and hasattr(shared.args, k):
setattr(shared.args, k, args[k])
shared.model, shared.tokenizer = load_model(model_name)
# Update shared.settings with custom generation defaults
if settings:
for k in settings:
if k in shared.settings:
shared.settings[k] = settings[k]
if k == 'truncation_length':
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
elif k == 'instruction_template':
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
def list_loras():
return {'lora_names': get_available_loras()[1:]}
def load_loras(lora_names):
add_lora_to_model(lora_names)
def unload_all_loras():
add_lora_to_model([])

View file

@ -0,0 +1,69 @@
import time
import numpy as np
from numpy.linalg import norm
from .embeddings import get_embeddings
moderations_disabled = False # return 0/false
category_embeddings = None
antonym_embeddings = None
categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"]
flag_threshold = 0.5
def get_category_embeddings() -> dict:
global category_embeddings, categories
if category_embeddings is None:
embeddings = get_embeddings(categories).tolist()
category_embeddings = dict(zip(categories, embeddings))
return category_embeddings
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
return np.dot(a, b) / (norm(a) * norm(b))
# seems most openai like with all-mpnet-base-v2
def mod_score(a: np.ndarray, b: np.ndarray) -> float:
return 2.0 * np.dot(a, b)
def moderations(input):
global category_embeddings, categories, flag_threshold, moderations_disabled
results = {
"id": f"modr-{int(time.time()*1e9)}",
"model": "text-moderation-001",
"results": [],
}
if moderations_disabled:
results['results'] = [{
'categories': dict([(C, False) for C in categories]),
'category_scores': dict([(C, 0.0) for C in categories]),
'flagged': False,
}]
return results
category_embeddings = get_category_embeddings()
# input, string or array
if isinstance(input, str):
input = [input]
for in_str in input:
for ine in get_embeddings([in_str]):
category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
flagged = any(category_flags.values())
results['results'].extend([{
'flagged': flagged,
'categories': category_flags,
'category_scores': category_scores,
}])
print(results)
return results

509
modules/api/script.py Normal file
View file

@ -0,0 +1,509 @@
import asyncio
import json
import logging
import os
import socket
import threading
import traceback
from collections import deque
from threading import Thread
import uvicorn
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
import modules.api.completions as OAIcompletions
import modules.api.logits as OAIlogits
import modules.api.models as OAImodels
from .tokens import token_count, token_decode, token_encode
from .errors import OpenAIError
from .utils import _start_cloudflared
from modules import shared
from modules.logging_colors import logger
from modules.models import unload_model
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation
from .typing import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatPromptResponse,
CompletionRequest,
CompletionResponse,
DecodeRequest,
DecodeResponse,
EmbeddingsRequest,
EmbeddingsResponse,
EncodeRequest,
EncodeResponse,
ImageGenerationRequest,
ImageGenerationResponse,
LoadLorasRequest,
LoadModelRequest,
LogitsRequest,
LogitsResponse,
LoraListResponse,
ModelInfoResponse,
ModelListResponse,
TokenCountResponse,
to_dict
)
async def _wait_for_disconnect(request: Request, stop_event: threading.Event):
"""Block until the client disconnects, then signal the stop_event."""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
stop_event.set()
return
def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
raise HTTPException(status_code=401, detail="Unauthorized")
def verify_admin_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.admin_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
raise HTTPException(status_code=401, detail="Unauthorized")
app = FastAPI()
check_key = [Depends(verify_api_key)]
check_admin_key = [Depends(verify_admin_key)]
# Configure CORS settings to allow all origins, methods, and headers
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.exception_handler(OpenAIError)
async def openai_error_handler(request: Request, exc: OpenAIError):
error_type = "server_error" if exc.code >= 500 else "invalid_request_error"
return JSONResponse(
status_code=exc.code,
content={"error": {
"message": exc.message,
"type": error_type,
"param": getattr(exc, 'param', None),
"code": None
}}
)
@app.middleware("http")
async def validate_host_header(request: Request, call_next):
# Be strict about only approving access to localhost by default
if not (shared.args.listen or shared.args.public_api):
host = request.headers.get("host", "").split(":")[0]
if host not in ["localhost", "127.0.0.1"]:
return JSONResponse(
status_code=400,
content={"detail": "Invalid host header"}
)
return await call_next(request)
@app.options("/", dependencies=check_key)
async def options_route():
return JSONResponse(content="OK")
@app.post('/v1/completions', response_model=CompletionResponse, dependencies=check_key)
async def openai_completions(request: Request, request_data: CompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
if request_data.stream:
if (request_data.n or 1) > 1:
return JSONResponse(
status_code=400,
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
)
stop_event = threading.Event()
async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
try:
async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected()
if disconnected:
break
yield {"data": json.dumps(resp)}
yield {"data": "[DONE]"}
finally:
stop_event.set()
response.close()
return EventSourceResponse(generator(), sep="\n") # SSE streaming
else:
stop_event = threading.Event()
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
try:
response = await asyncio.to_thread(
OAIcompletions.completions,
to_dict(request_data),
is_legacy=is_legacy,
stop_event=stop_event
)
finally:
stop_event.set()
monitor.cancel()
return JSONResponse(response)
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key)
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
if request_data.stream:
stop_event = threading.Event()
async def generator():
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
try:
async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected()
if disconnected:
break
yield {"data": json.dumps(resp)}
yield {"data": "[DONE]"}
finally:
stop_event.set()
response.close()
return EventSourceResponse(generator(), sep="\n") # SSE streaming
else:
stop_event = threading.Event()
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
try:
response = await asyncio.to_thread(
OAIcompletions.chat_completions,
to_dict(request_data),
is_legacy=is_legacy,
stop_event=stop_event
)
finally:
stop_event.set()
monitor.cancel()
return JSONResponse(response)
@app.get("/v1/models", dependencies=check_key)
@app.get("/v1/models/{model}", dependencies=check_key)
async def handle_models(request: Request):
path = request.url.path
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
if is_list:
response = OAImodels.list_models_openai_format()
else:
model_name = path[len('/v1/models/'):]
response = OAImodels.model_info_dict(model_name)
return JSONResponse(response)
@app.get('/v1/billing/usage', dependencies=check_key)
def handle_billing_usage():
'''
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
'''
return JSONResponse(content={"total_usage": 0})
@app.post('/v1/audio/transcriptions', dependencies=check_key)
async def handle_audio_transcription(request: Request):
import speech_recognition as sr
from pydub import AudioSegment
r = sr.Recognizer()
form = await request.form()
audio_file = await form["file"].read()
audio_data = AudioSegment.from_file(audio_file)
# Convert AudioSegment to raw data
raw_data = audio_data.raw_data
# Create AudioData object
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
whisper_language = form.getvalue('language', None)
whisper_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
transcription = {"text": ""}
try:
transcription["text"] = r.recognize_whisper(audio_data, language=whisper_language, model=whisper_model)
except sr.UnknownValueError:
print("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e:
print("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError"
return JSONResponse(content=transcription)
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
async def handle_image_generation(request_data: ImageGenerationRequest):
import modules.api.images as OAIimages
response = await asyncio.to_thread(OAIimages.generations, request_data)
return JSONResponse(response)
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
import modules.api.embeddings as OAIembeddings
input = request_data.input
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
if type(input) is str:
input = [input]
response = OAIembeddings.embeddings(input, request_data.encoding_format)
return JSONResponse(response)
@app.post("/v1/moderations", dependencies=check_key)
async def handle_moderations(request: Request):
import modules.api.moderations as OAImoderations
body = await request.json()
input = body["input"]
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
response = OAImoderations.moderations(input)
return JSONResponse(response)
@app.get("/v1/internal/health", dependencies=check_key)
async def handle_health_check():
return JSONResponse(content={"status": "ok"})
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
async def handle_token_encode(request_data: EncodeRequest):
response = token_encode(request_data.text)
return JSONResponse(response)
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
async def handle_token_decode(request_data: DecodeRequest):
response = token_decode(request_data.tokens)
return JSONResponse(response)
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
async def handle_token_count(request_data: EncodeRequest):
response = token_count(request_data.text)
return JSONResponse(response)
@app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key)
async def handle_logits(request_data: LogitsRequest):
'''
Given a prompt, returns the top 50 most likely logits as a dict.
The keys are the tokens, and the values are the probabilities.
'''
response = OAIlogits._get_next_logits(to_dict(request_data))
return JSONResponse(response)
@app.post('/v1/internal/chat-prompt', response_model=ChatPromptResponse, dependencies=check_key)
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
response = deque(generator, maxlen=1).pop()
return JSONResponse(response)
@app.post("/v1/internal/stop-generation", dependencies=check_key)
async def handle_stop_generation(request: Request):
stop_everything_event()
return JSONResponse(content="OK")
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key)
async def handle_model_info():
payload = OAImodels.get_current_model_info()
return JSONResponse(content=payload)
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key)
async def handle_list_models():
payload = OAImodels.list_models()
return JSONResponse(content=payload)
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
async def handle_load_model(request_data: LoadModelRequest):
'''
This endpoint is experimental and may change in the future.
The "args" parameter can be used to modify flags like "--load-in-4bit"
or "--n-gpu-layers" before loading a model. Example:
```
"args": {
"load_in_4bit": true,
"n_gpu_layers": 12
}
```
Note that those settings will remain after loading the model. So you
may need to change them back to load a second model.
The "settings" parameter is also a dict but with keys for the
shared.settings object. It can be used to modify the default instruction
template like this:
```
"settings": {
"instruction_template": "Alpaca"
}
```
'''
try:
OAImodels._load_model(to_dict(request_data))
return JSONResponse(content="OK")
except Exception:
traceback.print_exc()
raise HTTPException(status_code=400, detail="Failed to load the model.")
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
async def handle_unload_model():
unload_model()
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
async def handle_list_loras():
response = OAImodels.list_loras()
return JSONResponse(content=response)
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
async def handle_load_loras(request_data: LoadLorasRequest):
try:
OAImodels.load_loras(request_data.lora_names)
return JSONResponse(content="OK")
except Exception:
traceback.print_exc()
raise HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
async def handle_unload_loras():
OAImodels.unload_all_loras()
return JSONResponse(content="OK")
def find_available_port(starting_port):
"""Try the starting port, then find an available one if it's taken."""
try:
# Try to create a socket with the starting port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', starting_port))
return starting_port
except OSError:
# Port is already in use, so find a new one
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) # Bind to port 0 to get an available port
new_port = s.getsockname()[1]
logger.warning(f"Port {starting_port} is already in use. Using port {new_port} instead.")
return new_port
def run_server():
# Parse configuration
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
port = find_available_port(port)
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
# In the server configuration:
server_addrs = []
if shared.args.listen and shared.args.listen_host:
server_addrs.append(shared.args.listen_host)
else:
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
server_addrs.append('[::]' if shared.args.listen else '[::1]')
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1')
if not server_addrs:
raise Exception('you MUST enable IPv6 or IPv4 for the API to work')
# Log server information
if shared.args.public_api:
_start_cloudflared(
port,
shared.args.public_api_id,
max_attempts=3,
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}/v1\n')
)
else:
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
urls = [f'{url_proto}{addr}:{port}/v1' for addr in server_addrs]
if len(urls) > 1:
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
else:
logger.info('OpenAI-compatible API URL:\n\n' + '\n'.join(urls) + '\n')
# Log API keys
if shared.args.api_key:
if not shared.args.admin_key:
shared.args.admin_key = shared.args.api_key
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
# Start server
logging.getLogger("uvicorn.error").propagate = False
uvicorn.run(app, host=server_addrs, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile, access_log=False)
_server_started = False
def setup():
global _server_started
if _server_started:
return
_server_started = True
if shared.args.nowebui:
run_server()
else:
Thread(target=run_server, daemon=True).start()

26
modules/api/tokens.py Normal file
View file

@ -0,0 +1,26 @@
from modules.text_generation import decode, encode
def token_count(prompt):
tokens = encode(prompt)[0]
return {
'length': len(tokens)
}
def token_encode(input):
tokens = encode(input)[0]
if tokens.__class__.__name__ in ['Tensor', 'ndarray']:
tokens = tokens.tolist()
return {
'tokens': tokens,
'length': len(tokens),
}
def token_decode(tokens):
output = decode(tokens)
return {
'text': output
}

326
modules/api/typing.py Normal file
View file

@ -0,0 +1,326 @@
import json
import time
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, model_validator, validator
from modules import shared
class GenerationOptions(BaseModel):
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/user_data/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
dynatemp_low: float = shared.args.dynatemp_low
dynatemp_high: float = shared.args.dynatemp_high
dynatemp_exponent: float = shared.args.dynatemp_exponent
smoothing_factor: float = shared.args.smoothing_factor
smoothing_curve: float = shared.args.smoothing_curve
min_p: float = shared.args.min_p
top_k: int = shared.args.top_k
typical_p: float = shared.args.typical_p
xtc_threshold: float = shared.args.xtc_threshold
xtc_probability: float = shared.args.xtc_probability
epsilon_cutoff: float = shared.args.epsilon_cutoff
eta_cutoff: float = shared.args.eta_cutoff
tfs: float = shared.args.tfs
top_a: float = shared.args.top_a
top_n_sigma: float = shared.args.top_n_sigma
adaptive_target: float = shared.args.adaptive_target
adaptive_decay: float = shared.args.adaptive_decay
dry_multiplier: float = shared.args.dry_multiplier
dry_allowed_length: int = shared.args.dry_allowed_length
dry_base: float = shared.args.dry_base
repetition_penalty: float = shared.args.repetition_penalty
encoder_repetition_penalty: float = shared.args.encoder_repetition_penalty
no_repeat_ngram_size: int = shared.args.no_repeat_ngram_size
repetition_penalty_range: int = shared.args.repetition_penalty_range
penalty_alpha: float = shared.args.penalty_alpha
guidance_scale: float = shared.args.guidance_scale
mirostat_mode: int = shared.args.mirostat_mode
mirostat_tau: float = shared.args.mirostat_tau
mirostat_eta: float = shared.args.mirostat_eta
prompt_lookup_num_tokens: int = 0
max_tokens_second: int = 0
do_sample: bool = shared.args.do_sample
dynamic_temperature: bool = shared.args.dynamic_temperature
temperature_last: bool = shared.args.temperature_last
auto_max_new_tokens: bool = False
ban_eos_token: bool = False
add_bos_token: bool = True
enable_thinking: bool = shared.args.enable_thinking
reasoning_effort: str = shared.args.reasoning_effort
skip_special_tokens: bool = True
static_cache: bool = False
truncation_length: int = 0
seed: int = -1
sampler_priority: List[str] | str | None = Field(default=shared.args.sampler_priority, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
custom_token_bans: str = ""
negative_prompt: str = ''
dry_sequence_breakers: str = shared.args.dry_sequence_breakers
grammar_string: str = ""
class ToolDefinition(BaseModel):
function: 'ToolFunction'
type: str
class ToolFunction(BaseModel):
model_config = ConfigDict(extra='allow')
description: Optional[str] = None
name: str
parameters: Optional['ToolParameters'] = None
class ToolParameters(BaseModel):
model_config = ConfigDict(extra='allow')
properties: Optional[Dict[str, Any]] = None
required: Optional[list[str]] = None
type: str
description: Optional[str] = None
class FunctionCall(BaseModel):
name: str
arguments: Optional[str] = None
parameters: Optional[str] = None
@validator('arguments', allow_reuse=True)
def checkPropertyArgsOrParams(cls, v, values, **kwargs):
if not v and not values.get('parameters'):
raise ValueError("At least one of 'arguments' or 'parameters' must be provided as property in FunctionCall type")
return v
class ToolCall(BaseModel):
id: str
index: int
type: str
function: FunctionCall
class StreamOptions(BaseModel):
include_usage: bool | None = False
class CompletionRequestParams(BaseModel):
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.")
best_of: int | None = Field(default=1, description="Unused parameter.")
echo: bool | None = False
frequency_penalty: float | None = shared.args.frequency_penalty
logit_bias: dict | None = None
logprobs: int | None = None
max_tokens: int | None = 512
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.")
presence_penalty: float | None = shared.args.presence_penalty
stop: str | List[str] | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
suffix: str | None = None
temperature: float | None = shared.args.temperature
top_p: float | None = shared.args.top_p
user: str | None = Field(default=None, description="Unused parameter.")
@model_validator(mode='after')
def validate_prompt_or_messages(self):
if self.prompt is None and self.messages is None:
raise ValueError("Either 'prompt' or 'messages' must be provided")
return self
class CompletionRequest(GenerationOptions, CompletionRequestParams):
pass
class CompletionResponse(BaseModel):
id: str
choices: List[dict]
created: int = Field(default_factory=lambda: int(time.time()))
model: str
object: str = "text_completion"
usage: dict
class ChatCompletionRequestParams(BaseModel):
messages: List[dict]
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
frequency_penalty: float | None = shared.args.frequency_penalty
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
logit_bias: dict | None = None
logprobs: bool | None = None
top_logprobs: int | None = None
max_tokens: int | None = None
max_completion_tokens: int | None = None
n: int | None = Field(default=1, description="Unused parameter.")
presence_penalty: float | None = shared.args.presence_penalty
stop: str | List[str] | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
temperature: float | None = shared.args.temperature
top_p: float | None = shared.args.top_p
user: str | None = Field(default=None, description="Unused parameter.")
@model_validator(mode='after')
def resolve_max_tokens(self):
if self.max_tokens is None and self.max_completion_tokens is not None:
self.max_tokens = self.max_completion_tokens
return self
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
instruction_template_str: str | None = Field(default=None, description="A Jinja2 instruction template. If set, will take precedence over everything else.")
character: str | None = Field(default=None, description="A character defined under text-generation-webui/user_data/characters. If not set, the default \"Assistant\" character will be used.")
bot_name: str | None = Field(default=None, description="Overwrites the value set by character field.", alias="name2")
context: str | None = Field(default=None, description="Overwrites the value set by character field.")
greeting: str | None = Field(default=None, description="Overwrites the value set by character field.")
user_name: str | None = Field(default=None, description="Your name (the user). By default, it's \"You\".", alias="name1")
user_bio: str | None = Field(default=None, description="The user description/personality.")
chat_template_str: str | None = Field(default=None, description="Jinja2 template for chat.")
chat_instruct_command: str | None = "Continue the chat dialogue below. Write a single reply for the character \"<|character|>\".\n\n<|prompt|>"
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams):
pass
class ChatCompletionResponse(BaseModel):
id: str
choices: List[dict]
created: int = Field(default_factory=lambda: int(time.time()))
model: str
object: str = "chat.completion"
usage: dict
class ChatPromptResponse(BaseModel):
prompt: str
class EmbeddingsRequest(BaseModel):
input: str | List[str] | List[int] | List[List[int]]
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
encoding_format: str = Field(default="float", description="Can be float or base64.")
user: str | None = Field(default=None, description="Unused parameter.")
class EmbeddingsResponse(BaseModel):
index: int
embedding: List[float]
object: str = "embedding"
class EncodeRequest(BaseModel):
text: str
class EncodeResponse(BaseModel):
tokens: List[int]
length: int
class DecodeRequest(BaseModel):
tokens: List[int]
class DecodeResponse(BaseModel):
text: str
class TokenCountResponse(BaseModel):
length: int
class LogitsRequestParams(BaseModel):
prompt: str
use_samplers: bool = False
top_logits: int | None = 50
frequency_penalty: float | None = shared.args.frequency_penalty
max_tokens: int | None = 512
presence_penalty: float | None = shared.args.presence_penalty
temperature: float | None = shared.args.temperature
top_p: float | None = shared.args.top_p
class LogitsRequest(GenerationOptions, LogitsRequestParams):
pass
class LogitsResponse(BaseModel):
logits: Dict[str, float]
class ModelInfoResponse(BaseModel):
model_name: str
lora_names: List[str]
class ModelListResponse(BaseModel):
model_names: List[str]
class LoadModelRequest(BaseModel):
model_name: str
args: dict | None = None
settings: dict | None = None
class LoraListResponse(BaseModel):
lora_names: List[str]
class LoadLorasRequest(BaseModel):
lora_names: List[str]
class ImageGenerationRequest(BaseModel):
"""Image-specific parameters for generation."""
prompt: str
negative_prompt: str = ""
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
steps: int = Field(default=9, ge=1)
cfg_scale: float = Field(default=0.0, ge=0.0)
image_seed: int = Field(default=-1, description="-1 for random")
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
# OpenAI compatibility (unused)
model: str | None = None
response_format: str = "b64_json"
user: str | None = None
@model_validator(mode='after')
def resolve_batch_size(self):
if self.batch_size is None:
self.batch_size = self.n
return self
def get_width_height(self) -> tuple[int, int]:
try:
parts = self.size.lower().split('x')
return int(parts[0]), int(parts[1])
except (ValueError, IndexError):
return 1024, 1024
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: List[dict]
def to_json(obj):
return json.dumps(obj.__dict__, indent=4)
def to_dict(obj):
return obj.__dict__

53
modules/api/utils.py Normal file
View file

@ -0,0 +1,53 @@
import base64
import os
import time
import traceback
from typing import Callable, Optional
import numpy as np
def float_list_to_base64(float_array: np.ndarray) -> str:
# Convert the list to a float32 array that the OpenAPI client expects
# float_array = np.array(float_list, dtype="float32")
# Get raw bytes
bytes_array = float_array.tobytes()
# Encode bytes into base64
encoded_bytes = base64.b64encode(bytes_array)
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode('ascii')
return ascii_string
def debug_msg(*args, **kwargs):
if os.environ.get("OPENEDAI_DEBUG", 0):
print(*args, **kwargs)
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
try:
from flask_cloudflared import _run_cloudflared
except ImportError:
print('You should install flask_cloudflared manually')
raise Exception(
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
for _ in range(max_attempts):
try:
if tunnel_id is not None:
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
else:
public_url = _run_cloudflared(port, port + 1)
if on_start:
on_start(public_url)
return
except Exception:
traceback.print_exc()
time.sleep(3)
raise Exception('Could not start cloudflared.')

View file

@ -32,8 +32,7 @@ def load_extensions():
if name not in available_extensions:
continue
if name != 'api':
logger.info(f'Loading the extension "{name}"')
logger.info(f'Loading the extension "{name}"')
try:
# Prefer user extension, fall back to system extension

View file

@ -156,7 +156,7 @@ group.add_argument('--portable', action='store_true', help='Hide features not av
# API
group = parser.add_argument_group('API')
group.add_argument('--api', action='store_true', help='Enable the API extension.')
group.add_argument('--api', action='store_true', help='Enable the API server.')
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.')
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
@ -435,16 +435,6 @@ def fix_loader_name(name):
return 'TensorRT-LLM'
def add_extension(name, last=False):
if args.extensions is None:
args.extensions = [name]
elif last:
args.extensions = [x for x in args.extensions if x != name]
args.extensions.append(name)
elif name not in args.extensions:
args.extensions.append(name)
def is_chat():
return True
@ -464,10 +454,6 @@ def load_user_config():
args.loader = fix_loader_name(args.loader)
# Activate the API extension
if args.api or args.public_api:
add_extension('openai', last=True)
# Load model-specific settings
p = Path(f'{args.model_dir}/config.yaml')
if p.exists():

View file

@ -95,8 +95,6 @@ def set_interface_arguments(extensions, bool_active):
setattr(shared.args, k, False)
for k in bool_active:
setattr(shared.args, k, True)
if k == 'api':
shared.add_extension('openai', last=True)
shared.need_restart = True