diff --git a/docs/12 - OpenAI API.md b/docs/12 - OpenAI API.md index 0a076c35..727f6ece 100644 --- a/docs/12 - OpenAI API.md +++ b/docs/12 - OpenAI API.md @@ -232,6 +232,17 @@ curl -k http://127.0.0.1:5000/v1/internal/model/load \ }' ``` +You can also set a default instruction template for all subsequent API requests by passing `instruction_template` (a template name from `user_data/instruction-templates/`) or `instruction_template_str` (a raw Jinja2 string): + +```shell +curl -k http://127.0.0.1:5000/v1/internal/model/load \ + -H "Content-Type: application/json" \ + -d '{ + "model_name": "Qwen_Qwen3-0.6B-Q4_K_M.gguf", + "instruction_template": "Alpaca" + }' +``` + #### Python chat example ```python diff --git a/modules/api/models.py b/modules/api/models.py index 5dd77850..bfcd2c31 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,7 +1,8 @@ from modules import loaders, shared +from modules.logging_colors import logger from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model -from modules.models_settings import get_model_metadata, update_model_parameters +from modules.models_settings import get_model_metadata, load_instruction_template, update_model_parameters from modules.utils import get_available_loras, get_available_models @@ -69,6 +70,13 @@ def _load_model(data): shared.model, shared.tokenizer = load_model(model_name) + if data.get("instruction_template_str") is not None: + shared.settings['instruction_template_str'] = data["instruction_template_str"] + logger.info("INSTRUCTION TEMPLATE: set to custom Jinja2 string") + elif data.get("instruction_template") is not None: + shared.settings['instruction_template_str'] = load_instruction_template(data["instruction_template"]) + logger.info(f"INSTRUCTION TEMPLATE: {data['instruction_template']}") + def list_loras(): return {'lora_names': get_available_loras()[1:]} diff --git a/modules/api/script.py b/modules/api/script.py index e79a1967..1f41d0cd 100644 --- a/modules/api/script.py +++ b/modules/api/script.py @@ -487,6 +487,11 @@ async def handle_load_model(request_data: LoadModelRequest): Loader args are reset to their startup defaults between loads, so settings from a previous load do not leak into the next one. + + The "instruction_template" parameter sets the default instruction + template by name (from user_data/instruction-templates/). The + "instruction_template_str" parameter sets it as a raw Jinja2 string + and takes precedence over "instruction_template". ''' try: diff --git a/modules/api/typing.py b/modules/api/typing.py index a758743e..56d7f2bc 100644 --- a/modules/api/typing.py +++ b/modules/api/typing.py @@ -271,6 +271,8 @@ class ModelListResponse(BaseModel): class LoadModelRequest(BaseModel): model_name: str args: dict | None = None + instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. Sets the default template for all subsequent API requests.") + instruction_template_str: str | None = Field(default=None, description="A Jinja2 instruction template string. If set, takes precedence over instruction_template.") class LoraListResponse(BaseModel): diff --git a/modules/models_settings.py b/modules/models_settings.py index eafa0581..b10d780c 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -400,14 +400,19 @@ def load_instruction_template(template): if template == 'None': return '' - for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']: - if filepath.exists(): - break + for name in (template, 'Alpaca'): + path = shared.user_data_dir / 'instruction-templates' / f'{name}.yaml' + try: + with open(path, 'r', encoding='utf-8') as f: + file_contents = f.read() + except FileNotFoundError: + if name == template: + logger.warning(f"Instruction template '{template}' not found, falling back to Alpaca") + continue + + break else: return '' - - with open(filepath, 'r', encoding='utf-8') as f: - file_contents = f.read() data = yaml.safe_load(file_contents) if 'instruction_template' in data: return data['instruction_template']