mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-07 15:43:49 +00:00
API: add instruction_template support to the model load endpoint
This commit is contained in:
parent
4d6230a944
commit
c26ffdd24c
5 changed files with 38 additions and 7 deletions
|
|
@ -1,7 +1,8 @@
|
|||
from modules import loaders, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.models_settings import get_model_metadata, load_instruction_template, update_model_parameters
|
||||
from modules.utils import get_available_loras, get_available_models
|
||||
|
||||
|
||||
|
|
@ -69,6 +70,13 @@ def _load_model(data):
|
|||
|
||||
shared.model, shared.tokenizer = load_model(model_name)
|
||||
|
||||
if data.get("instruction_template_str") is not None:
|
||||
shared.settings['instruction_template_str'] = data["instruction_template_str"]
|
||||
logger.info("INSTRUCTION TEMPLATE: set to custom Jinja2 string")
|
||||
elif data.get("instruction_template") is not None:
|
||||
shared.settings['instruction_template_str'] = load_instruction_template(data["instruction_template"])
|
||||
logger.info(f"INSTRUCTION TEMPLATE: {data['instruction_template']}")
|
||||
|
||||
|
||||
def list_loras():
|
||||
return {'lora_names': get_available_loras()[1:]}
|
||||
|
|
|
|||
|
|
@ -487,6 +487,11 @@ async def handle_load_model(request_data: LoadModelRequest):
|
|||
|
||||
Loader args are reset to their startup defaults between loads, so
|
||||
settings from a previous load do not leak into the next one.
|
||||
|
||||
The "instruction_template" parameter sets the default instruction
|
||||
template by name (from user_data/instruction-templates/). The
|
||||
"instruction_template_str" parameter sets it as a raw Jinja2 string
|
||||
and takes precedence over "instruction_template".
|
||||
'''
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -271,6 +271,8 @@ class ModelListResponse(BaseModel):
|
|||
class LoadModelRequest(BaseModel):
|
||||
model_name: str
|
||||
args: dict | None = None
|
||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. Sets the default template for all subsequent API requests.")
|
||||
instruction_template_str: str | None = Field(default=None, description="A Jinja2 instruction template string. If set, takes precedence over instruction_template.")
|
||||
|
||||
|
||||
class LoraListResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -400,14 +400,19 @@ def load_instruction_template(template):
|
|||
if template == 'None':
|
||||
return ''
|
||||
|
||||
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
|
||||
if filepath.exists():
|
||||
break
|
||||
for name in (template, 'Alpaca'):
|
||||
path = shared.user_data_dir / 'instruction-templates' / f'{name}.yaml'
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
file_contents = f.read()
|
||||
except FileNotFoundError:
|
||||
if name == template:
|
||||
logger.warning(f"Instruction template '{template}' not found, falling back to Alpaca")
|
||||
continue
|
||||
|
||||
break
|
||||
else:
|
||||
return ''
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
file_contents = f.read()
|
||||
data = yaml.safe_load(file_contents)
|
||||
if 'instruction_template' in data:
|
||||
return data['instruction_template']
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue