mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 07:03:37 +00:00
API: Move OpenAI-compatible API from extensions/openai to modules/api
This commit is contained in:
parent
2e4232e02b
commit
bf6fbc019d
23 changed files with 51 additions and 65 deletions
85
modules/api/models.py
Normal file
85
modules/api/models.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
from modules import loaders, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.utils import get_available_loras, get_available_models
|
||||
|
||||
|
||||
def get_current_model_info():
|
||||
return {
|
||||
'model_name': shared.model_name,
|
||||
'lora_names': shared.lora_names,
|
||||
'loader': shared.args.loader
|
||||
}
|
||||
|
||||
|
||||
def list_models():
|
||||
return {'model_names': get_available_models()}
|
||||
|
||||
|
||||
def list_models_openai_format():
|
||||
"""Returns model list in OpenAI API format"""
|
||||
if shared.model_name and shared.model_name != 'None':
|
||||
data = [model_info_dict(shared.model_name)]
|
||||
else:
|
||||
data = []
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data
|
||||
}
|
||||
|
||||
|
||||
def model_info_dict(model_name: str) -> dict:
|
||||
return {
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "user"
|
||||
}
|
||||
|
||||
|
||||
def _load_model(data):
|
||||
model_name = data["model_name"]
|
||||
args = data["args"]
|
||||
settings = data["settings"]
|
||||
|
||||
unload_model()
|
||||
model_settings = get_model_metadata(model_name)
|
||||
update_model_parameters(model_settings)
|
||||
|
||||
# Update shared.args with custom model loading settings
|
||||
# Security: only allow keys that correspond to model loading
|
||||
# parameters exposed in the UI. Never allow security-sensitive
|
||||
# flags like trust_remote_code or extra_flags to be set via the API.
|
||||
blocked_keys = {'extra_flags'}
|
||||
allowed_keys = set(loaders.list_model_elements()) - blocked_keys
|
||||
if args:
|
||||
for k in args:
|
||||
if k in allowed_keys and hasattr(shared.args, k):
|
||||
setattr(shared.args, k, args[k])
|
||||
|
||||
shared.model, shared.tokenizer = load_model(model_name)
|
||||
|
||||
# Update shared.settings with custom generation defaults
|
||||
if settings:
|
||||
for k in settings:
|
||||
if k in shared.settings:
|
||||
shared.settings[k] = settings[k]
|
||||
if k == 'truncation_length':
|
||||
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
|
||||
elif k == 'instruction_template':
|
||||
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
|
||||
|
||||
|
||||
def list_loras():
|
||||
return {'lora_names': get_available_loras()[1:]}
|
||||
|
||||
|
||||
def load_loras(lora_names):
|
||||
add_lora_to_model(lora_names)
|
||||
|
||||
|
||||
def unload_all_loras():
|
||||
add_lora_to_model([])
|
||||
Loading…
Add table
Add a link
Reference in a new issue