mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-05 14:45:28 +00:00
API: Move OpenAI-compatible API from extensions/openai to modules/api
This commit is contained in:
parent
2e4232e02b
commit
bf6fbc019d
23 changed files with 51 additions and 65 deletions
0
modules/api/__init__.py
Normal file
0
modules/api/__init__.py
Normal file
11
modules/api/cache_embedding_model.py
Normal file
11
modules/api/cache_embedding_model.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
# preload the embedding model, useful for Docker images to prevent re-download on config change
|
||||
# Dockerfile:
|
||||
# ENV OPENEDAI_EMBEDDING_MODEL="sentence-transformers/all-mpnet-base-v2" # Optional
|
||||
# RUN python3 cache_embedded_model.py
|
||||
import os
|
||||
|
||||
import sentence_transformers
|
||||
|
||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
|
||||
model = sentence_transformers.SentenceTransformer(st_model)
|
||||
907
modules/api/completions.py
Normal file
907
modules/api/completions.py
Normal file
|
|
@ -0,0 +1,907 @@
|
|||
import copy
|
||||
import functools
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import tiktoken
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .errors import InvalidRequestError
|
||||
from .typing import ToolDefinition
|
||||
from .utils import debug_msg
|
||||
from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format
|
||||
from modules import shared
|
||||
from modules.reasoning import extract_reasoning
|
||||
from modules.chat import (
|
||||
generate_chat_prompt,
|
||||
generate_chat_reply,
|
||||
load_character_memoized,
|
||||
load_instruction_template_memoized
|
||||
)
|
||||
from modules.image_utils import convert_openai_messages_to_images
|
||||
from modules.logging_colors import logger
|
||||
from modules.presets import load_preset_memoized
|
||||
from modules.text_generation import decode, encode, generate_reply
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_chat_template_file(filepath):
|
||||
"""Load a chat template from a file path (.jinja, .jinja2, or .yaml/.yml)."""
|
||||
filepath = Path(filepath)
|
||||
ext = filepath.suffix.lower()
|
||||
text = filepath.read_text(encoding='utf-8')
|
||||
if ext in ['.yaml', '.yml']:
|
||||
data = yaml.safe_load(text)
|
||||
return data.get('instruction_template', '')
|
||||
return text
|
||||
|
||||
|
||||
def _get_raw_logprob_entries(offset=0):
|
||||
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
|
||||
|
||||
Returns (new_entries, new_offset).
|
||||
"""
|
||||
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
|
||||
return [], offset
|
||||
|
||||
all_entries = shared.model.last_completion_probabilities
|
||||
new_entries = all_entries[offset:]
|
||||
return new_entries, len(all_entries)
|
||||
|
||||
|
||||
def _dict_to_logprob_entries(token_dict):
|
||||
"""Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format."""
|
||||
if not token_dict:
|
||||
return []
|
||||
|
||||
return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}]
|
||||
|
||||
|
||||
def _parse_entry_top(entry):
|
||||
"""Extract the top logprobs list from a raw entry, handling both key names."""
|
||||
return entry.get('top_logprobs', entry.get('top_probs', []))
|
||||
|
||||
|
||||
def format_chat_logprobs(entries):
|
||||
"""Format logprob entries into OpenAI chat completions logprobs format.
|
||||
|
||||
Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]}
|
||||
"""
|
||||
if not entries:
|
||||
return None
|
||||
|
||||
content = []
|
||||
for entry in entries:
|
||||
top = _parse_entry_top(entry)
|
||||
if not top:
|
||||
continue
|
||||
|
||||
chosen = top[0]
|
||||
token_str = chosen.get('token', '')
|
||||
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||
|
||||
top_list = []
|
||||
for item in top:
|
||||
t = item.get('token', '')
|
||||
lp = item.get('logprob', item.get('prob', 0))
|
||||
top_list.append({
|
||||
"token": t,
|
||||
"logprob": lp,
|
||||
"bytes": list(t.encode('utf-8')) if t else None
|
||||
})
|
||||
|
||||
content.append({
|
||||
"token": token_str,
|
||||
"logprob": token_logprob,
|
||||
"bytes": list(token_str.encode('utf-8')) if token_str else None,
|
||||
"top_logprobs": top_list
|
||||
})
|
||||
|
||||
return {"content": content, "refusal": None} if content else None
|
||||
|
||||
|
||||
def format_completion_logprobs(entries):
|
||||
"""Format logprob entries into OpenAI completions logprobs format.
|
||||
|
||||
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
|
||||
"""
|
||||
if not entries:
|
||||
return None
|
||||
|
||||
tokens = []
|
||||
token_logprobs = []
|
||||
top_logprobs = []
|
||||
text_offset = []
|
||||
offset = 0
|
||||
|
||||
for entry in entries:
|
||||
top = _parse_entry_top(entry)
|
||||
if not top:
|
||||
continue
|
||||
|
||||
chosen = top[0]
|
||||
token_str = chosen.get('token', '')
|
||||
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||
|
||||
tokens.append(token_str)
|
||||
token_logprobs.append(token_logprob)
|
||||
text_offset.append(offset)
|
||||
offset += len(token_str)
|
||||
|
||||
top_dict = {}
|
||||
for item in top:
|
||||
t = item.get('token', '')
|
||||
lp = item.get('logprob', item.get('prob', 0))
|
||||
top_dict[t] = lp
|
||||
top_logprobs.append(top_dict)
|
||||
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
return {
|
||||
"tokens": tokens,
|
||||
"token_logprobs": token_logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"text_offset": text_offset
|
||||
}
|
||||
|
||||
|
||||
def process_parameters(body, is_legacy=False):
|
||||
generate_params = body
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
|
||||
if generate_params['truncation_length'] == 0:
|
||||
generate_params['truncation_length'] = shared.settings['truncation_length']
|
||||
|
||||
if generate_params['temperature'] == 0:
|
||||
generate_params['do_sample'] = False
|
||||
generate_params['top_k'] = 1
|
||||
|
||||
if body['preset'] is not None:
|
||||
preset = load_preset_memoized(body['preset'])
|
||||
generate_params.update(preset)
|
||||
|
||||
generate_params['custom_stopping_strings'] = []
|
||||
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||
if isinstance(body['stop'], str):
|
||||
generate_params['custom_stopping_strings'] = [body['stop']]
|
||||
elif isinstance(body['stop'], list):
|
||||
generate_params['custom_stopping_strings'] = body['stop']
|
||||
|
||||
# Resolve logprobs: for chat completions, logprobs is a bool and the count
|
||||
# comes from top_logprobs. Normalize to an int for all backends.
|
||||
logprobs = body.get('logprobs', None)
|
||||
top_logprobs = body.get('top_logprobs', None)
|
||||
if logprobs is True:
|
||||
logprobs = max(top_logprobs, 1) if top_logprobs is not None else 5
|
||||
generate_params['logprobs'] = logprobs
|
||||
|
||||
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
|
||||
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
|
||||
from transformers import LogitsProcessorList
|
||||
|
||||
from modules.transformers_loader import (
|
||||
LogitsBiasProcessor,
|
||||
LogprobProcessor
|
||||
)
|
||||
|
||||
logits_processor = []
|
||||
logit_bias = body.get('logit_bias', None)
|
||||
if logit_bias: # {str: float, ...}
|
||||
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||
|
||||
if logprobs is not None and logprobs > 0:
|
||||
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([generate_params['logprob_proc']])
|
||||
|
||||
if logits_processor: # requires logits_processor support
|
||||
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
def process_multimodal_content(content):
|
||||
"""Extract text and add image placeholders from OpenAI multimodal format"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
image_placeholders = ""
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get('type', '')
|
||||
if item_type == 'text':
|
||||
text_parts.append(item.get('text', ''))
|
||||
elif item_type == 'image_url':
|
||||
image_placeholders += "<__media__>"
|
||||
|
||||
final_text = ' '.join(text_parts)
|
||||
if image_placeholders:
|
||||
return f"{image_placeholders}\n\n{final_text}"
|
||||
else:
|
||||
return final_text
|
||||
|
||||
return str(content)
|
||||
|
||||
|
||||
def convert_history(history):
|
||||
'''
|
||||
Chat histories in this program are in the format [message, reply].
|
||||
This function converts OpenAI histories to that format.
|
||||
'''
|
||||
chat_dialogue = []
|
||||
current_message = ""
|
||||
current_reply = ""
|
||||
user_input = ""
|
||||
user_input_last = True
|
||||
system_message = ""
|
||||
seen_non_system = False
|
||||
|
||||
for entry in history:
|
||||
content = entry["content"]
|
||||
role = entry["role"]
|
||||
|
||||
if role == "user":
|
||||
seen_non_system = True
|
||||
# Extract text content (images handled by model-specific code)
|
||||
content = process_multimodal_content(content)
|
||||
user_input = content
|
||||
user_input_last = True
|
||||
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, '', '', {}])
|
||||
current_message = ""
|
||||
|
||||
current_message = content
|
||||
elif role == "assistant":
|
||||
seen_non_system = True
|
||||
meta = {}
|
||||
tool_calls = entry.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list):
|
||||
meta["tool_calls"] = tool_calls
|
||||
if content.strip() == "":
|
||||
content = "" # keep empty content, don't skip
|
||||
|
||||
current_reply = content
|
||||
user_input_last = False
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, current_reply, '', meta])
|
||||
current_message = ""
|
||||
current_reply = ""
|
||||
else:
|
||||
chat_dialogue.append(['', current_reply, '', meta])
|
||||
elif role == "tool":
|
||||
seen_non_system = True
|
||||
user_input_last = False
|
||||
meta = {}
|
||||
if "tool_call_id" in entry:
|
||||
meta["tool_call_id"] = entry["tool_call_id"]
|
||||
chat_dialogue.append(['', '', content, meta])
|
||||
elif role in ("system", "developer"):
|
||||
if not seen_non_system:
|
||||
# Leading system messages go to custom_system_message (placed at top)
|
||||
system_message += f"\n{content}" if system_message else content
|
||||
else:
|
||||
# Mid-conversation system messages: preserve position in history
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, '', '', {}])
|
||||
current_message = ""
|
||||
chat_dialogue.append([content, '', '', {"role": "system"}])
|
||||
|
||||
if not user_input_last:
|
||||
user_input = ""
|
||||
|
||||
return user_input, system_message, {
|
||||
'internal': chat_dialogue,
|
||||
'visible': copy.deepcopy(chat_dialogue),
|
||||
'messages': history # Store original messages for multimodal models
|
||||
}
|
||||
|
||||
|
||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False, stop_event=None) -> dict:
|
||||
if body.get('functions', []):
|
||||
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||
|
||||
if body.get('function_call', ''):
|
||||
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||
|
||||
if 'messages' not in body:
|
||||
raise InvalidRequestError(message="messages is required", param='messages')
|
||||
|
||||
tools = None
|
||||
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and body['tools']:
|
||||
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
|
||||
|
||||
tool_choice = body.get('tool_choice', None)
|
||||
if tool_choice == "none":
|
||||
tools = None # Disable tool detection entirely
|
||||
|
||||
messages = body['messages']
|
||||
for m in messages:
|
||||
if 'role' not in m:
|
||||
raise InvalidRequestError(message="messages: missing role", param='messages')
|
||||
elif m['role'] == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
|
||||
# Handle multimodal content validation
|
||||
content = m.get('content')
|
||||
if content is None:
|
||||
# OpenAI allows content: null on assistant messages when tool_calls is present
|
||||
if m['role'] == 'assistant' and m.get('tool_calls'):
|
||||
m['content'] = ''
|
||||
else:
|
||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||
|
||||
# Validate multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or 'type' not in item:
|
||||
raise InvalidRequestError(message="messages: invalid content item format", param='messages')
|
||||
if item['type'] not in ['text', 'image_url']:
|
||||
raise InvalidRequestError(message="messages: unsupported content type", param='messages')
|
||||
if item['type'] == 'text' and 'text' not in item:
|
||||
raise InvalidRequestError(message="messages: missing text in content item", param='messages')
|
||||
if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']):
|
||||
raise InvalidRequestError(message="messages: missing image_url in content item", param='messages')
|
||||
|
||||
# Chat Completions
|
||||
object_type = 'chat.completion' if not stream else 'chat.completion.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# generation parameters
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
if stop_event is not None:
|
||||
generate_params['stop_event'] = stop_event
|
||||
continue_ = body['continue_']
|
||||
|
||||
# Instruction template
|
||||
if body['instruction_template_str']:
|
||||
instruction_template_str = body['instruction_template_str']
|
||||
elif body['instruction_template']:
|
||||
instruction_template = body['instruction_template']
|
||||
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
|
||||
instruction_template_str = load_instruction_template_memoized(instruction_template)
|
||||
elif shared.args.chat_template_file:
|
||||
instruction_template_str = load_chat_template_file(shared.args.chat_template_file)
|
||||
else:
|
||||
instruction_template_str = shared.settings['instruction_template_str']
|
||||
|
||||
chat_template_str = body['chat_template_str'] or shared.default_settings['chat_template_str']
|
||||
chat_instruct_command = body['chat_instruct_command'] or shared.default_settings['chat-instruct_command']
|
||||
|
||||
# Chat character
|
||||
character = body['character'] or shared.default_settings['character']
|
||||
character = "Assistant" if character == "None" else character
|
||||
name1 = body['user_name'] or shared.default_settings['name1']
|
||||
name1, name2, _, greeting, context = load_character_memoized(character, name1, '')
|
||||
name2 = body['bot_name'] or name2
|
||||
context = body['context'] or context
|
||||
greeting = body['greeting'] or greeting
|
||||
user_bio = body['user_bio'] or ''
|
||||
|
||||
# History
|
||||
user_input, custom_system_message, history = convert_history(messages)
|
||||
|
||||
generate_params.update({
|
||||
'mode': body['mode'],
|
||||
'name1': name1,
|
||||
'name2': name2,
|
||||
'context': context,
|
||||
'greeting': greeting,
|
||||
'user_bio': user_bio,
|
||||
'instruction_template_str': instruction_template_str,
|
||||
'custom_system_message': custom_system_message,
|
||||
'chat_template_str': chat_template_str,
|
||||
'chat-instruct_command': chat_instruct_command,
|
||||
'tools': tools,
|
||||
'history': history,
|
||||
'stream': stream
|
||||
})
|
||||
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
if max_tokens in [None, 0]:
|
||||
generate_params['max_new_tokens'] = 512
|
||||
generate_params['auto_max_new_tokens'] = True
|
||||
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
if logprob_proc:
|
||||
logprob_proc.token_alternatives_history.clear()
|
||||
chat_logprobs_offset = [0] # mutable for closure access in streaming
|
||||
|
||||
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False, reasoning_content=None):
|
||||
# begin streaming
|
||||
delta = {}
|
||||
if include_role:
|
||||
delta['role'] = 'assistant'
|
||||
delta['refusal'] = None
|
||||
if content is not None:
|
||||
delta['content'] = content
|
||||
if reasoning_content is not None:
|
||||
delta['reasoning_content'] = reasoning_content
|
||||
if chunk_tool_calls:
|
||||
delta['tool_calls'] = chunk_tool_calls
|
||||
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"delta": delta,
|
||||
"logprobs": None,
|
||||
}],
|
||||
}
|
||||
|
||||
if logprob_proc:
|
||||
entries = _dict_to_logprob_entries(logprob_proc.token_alternatives)
|
||||
formatted = format_chat_logprobs(entries)
|
||||
if formatted:
|
||||
chunk[resp_list][0]["logprobs"] = formatted
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0])
|
||||
if entries:
|
||||
formatted = format_chat_logprobs(entries)
|
||||
if formatted:
|
||||
chunk[resp_list][0]["logprobs"] = formatted
|
||||
|
||||
return chunk
|
||||
|
||||
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||
stream_options = body.get('stream_options')
|
||||
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||
|
||||
# generate reply #######################################
|
||||
if prompt_only:
|
||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
||||
yield {'prompt': prompt}
|
||||
return
|
||||
|
||||
if stream:
|
||||
chunk = chat_streaming_chunk('', include_role=True)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
seen_reasoning = ''
|
||||
|
||||
tool_calls = []
|
||||
end_last_tool_call = 0
|
||||
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
|
||||
_tool_parsers = None
|
||||
|
||||
# Filter supported_tools when tool_choice specifies a particular function
|
||||
if supported_tools and isinstance(tool_choice, dict):
|
||||
specified_func = tool_choice.get("function", {}).get("name")
|
||||
if specified_func and specified_func in supported_tools:
|
||||
supported_tools = [specified_func]
|
||||
|
||||
if supported_tools is not None:
|
||||
_template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '')
|
||||
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
|
||||
|
||||
for a in generator:
|
||||
answer = a['internal'][-1][1]
|
||||
|
||||
if supported_tools is not None:
|
||||
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else []
|
||||
if len(tool_call) > 0:
|
||||
for tc in tool_call:
|
||||
tc["id"] = get_tool_call_id()
|
||||
if stream:
|
||||
tc["index"] = len(tool_calls)
|
||||
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
||||
tool_calls.append(tc)
|
||||
end_last_tool_call = len(answer)
|
||||
|
||||
# Stop generation before streaming content if tool_calls were detected,
|
||||
# so that raw tool markup is not sent as content deltas.
|
||||
if len(tool_calls) > 0:
|
||||
break
|
||||
|
||||
if stream:
|
||||
# Strip reasoning/thinking blocks so only final content is streamed.
|
||||
# Reasoning is emitted separately as reasoning_content deltas.
|
||||
reasoning, content = extract_reasoning(answer)
|
||||
if reasoning is not None:
|
||||
new_reasoning = reasoning[len(seen_reasoning):]
|
||||
new_content = content[len(seen_content):]
|
||||
else:
|
||||
new_reasoning = None
|
||||
new_content = answer[len(seen_content):]
|
||||
|
||||
if (not new_content and not new_reasoning) or chr(0xfffd) in (new_content or '') + (new_reasoning or ''):
|
||||
continue
|
||||
|
||||
chunk = chat_streaming_chunk(
|
||||
content=new_content if new_content else None,
|
||||
reasoning_content=new_reasoning if new_reasoning else None,
|
||||
)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
|
||||
if reasoning is not None:
|
||||
seen_reasoning = reasoning
|
||||
seen_content = content
|
||||
else:
|
||||
seen_content = answer
|
||||
yield chunk
|
||||
|
||||
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
if len(tool_calls) > 0:
|
||||
stop_reason = "tool_calls"
|
||||
elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
else:
|
||||
stop_reason = "stop"
|
||||
|
||||
if stream:
|
||||
chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls)
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
usage = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
# Separate usage-only chunk with choices: [] per OpenAI spec
|
||||
yield {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [],
|
||||
"usage": usage
|
||||
}
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
reasoning, content = extract_reasoning(answer)
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"refusal": None,
|
||||
"content": None if tool_calls else content,
|
||||
**({"reasoning_content": reasoning} if reasoning else {}),
|
||||
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||
}
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": message,
|
||||
"logprobs": None,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc:
|
||||
all_entries = []
|
||||
for alt in logprob_proc.token_alternatives_history:
|
||||
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||
formatted = format_chat_logprobs(all_entries)
|
||||
if formatted:
|
||||
resp[resp_list][0]["logprobs"] = formatted
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||
if raw:
|
||||
formatted = format_chat_logprobs(raw)
|
||||
if formatted:
|
||||
resp[resp_list][0]["logprobs"] = formatted
|
||||
|
||||
yield resp
|
||||
|
||||
|
||||
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
|
||||
object_type = 'text_completion'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "cmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
|
||||
# Handle both prompt and messages format for unified multimodal support
|
||||
if prompt_str not in body or body[prompt_str] is None:
|
||||
if 'messages' in body:
|
||||
# Convert messages format to prompt for completions endpoint
|
||||
prompt_text = ""
|
||||
for message in body.get('messages', []):
|
||||
if isinstance(message, dict) and 'content' in message:
|
||||
# Extract text content from multimodal messages
|
||||
content = message['content']
|
||||
if isinstance(content, str):
|
||||
prompt_text += content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get('type') == 'text':
|
||||
prompt_text += item.get('text', '')
|
||||
|
||||
# Allow empty prompts for image-only requests
|
||||
body[prompt_str] = prompt_text
|
||||
else:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
# common params
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
generate_params['stream'] = stream
|
||||
if stop_event is not None:
|
||||
generate_params['stop_event'] = stop_event
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
if logprob_proc:
|
||||
logprob_proc.token_alternatives_history.clear()
|
||||
suffix = body['suffix'] if body['suffix'] else ''
|
||||
echo = body['echo']
|
||||
|
||||
# Add messages to generate_params if present for multimodal processing
|
||||
if body.get('messages'):
|
||||
generate_params['messages'] = body['messages']
|
||||
raw_images = convert_openai_messages_to_images(generate_params['messages'])
|
||||
if raw_images:
|
||||
logger.info(f"Found {len(raw_images)} image(s) in request.")
|
||||
generate_params['raw_images'] = raw_images
|
||||
|
||||
n_completions = body.get('n', 1) or 1
|
||||
|
||||
if not stream:
|
||||
prompt_arg = body[prompt_str]
|
||||
|
||||
# Handle empty/None prompts (e.g., image-only requests)
|
||||
if prompt_arg is None:
|
||||
prompt_arg = ""
|
||||
|
||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and len(prompt_arg) > 0 and isinstance(prompt_arg[0], int)):
|
||||
prompt_arg = [prompt_arg]
|
||||
|
||||
resp_list_data = []
|
||||
total_completion_token_count = 0
|
||||
total_prompt_token_count = 0
|
||||
choice_index = 0
|
||||
|
||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
|
||||
# token lists
|
||||
if requested_model == shared.model_name:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
|
||||
prefix = prompt if echo else ''
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
|
||||
original_seed = generate_params.get('seed', -1)
|
||||
for _n in range(n_completions):
|
||||
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
|
||||
if original_seed >= 0:
|
||||
generate_params['seed'] = original_seed + _n
|
||||
|
||||
if logprob_proc:
|
||||
logprob_proc.token_alternatives_history.clear()
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
if logprob_proc:
|
||||
all_entries = []
|
||||
for alt in logprob_proc.token_alternatives_history:
|
||||
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||
completion_logprobs = format_completion_logprobs(all_entries)
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||
completion_logprobs = format_completion_logprobs(raw)
|
||||
else:
|
||||
completion_logprobs = None
|
||||
|
||||
respi = {
|
||||
"index": choice_index,
|
||||
"finish_reason": stop_reason,
|
||||
"text": prefix + answer + suffix,
|
||||
"logprobs": completion_logprobs,
|
||||
}
|
||||
|
||||
resp_list_data.append(respi)
|
||||
choice_index += 1
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: resp_list_data,
|
||||
"usage": {
|
||||
"prompt_tokens": total_prompt_token_count,
|
||||
"completion_tokens": total_completion_token_count,
|
||||
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
yield resp
|
||||
else:
|
||||
prompt = body[prompt_str]
|
||||
if isinstance(prompt, list):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
prefix = prompt if echo else ''
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||
stream_options = body.get('stream_options')
|
||||
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||
cmpl_logprobs_offset = [0] # mutable for closure access in streaming
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
if logprob_proc:
|
||||
chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives))
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0])
|
||||
chunk_logprobs = format_completion_logprobs(entries) if entries else None
|
||||
else:
|
||||
chunk_logprobs = None
|
||||
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
"logprobs": chunk_logprobs,
|
||||
}],
|
||||
}
|
||||
|
||||
return chunk
|
||||
|
||||
chunk = text_streaming_chunk(prefix)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = text_streaming_chunk(suffix)
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
usage = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
# Separate usage-only chunk with choices: [] per OpenAI spec
|
||||
yield {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [],
|
||||
"usage": usage
|
||||
}
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
|
||||
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||
generator = chat_completions_common(body, is_legacy, stream=False, stop_event=stop_event)
|
||||
return deque(generator, maxlen=1).pop()
|
||||
|
||||
|
||||
def stream_chat_completions(body: dict, is_legacy: bool = False, stop_event=None):
|
||||
for resp in chat_completions_common(body, is_legacy, stream=True, stop_event=stop_event):
|
||||
yield resp
|
||||
|
||||
|
||||
def completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||
generator = completions_common(body, is_legacy, stream=False, stop_event=stop_event)
|
||||
return deque(generator, maxlen=1).pop()
|
||||
|
||||
|
||||
def stream_completions(body: dict, is_legacy: bool = False, stop_event=None):
|
||||
for resp in completions_common(body, is_legacy, stream=True, stop_event=stop_event):
|
||||
yield resp
|
||||
|
||||
|
||||
def validateTools(tools: list[dict]):
|
||||
# Validate each tool definition in the JSON array
|
||||
valid_tools = None
|
||||
for idx in range(len(tools)):
|
||||
tool = tools[idx]
|
||||
try:
|
||||
tool_definition = ToolDefinition(**tool)
|
||||
# Backfill defaults so Jinja2 templates don't crash on missing fields
|
||||
func = tool.get("function", {})
|
||||
if "description" not in func:
|
||||
func["description"] = ""
|
||||
if "parameters" not in func:
|
||||
func["parameters"] = {"type": "object", "properties": {}}
|
||||
if valid_tools is None:
|
||||
valid_tools = []
|
||||
valid_tools.append(tool)
|
||||
except ValidationError:
|
||||
raise InvalidRequestError(message=f"Invalid tool specification at index {idx}.", param='tools')
|
||||
|
||||
return valid_tools
|
||||
96
modules/api/embeddings.py
Normal file
96
modules/api/embeddings.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
from transformers import AutoModel
|
||||
|
||||
from .errors import ServiceUnavailableError
|
||||
from .utils import debug_msg, float_list_to_base64
|
||||
from modules.logging_colors import logger
|
||||
|
||||
embeddings_params_initialized = False
|
||||
|
||||
|
||||
def initialize_embedding_params():
|
||||
'''
|
||||
using 'lazy loading' to avoid circular import
|
||||
so this function will be executed only once
|
||||
'''
|
||||
global embeddings_params_initialized
|
||||
if not embeddings_params_initialized:
|
||||
global st_model, embeddings_model, embeddings_device
|
||||
|
||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", 'sentence-transformers/all-mpnet-base-v2')
|
||||
embeddings_model = None
|
||||
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
||||
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", 'cpu')
|
||||
if embeddings_device.lower() == 'auto':
|
||||
embeddings_device = None
|
||||
|
||||
embeddings_params_initialized = True
|
||||
|
||||
|
||||
def load_embedding_model(model: str):
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ModuleNotFoundError:
|
||||
logger.error("The sentence_transformers module has not been found. Please install it manually with pip install -U sentence-transformers.")
|
||||
raise ModuleNotFoundError
|
||||
|
||||
initialize_embedding_params()
|
||||
global embeddings_device, embeddings_model
|
||||
try:
|
||||
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||
if 'jina-embeddings' in model:
|
||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) # trust_remote_code is needed to use the encode method
|
||||
embeddings_model = embeddings_model.to(embeddings_device)
|
||||
else:
|
||||
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
||||
|
||||
print(f"Loaded embedding model: {model}")
|
||||
except Exception as e:
|
||||
embeddings_model = None
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||
|
||||
|
||||
def get_embeddings_model():
|
||||
initialize_embedding_params()
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
load_embedding_model(st_model) # lazy load the model
|
||||
|
||||
return embeddings_model
|
||||
|
||||
|
||||
def get_embeddings_model_name() -> str:
|
||||
initialize_embedding_params()
|
||||
global st_model
|
||||
return st_model
|
||||
|
||||
|
||||
def get_embeddings(input: list) -> np.ndarray:
|
||||
model = get_embeddings_model()
|
||||
debug_msg(f"embedding model : {model}")
|
||||
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
|
||||
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
|
||||
return embedding
|
||||
|
||||
|
||||
def embeddings(input: list, encoding_format: str) -> dict:
|
||||
embeddings = get_embeddings(input)
|
||||
if encoding_format == "base64":
|
||||
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||
else:
|
||||
data = [{"object": "embedding", "embedding": emb.tolist(), "index": n} for n, emb in enumerate(embeddings)]
|
||||
|
||||
response = {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": st_model, # return the real model
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
}
|
||||
|
||||
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||
return response
|
||||
31
modules/api/errors.py
Normal file
31
modules/api/errors.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
class OpenAIError(Exception):
|
||||
def __init__(self, message=None, code=500, internal_message=''):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.internal_message = internal_message
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(message=%r, code=%d)" % (
|
||||
self.__class__.__name__,
|
||||
self.message,
|
||||
self.code,
|
||||
)
|
||||
|
||||
|
||||
class InvalidRequestError(OpenAIError):
|
||||
def __init__(self, message, param, code=400, internal_message=''):
|
||||
super().__init__(message, code, internal_message)
|
||||
self.param = param
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(message=%r, code=%d, param=%s)" % (
|
||||
self.__class__.__name__,
|
||||
self.message,
|
||||
self.code,
|
||||
self.param,
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableError(OpenAIError):
|
||||
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
|
||||
super().__init__(message, code, internal_message)
|
||||
69
modules/api/images.py
Normal file
69
modules/api/images.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
OpenAI-compatible image generation using local diffusion models.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
from .errors import ServiceUnavailableError
|
||||
from modules import shared
|
||||
|
||||
|
||||
def generations(request):
|
||||
"""
|
||||
Generate images using the loaded diffusion model.
|
||||
Returns dict with 'created' timestamp and 'data' list of images.
|
||||
"""
|
||||
from modules.ui_image_generation import generate
|
||||
|
||||
if shared.image_model is None:
|
||||
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
|
||||
|
||||
width, height = request.get_width_height()
|
||||
|
||||
# Build state dict: GenerationOptions fields + image-specific keys
|
||||
state = request.model_dump()
|
||||
state.update({
|
||||
'image_model_menu': shared.image_model_name,
|
||||
'image_prompt': request.prompt,
|
||||
'image_neg_prompt': request.negative_prompt,
|
||||
'image_width': width,
|
||||
'image_height': height,
|
||||
'image_steps': request.steps,
|
||||
'image_seed': request.image_seed,
|
||||
'image_batch_size': request.batch_size,
|
||||
'image_batch_count': request.batch_count,
|
||||
'image_cfg_scale': request.cfg_scale,
|
||||
'image_llm_variations': False,
|
||||
})
|
||||
|
||||
# Exhaust generator, keep final result
|
||||
images = []
|
||||
for images, _ in generate(state, save_images=False):
|
||||
pass
|
||||
|
||||
if not images:
|
||||
raise ServiceUnavailableError("Image generation failed or produced no images.")
|
||||
|
||||
# Build response
|
||||
resp = {'created': int(time.time()), 'data': []}
|
||||
for img in images:
|
||||
b64 = _image_to_base64(img)
|
||||
|
||||
image_obj = {'revised_prompt': request.prompt}
|
||||
|
||||
if request.response_format == 'b64_json':
|
||||
image_obj['b64_json'] = b64
|
||||
else:
|
||||
image_obj['url'] = f'data:image/png;base64,{b64}'
|
||||
|
||||
resp['data'].append(image_obj)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def _image_to_base64(image) -> str:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
9
modules/api/logits.py
Normal file
9
modules/api/logits.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from .completions import process_parameters
|
||||
from modules.logits import get_next_logits
|
||||
|
||||
|
||||
def _get_next_logits(body):
|
||||
# Pre-process the input payload to simulate a real generation
|
||||
use_samplers = body['use_samplers']
|
||||
state = process_parameters(body)
|
||||
return get_next_logits(body['prompt'], state, use_samplers, "", top_logits=body['top_logits'], return_dict=True)
|
||||
85
modules/api/models.py
Normal file
85
modules/api/models.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
from modules import loaders, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.utils import get_available_loras, get_available_models
|
||||
|
||||
|
||||
def get_current_model_info():
|
||||
return {
|
||||
'model_name': shared.model_name,
|
||||
'lora_names': shared.lora_names,
|
||||
'loader': shared.args.loader
|
||||
}
|
||||
|
||||
|
||||
def list_models():
|
||||
return {'model_names': get_available_models()}
|
||||
|
||||
|
||||
def list_models_openai_format():
|
||||
"""Returns model list in OpenAI API format"""
|
||||
if shared.model_name and shared.model_name != 'None':
|
||||
data = [model_info_dict(shared.model_name)]
|
||||
else:
|
||||
data = []
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data
|
||||
}
|
||||
|
||||
|
||||
def model_info_dict(model_name: str) -> dict:
|
||||
return {
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "user"
|
||||
}
|
||||
|
||||
|
||||
def _load_model(data):
|
||||
model_name = data["model_name"]
|
||||
args = data["args"]
|
||||
settings = data["settings"]
|
||||
|
||||
unload_model()
|
||||
model_settings = get_model_metadata(model_name)
|
||||
update_model_parameters(model_settings)
|
||||
|
||||
# Update shared.args with custom model loading settings
|
||||
# Security: only allow keys that correspond to model loading
|
||||
# parameters exposed in the UI. Never allow security-sensitive
|
||||
# flags like trust_remote_code or extra_flags to be set via the API.
|
||||
blocked_keys = {'extra_flags'}
|
||||
allowed_keys = set(loaders.list_model_elements()) - blocked_keys
|
||||
if args:
|
||||
for k in args:
|
||||
if k in allowed_keys and hasattr(shared.args, k):
|
||||
setattr(shared.args, k, args[k])
|
||||
|
||||
shared.model, shared.tokenizer = load_model(model_name)
|
||||
|
||||
# Update shared.settings with custom generation defaults
|
||||
if settings:
|
||||
for k in settings:
|
||||
if k in shared.settings:
|
||||
shared.settings[k] = settings[k]
|
||||
if k == 'truncation_length':
|
||||
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
|
||||
elif k == 'instruction_template':
|
||||
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
|
||||
|
||||
|
||||
def list_loras():
|
||||
return {'lora_names': get_available_loras()[1:]}
|
||||
|
||||
|
||||
def load_loras(lora_names):
|
||||
add_lora_to_model(lora_names)
|
||||
|
||||
|
||||
def unload_all_loras():
|
||||
add_lora_to_model([])
|
||||
69
modules/api/moderations.py
Normal file
69
modules/api/moderations.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import time
|
||||
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
|
||||
from .embeddings import get_embeddings
|
||||
|
||||
moderations_disabled = False # return 0/false
|
||||
category_embeddings = None
|
||||
antonym_embeddings = None
|
||||
categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"]
|
||||
flag_threshold = 0.5
|
||||
|
||||
|
||||
def get_category_embeddings() -> dict:
|
||||
global category_embeddings, categories
|
||||
if category_embeddings is None:
|
||||
embeddings = get_embeddings(categories).tolist()
|
||||
category_embeddings = dict(zip(categories, embeddings))
|
||||
|
||||
return category_embeddings
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
return np.dot(a, b) / (norm(a) * norm(b))
|
||||
|
||||
|
||||
# seems most openai like with all-mpnet-base-v2
|
||||
def mod_score(a: np.ndarray, b: np.ndarray) -> float:
|
||||
return 2.0 * np.dot(a, b)
|
||||
|
||||
|
||||
def moderations(input):
|
||||
global category_embeddings, categories, flag_threshold, moderations_disabled
|
||||
results = {
|
||||
"id": f"modr-{int(time.time()*1e9)}",
|
||||
"model": "text-moderation-001",
|
||||
"results": [],
|
||||
}
|
||||
|
||||
if moderations_disabled:
|
||||
results['results'] = [{
|
||||
'categories': dict([(C, False) for C in categories]),
|
||||
'category_scores': dict([(C, 0.0) for C in categories]),
|
||||
'flagged': False,
|
||||
}]
|
||||
return results
|
||||
|
||||
category_embeddings = get_category_embeddings()
|
||||
|
||||
# input, string or array
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
for in_str in input:
|
||||
for ine in get_embeddings([in_str]):
|
||||
category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
|
||||
category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
|
||||
flagged = any(category_flags.values())
|
||||
|
||||
results['results'].extend([{
|
||||
'flagged': flagged,
|
||||
'categories': category_flags,
|
||||
'category_scores': category_scores,
|
||||
}])
|
||||
|
||||
print(results)
|
||||
|
||||
return results
|
||||
509
modules/api/script.py
Normal file
509
modules/api/script.py
Normal file
|
|
@ -0,0 +1,509 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import traceback
|
||||
from collections import deque
|
||||
from threading import Thread
|
||||
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
import modules.api.completions as OAIcompletions
|
||||
import modules.api.logits as OAIlogits
|
||||
import modules.api.models as OAImodels
|
||||
from .tokens import token_count, token_decode, token_encode
|
||||
from .errors import OpenAIError
|
||||
from .utils import _start_cloudflared
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import unload_model
|
||||
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation
|
||||
|
||||
from .typing import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatPromptResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DecodeRequest,
|
||||
DecodeResponse,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EncodeRequest,
|
||||
EncodeResponse,
|
||||
ImageGenerationRequest,
|
||||
ImageGenerationResponse,
|
||||
LoadLorasRequest,
|
||||
LoadModelRequest,
|
||||
LogitsRequest,
|
||||
LogitsResponse,
|
||||
LoraListResponse,
|
||||
ModelInfoResponse,
|
||||
ModelListResponse,
|
||||
TokenCountResponse,
|
||||
to_dict
|
||||
)
|
||||
|
||||
|
||||
async def _wait_for_disconnect(request: Request, stop_event: threading.Event):
|
||||
"""Block until the client disconnects, then signal the stop_event."""
|
||||
while True:
|
||||
message = await request.receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
stop_event.set()
|
||||
return
|
||||
|
||||
|
||||
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||
expected_api_key = shared.args.api_key
|
||||
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
def verify_admin_key(authorization: str = Header(None)) -> None:
|
||||
expected_api_key = shared.args.admin_key
|
||||
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
check_key = [Depends(verify_api_key)]
|
||||
check_admin_key = [Depends(verify_admin_key)]
|
||||
|
||||
# Configure CORS settings to allow all origins, methods, and headers
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"]
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(OpenAIError)
|
||||
async def openai_error_handler(request: Request, exc: OpenAIError):
|
||||
error_type = "server_error" if exc.code >= 500 else "invalid_request_error"
|
||||
return JSONResponse(
|
||||
status_code=exc.code,
|
||||
content={"error": {
|
||||
"message": exc.message,
|
||||
"type": error_type,
|
||||
"param": getattr(exc, 'param', None),
|
||||
"code": None
|
||||
}}
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def validate_host_header(request: Request, call_next):
|
||||
# Be strict about only approving access to localhost by default
|
||||
if not (shared.args.listen or shared.args.public_api):
|
||||
host = request.headers.get("host", "").split(":")[0]
|
||||
if host not in ["localhost", "127.0.0.1"]:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Invalid host header"}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@app.options("/", dependencies=check_key)
|
||||
async def options_route():
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
|
||||
@app.post('/v1/completions', response_model=CompletionResponse, dependencies=check_key)
|
||||
async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||
path = request.url.path
|
||||
is_legacy = "/generate" in path
|
||||
|
||||
if request_data.stream:
|
||||
if (request_data.n or 1) > 1:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
|
||||
)
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
async def generator():
|
||||
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
|
||||
try:
|
||||
async for resp in iterate_in_threadpool(response):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
yield {"data": "[DONE]"}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
||||
return EventSourceResponse(generator(), sep="\n") # SSE streaming
|
||||
|
||||
else:
|
||||
stop_event = threading.Event()
|
||||
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
OAIcompletions.completions,
|
||||
to_dict(request_data),
|
||||
is_legacy=is_legacy,
|
||||
stop_event=stop_event
|
||||
)
|
||||
finally:
|
||||
stop_event.set()
|
||||
monitor.cancel()
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key)
|
||||
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
|
||||
path = request.url.path
|
||||
is_legacy = "/generate" in path
|
||||
|
||||
if request_data.stream:
|
||||
stop_event = threading.Event()
|
||||
|
||||
async def generator():
|
||||
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy, stop_event=stop_event)
|
||||
try:
|
||||
async for resp in iterate_in_threadpool(response):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
yield {"data": "[DONE]"}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
||||
return EventSourceResponse(generator(), sep="\n") # SSE streaming
|
||||
|
||||
else:
|
||||
stop_event = threading.Event()
|
||||
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
OAIcompletions.chat_completions,
|
||||
to_dict(request_data),
|
||||
is_legacy=is_legacy,
|
||||
stop_event=stop_event
|
||||
)
|
||||
finally:
|
||||
stop_event.set()
|
||||
monitor.cancel()
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.get("/v1/models", dependencies=check_key)
|
||||
@app.get("/v1/models/{model}", dependencies=check_key)
|
||||
async def handle_models(request: Request):
|
||||
path = request.url.path
|
||||
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
|
||||
|
||||
if is_list:
|
||||
response = OAImodels.list_models_openai_format()
|
||||
else:
|
||||
model_name = path[len('/v1/models/'):]
|
||||
response = OAImodels.model_info_dict(model_name)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.get('/v1/billing/usage', dependencies=check_key)
|
||||
def handle_billing_usage():
|
||||
'''
|
||||
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||
'''
|
||||
return JSONResponse(content={"total_usage": 0})
|
||||
|
||||
|
||||
@app.post('/v1/audio/transcriptions', dependencies=check_key)
|
||||
async def handle_audio_transcription(request: Request):
|
||||
import speech_recognition as sr
|
||||
from pydub import AudioSegment
|
||||
|
||||
r = sr.Recognizer()
|
||||
|
||||
form = await request.form()
|
||||
audio_file = await form["file"].read()
|
||||
audio_data = AudioSegment.from_file(audio_file)
|
||||
|
||||
# Convert AudioSegment to raw data
|
||||
raw_data = audio_data.raw_data
|
||||
|
||||
# Create AudioData object
|
||||
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
||||
whisper_language = form.getvalue('language', None)
|
||||
whisper_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
||||
|
||||
transcription = {"text": ""}
|
||||
|
||||
try:
|
||||
transcription["text"] = r.recognize_whisper(audio_data, language=whisper_language, model=whisper_model)
|
||||
except sr.UnknownValueError:
|
||||
print("Whisper could not understand audio")
|
||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||
except sr.RequestError as e:
|
||||
print("Could not request results from Whisper", e)
|
||||
transcription["text"] = "Whisper could not understand audio RequestError"
|
||||
|
||||
return JSONResponse(content=transcription)
|
||||
|
||||
|
||||
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
|
||||
async def handle_image_generation(request_data: ImageGenerationRequest):
|
||||
import modules.api.images as OAIimages
|
||||
|
||||
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
|
||||
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
||||
import modules.api.embeddings as OAIembeddings
|
||||
|
||||
input = request_data.input
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
|
||||
response = OAIembeddings.embeddings(input, request_data.encoding_format)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/moderations", dependencies=check_key)
|
||||
async def handle_moderations(request: Request):
|
||||
import modules.api.moderations as OAImoderations
|
||||
|
||||
body = await request.json()
|
||||
input = body["input"]
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
response = OAImoderations.moderations(input)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.get("/v1/internal/health", dependencies=check_key)
|
||||
async def handle_health_check():
|
||||
return JSONResponse(content={"status": "ok"})
|
||||
|
||||
|
||||
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
|
||||
async def handle_token_encode(request_data: EncodeRequest):
|
||||
response = token_encode(request_data.text)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
|
||||
async def handle_token_decode(request_data: DecodeRequest):
|
||||
response = token_decode(request_data.tokens)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
|
||||
async def handle_token_count(request_data: EncodeRequest):
|
||||
response = token_count(request_data.text)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key)
|
||||
async def handle_logits(request_data: LogitsRequest):
|
||||
'''
|
||||
Given a prompt, returns the top 50 most likely logits as a dict.
|
||||
The keys are the tokens, and the values are the probabilities.
|
||||
'''
|
||||
response = OAIlogits._get_next_logits(to_dict(request_data))
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post('/v1/internal/chat-prompt', response_model=ChatPromptResponse, dependencies=check_key)
|
||||
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
|
||||
path = request.url.path
|
||||
is_legacy = "/generate" in path
|
||||
generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
|
||||
response = deque(generator, maxlen=1).pop()
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/stop-generation", dependencies=check_key)
|
||||
async def handle_stop_generation(request: Request):
|
||||
stop_everything_event()
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
|
||||
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key)
|
||||
async def handle_model_info():
|
||||
payload = OAImodels.get_current_model_info()
|
||||
return JSONResponse(content=payload)
|
||||
|
||||
|
||||
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key)
|
||||
async def handle_list_models():
|
||||
payload = OAImodels.list_models()
|
||||
return JSONResponse(content=payload)
|
||||
|
||||
|
||||
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
|
||||
async def handle_load_model(request_data: LoadModelRequest):
|
||||
'''
|
||||
This endpoint is experimental and may change in the future.
|
||||
|
||||
The "args" parameter can be used to modify flags like "--load-in-4bit"
|
||||
or "--n-gpu-layers" before loading a model. Example:
|
||||
|
||||
```
|
||||
"args": {
|
||||
"load_in_4bit": true,
|
||||
"n_gpu_layers": 12
|
||||
}
|
||||
```
|
||||
|
||||
Note that those settings will remain after loading the model. So you
|
||||
may need to change them back to load a second model.
|
||||
|
||||
The "settings" parameter is also a dict but with keys for the
|
||||
shared.settings object. It can be used to modify the default instruction
|
||||
template like this:
|
||||
|
||||
```
|
||||
"settings": {
|
||||
"instruction_template": "Alpaca"
|
||||
}
|
||||
```
|
||||
'''
|
||||
|
||||
try:
|
||||
OAImodels._load_model(to_dict(request_data))
|
||||
return JSONResponse(content="OK")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=400, detail="Failed to load the model.")
|
||||
|
||||
|
||||
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
||||
async def handle_unload_model():
|
||||
unload_model()
|
||||
|
||||
|
||||
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
||||
async def handle_list_loras():
|
||||
response = OAImodels.list_loras()
|
||||
return JSONResponse(content=response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
|
||||
async def handle_load_loras(request_data: LoadLorasRequest):
|
||||
try:
|
||||
OAImodels.load_loras(request_data.lora_names)
|
||||
return JSONResponse(content="OK")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
|
||||
|
||||
|
||||
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
||||
async def handle_unload_loras():
|
||||
OAImodels.unload_all_loras()
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
|
||||
def find_available_port(starting_port):
|
||||
"""Try the starting port, then find an available one if it's taken."""
|
||||
try:
|
||||
# Try to create a socket with the starting port
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', starting_port))
|
||||
return starting_port
|
||||
except OSError:
|
||||
# Port is already in use, so find a new one
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0)) # Bind to port 0 to get an available port
|
||||
new_port = s.getsockname()[1]
|
||||
logger.warning(f"Port {starting_port} is already in use. Using port {new_port} instead.")
|
||||
return new_port
|
||||
|
||||
|
||||
def run_server():
|
||||
# Parse configuration
|
||||
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
||||
port = find_available_port(port)
|
||||
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
||||
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
||||
|
||||
# In the server configuration:
|
||||
server_addrs = []
|
||||
if shared.args.listen and shared.args.listen_host:
|
||||
server_addrs.append(shared.args.listen_host)
|
||||
else:
|
||||
if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6):
|
||||
server_addrs.append('[::]' if shared.args.listen else '[::1]')
|
||||
if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4):
|
||||
server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1')
|
||||
|
||||
if not server_addrs:
|
||||
raise Exception('you MUST enable IPv6 or IPv4 for the API to work')
|
||||
|
||||
# Log server information
|
||||
if shared.args.public_api:
|
||||
_start_cloudflared(
|
||||
port,
|
||||
shared.args.public_api_id,
|
||||
max_attempts=3,
|
||||
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}/v1\n')
|
||||
)
|
||||
else:
|
||||
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
|
||||
urls = [f'{url_proto}{addr}:{port}/v1' for addr in server_addrs]
|
||||
if len(urls) > 1:
|
||||
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
|
||||
else:
|
||||
logger.info('OpenAI-compatible API URL:\n\n' + '\n'.join(urls) + '\n')
|
||||
|
||||
# Log API keys
|
||||
if shared.args.api_key:
|
||||
if not shared.args.admin_key:
|
||||
shared.args.admin_key = shared.args.api_key
|
||||
|
||||
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
||||
|
||||
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
|
||||
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
|
||||
|
||||
# Start server
|
||||
logging.getLogger("uvicorn.error").propagate = False
|
||||
uvicorn.run(app, host=server_addrs, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile, access_log=False)
|
||||
|
||||
|
||||
_server_started = False
|
||||
|
||||
|
||||
def setup():
|
||||
global _server_started
|
||||
if _server_started:
|
||||
return
|
||||
|
||||
_server_started = True
|
||||
if shared.args.nowebui:
|
||||
run_server()
|
||||
else:
|
||||
Thread(target=run_server, daemon=True).start()
|
||||
26
modules/api/tokens.py
Normal file
26
modules/api/tokens.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from modules.text_generation import decode, encode
|
||||
|
||||
|
||||
def token_count(prompt):
|
||||
tokens = encode(prompt)[0]
|
||||
return {
|
||||
'length': len(tokens)
|
||||
}
|
||||
|
||||
|
||||
def token_encode(input):
|
||||
tokens = encode(input)[0]
|
||||
if tokens.__class__.__name__ in ['Tensor', 'ndarray']:
|
||||
tokens = tokens.tolist()
|
||||
|
||||
return {
|
||||
'tokens': tokens,
|
||||
'length': len(tokens),
|
||||
}
|
||||
|
||||
|
||||
def token_decode(tokens):
|
||||
output = decode(tokens)
|
||||
return {
|
||||
'text': output
|
||||
}
|
||||
326
modules/api/typing.py
Normal file
326
modules/api/typing.py
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator, validator
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
class GenerationOptions(BaseModel):
|
||||
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/user_data/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
|
||||
dynatemp_low: float = shared.args.dynatemp_low
|
||||
dynatemp_high: float = shared.args.dynatemp_high
|
||||
dynatemp_exponent: float = shared.args.dynatemp_exponent
|
||||
smoothing_factor: float = shared.args.smoothing_factor
|
||||
smoothing_curve: float = shared.args.smoothing_curve
|
||||
min_p: float = shared.args.min_p
|
||||
top_k: int = shared.args.top_k
|
||||
typical_p: float = shared.args.typical_p
|
||||
xtc_threshold: float = shared.args.xtc_threshold
|
||||
xtc_probability: float = shared.args.xtc_probability
|
||||
epsilon_cutoff: float = shared.args.epsilon_cutoff
|
||||
eta_cutoff: float = shared.args.eta_cutoff
|
||||
tfs: float = shared.args.tfs
|
||||
top_a: float = shared.args.top_a
|
||||
top_n_sigma: float = shared.args.top_n_sigma
|
||||
adaptive_target: float = shared.args.adaptive_target
|
||||
adaptive_decay: float = shared.args.adaptive_decay
|
||||
dry_multiplier: float = shared.args.dry_multiplier
|
||||
dry_allowed_length: int = shared.args.dry_allowed_length
|
||||
dry_base: float = shared.args.dry_base
|
||||
repetition_penalty: float = shared.args.repetition_penalty
|
||||
encoder_repetition_penalty: float = shared.args.encoder_repetition_penalty
|
||||
no_repeat_ngram_size: int = shared.args.no_repeat_ngram_size
|
||||
repetition_penalty_range: int = shared.args.repetition_penalty_range
|
||||
penalty_alpha: float = shared.args.penalty_alpha
|
||||
guidance_scale: float = shared.args.guidance_scale
|
||||
mirostat_mode: int = shared.args.mirostat_mode
|
||||
mirostat_tau: float = shared.args.mirostat_tau
|
||||
mirostat_eta: float = shared.args.mirostat_eta
|
||||
prompt_lookup_num_tokens: int = 0
|
||||
max_tokens_second: int = 0
|
||||
do_sample: bool = shared.args.do_sample
|
||||
dynamic_temperature: bool = shared.args.dynamic_temperature
|
||||
temperature_last: bool = shared.args.temperature_last
|
||||
auto_max_new_tokens: bool = False
|
||||
ban_eos_token: bool = False
|
||||
add_bos_token: bool = True
|
||||
enable_thinking: bool = shared.args.enable_thinking
|
||||
reasoning_effort: str = shared.args.reasoning_effort
|
||||
skip_special_tokens: bool = True
|
||||
static_cache: bool = False
|
||||
truncation_length: int = 0
|
||||
seed: int = -1
|
||||
sampler_priority: List[str] | str | None = Field(default=shared.args.sampler_priority, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
|
||||
custom_token_bans: str = ""
|
||||
negative_prompt: str = ''
|
||||
dry_sequence_breakers: str = shared.args.dry_sequence_breakers
|
||||
grammar_string: str = ""
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
function: 'ToolFunction'
|
||||
type: str
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
model_config = ConfigDict(extra='allow')
|
||||
description: Optional[str] = None
|
||||
name: str
|
||||
parameters: Optional['ToolParameters'] = None
|
||||
|
||||
|
||||
class ToolParameters(BaseModel):
|
||||
model_config = ConfigDict(extra='allow')
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
required: Optional[list[str]] = None
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: Optional[str] = None
|
||||
parameters: Optional[str] = None
|
||||
|
||||
@validator('arguments', allow_reuse=True)
|
||||
def checkPropertyArgsOrParams(cls, v, values, **kwargs):
|
||||
if not v and not values.get('parameters'):
|
||||
raise ValueError("At least one of 'arguments' or 'parameters' must be provided as property in FunctionCall type")
|
||||
return v
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
index: int
|
||||
type: str
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: bool | None = False
|
||||
|
||||
|
||||
class CompletionRequestParams(BaseModel):
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
|
||||
messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.")
|
||||
best_of: int | None = Field(default=1, description="Unused parameter.")
|
||||
echo: bool | None = False
|
||||
frequency_penalty: float | None = shared.args.frequency_penalty
|
||||
logit_bias: dict | None = None
|
||||
logprobs: int | None = None
|
||||
max_tokens: int | None = 512
|
||||
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.")
|
||||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
suffix: str | None = None
|
||||
temperature: float | None = shared.args.temperature
|
||||
top_p: float | None = shared.args.top_p
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_prompt_or_messages(self):
|
||||
if self.prompt is None and self.messages is None:
|
||||
raise ValueError("Either 'prompt' or 'messages' must be provided")
|
||||
return self
|
||||
|
||||
|
||||
class CompletionRequest(GenerationOptions, CompletionRequestParams):
|
||||
pass
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
choices: List[dict]
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
object: str = "text_completion"
|
||||
usage: dict
|
||||
|
||||
|
||||
class ChatCompletionRequestParams(BaseModel):
|
||||
messages: List[dict]
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||
frequency_penalty: float | None = shared.args.frequency_penalty
|
||||
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
||||
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
||||
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
||||
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
|
||||
logit_bias: dict | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
max_tokens: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
n: int | None = Field(default=1, description="Unused parameter.")
|
||||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = shared.args.temperature
|
||||
top_p: float | None = shared.args.top_p
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
@model_validator(mode='after')
|
||||
def resolve_max_tokens(self):
|
||||
if self.max_tokens is None and self.max_completion_tokens is not None:
|
||||
self.max_tokens = self.max_completion_tokens
|
||||
return self
|
||||
|
||||
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
||||
|
||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
|
||||
instruction_template_str: str | None = Field(default=None, description="A Jinja2 instruction template. If set, will take precedence over everything else.")
|
||||
|
||||
character: str | None = Field(default=None, description="A character defined under text-generation-webui/user_data/characters. If not set, the default \"Assistant\" character will be used.")
|
||||
bot_name: str | None = Field(default=None, description="Overwrites the value set by character field.", alias="name2")
|
||||
context: str | None = Field(default=None, description="Overwrites the value set by character field.")
|
||||
greeting: str | None = Field(default=None, description="Overwrites the value set by character field.")
|
||||
user_name: str | None = Field(default=None, description="Your name (the user). By default, it's \"You\".", alias="name1")
|
||||
user_bio: str | None = Field(default=None, description="The user description/personality.")
|
||||
chat_template_str: str | None = Field(default=None, description="Jinja2 template for chat.")
|
||||
|
||||
chat_instruct_command: str | None = "Continue the chat dialogue below. Write a single reply for the character \"<|character|>\".\n\n<|prompt|>"
|
||||
|
||||
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
|
||||
|
||||
|
||||
class ChatCompletionRequest(GenerationOptions, ChatCompletionRequestParams):
|
||||
pass
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
choices: List[dict]
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
object: str = "chat.completion"
|
||||
usage: dict
|
||||
|
||||
|
||||
class ChatPromptResponse(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: str | List[str] | List[int] | List[List[int]]
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||
encoding_format: str = Field(default="float", description="Can be float or base64.")
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
index: int
|
||||
embedding: List[float]
|
||||
object: str = "embedding"
|
||||
|
||||
|
||||
class EncodeRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class EncodeResponse(BaseModel):
|
||||
tokens: List[int]
|
||||
length: int
|
||||
|
||||
|
||||
class DecodeRequest(BaseModel):
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
class DecodeResponse(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
length: int
|
||||
|
||||
|
||||
class LogitsRequestParams(BaseModel):
|
||||
prompt: str
|
||||
use_samplers: bool = False
|
||||
top_logits: int | None = 50
|
||||
frequency_penalty: float | None = shared.args.frequency_penalty
|
||||
max_tokens: int | None = 512
|
||||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
temperature: float | None = shared.args.temperature
|
||||
top_p: float | None = shared.args.top_p
|
||||
|
||||
|
||||
class LogitsRequest(GenerationOptions, LogitsRequestParams):
|
||||
pass
|
||||
|
||||
|
||||
class LogitsResponse(BaseModel):
|
||||
logits: Dict[str, float]
|
||||
|
||||
|
||||
class ModelInfoResponse(BaseModel):
|
||||
model_name: str
|
||||
lora_names: List[str]
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
model_names: List[str]
|
||||
|
||||
|
||||
class LoadModelRequest(BaseModel):
|
||||
model_name: str
|
||||
args: dict | None = None
|
||||
settings: dict | None = None
|
||||
|
||||
|
||||
class LoraListResponse(BaseModel):
|
||||
lora_names: List[str]
|
||||
|
||||
|
||||
class LoadLorasRequest(BaseModel):
|
||||
lora_names: List[str]
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
"""Image-specific parameters for generation."""
|
||||
prompt: str
|
||||
negative_prompt: str = ""
|
||||
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
|
||||
steps: int = Field(default=9, ge=1)
|
||||
cfg_scale: float = Field(default=0.0, ge=0.0)
|
||||
image_seed: int = Field(default=-1, description="-1 for random")
|
||||
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
|
||||
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
|
||||
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
|
||||
|
||||
# OpenAI compatibility (unused)
|
||||
model: str | None = None
|
||||
response_format: str = "b64_json"
|
||||
user: str | None = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def resolve_batch_size(self):
|
||||
if self.batch_size is None:
|
||||
self.batch_size = self.n
|
||||
return self
|
||||
|
||||
def get_width_height(self) -> tuple[int, int]:
|
||||
try:
|
||||
parts = self.size.lower().split('x')
|
||||
return int(parts[0]), int(parts[1])
|
||||
except (ValueError, IndexError):
|
||||
return 1024, 1024
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
data: List[dict]
|
||||
|
||||
|
||||
def to_json(obj):
|
||||
return json.dumps(obj.__dict__, indent=4)
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
return obj.__dict__
|
||||
53
modules/api/utils.py
Normal file
53
modules/api/utils.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import base64
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
# Convert the list to a float32 array that the OpenAPI client expects
|
||||
# float_array = np.array(float_list, dtype="float32")
|
||||
|
||||
# Get raw bytes
|
||||
bytes_array = float_array.tobytes()
|
||||
|
||||
# Encode bytes into base64
|
||||
encoded_bytes = base64.b64encode(bytes_array)
|
||||
|
||||
# Turn raw base64 encoded bytes into ASCII
|
||||
ascii_string = encoded_bytes.decode('ascii')
|
||||
return ascii_string
|
||||
|
||||
|
||||
def debug_msg(*args, **kwargs):
|
||||
if os.environ.get("OPENEDAI_DEBUG", 0):
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
raise Exception(
|
||||
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
||||
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
if tunnel_id is not None:
|
||||
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
|
||||
else:
|
||||
public_url = _run_cloudflared(port, port + 1)
|
||||
|
||||
if on_start:
|
||||
on_start(public_url)
|
||||
|
||||
return
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
time.sleep(3)
|
||||
|
||||
raise Exception('Could not start cloudflared.')
|
||||
|
|
@ -32,8 +32,7 @@ def load_extensions():
|
|||
if name not in available_extensions:
|
||||
continue
|
||||
|
||||
if name != 'api':
|
||||
logger.info(f'Loading the extension "{name}"')
|
||||
logger.info(f'Loading the extension "{name}"')
|
||||
|
||||
try:
|
||||
# Prefer user extension, fall back to system extension
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ group.add_argument('--portable', action='store_true', help='Hide features not av
|
|||
|
||||
# API
|
||||
group = parser.add_argument_group('API')
|
||||
group.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
group.add_argument('--api', action='store_true', help='Enable the API server.')
|
||||
group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.')
|
||||
group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
||||
group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
||||
|
|
@ -435,16 +435,6 @@ def fix_loader_name(name):
|
|||
return 'TensorRT-LLM'
|
||||
|
||||
|
||||
def add_extension(name, last=False):
|
||||
if args.extensions is None:
|
||||
args.extensions = [name]
|
||||
elif last:
|
||||
args.extensions = [x for x in args.extensions if x != name]
|
||||
args.extensions.append(name)
|
||||
elif name not in args.extensions:
|
||||
args.extensions.append(name)
|
||||
|
||||
|
||||
def is_chat():
|
||||
return True
|
||||
|
||||
|
|
@ -464,10 +454,6 @@ def load_user_config():
|
|||
|
||||
args.loader = fix_loader_name(args.loader)
|
||||
|
||||
# Activate the API extension
|
||||
if args.api or args.public_api:
|
||||
add_extension('openai', last=True)
|
||||
|
||||
# Load model-specific settings
|
||||
p = Path(f'{args.model_dir}/config.yaml')
|
||||
if p.exists():
|
||||
|
|
|
|||
|
|
@ -95,8 +95,6 @@ def set_interface_arguments(extensions, bool_active):
|
|||
setattr(shared.args, k, False)
|
||||
for k in bool_active:
|
||||
setattr(shared.args, k, True)
|
||||
if k == 'api':
|
||||
shared.add_extension('openai', last=True)
|
||||
|
||||
shared.need_restart = True
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue