mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-09 15:13:56 +01:00
Add --chat-template-file flag to override the default instruction template for API requests
Matches llama.cpp's flag name. Supports .jinja, .jinja2, and .yaml files. Priority: per-request params > --chat-template-file > model's built-in template.
This commit is contained in:
parent
3531069824
commit
f5acf55207
|
|
@ -1,9 +1,12 @@
|
|||
import copy
|
||||
import functools
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import tiktoken
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from extensions.openai.errors import InvalidRequestError
|
||||
|
|
@ -22,6 +25,18 @@ from modules.presets import load_preset_memoized
|
|||
from modules.text_generation import decode, encode, generate_reply
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_chat_template_file(filepath):
|
||||
"""Load a chat template from a file path (.jinja, .jinja2, or .yaml/.yml)."""
|
||||
filepath = Path(filepath)
|
||||
ext = filepath.suffix.lower()
|
||||
text = filepath.read_text(encoding='utf-8')
|
||||
if ext in ['.yaml', '.yml']:
|
||||
data = yaml.safe_load(text)
|
||||
return data.get('instruction_template', '')
|
||||
return text
|
||||
|
||||
|
||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
# more problems than it's worth.
|
||||
# try:
|
||||
|
|
@ -234,6 +249,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
instruction_template = body['instruction_template']
|
||||
instruction_template = "Alpaca" if instruction_template == "None" else instruction_template
|
||||
instruction_template_str = load_instruction_template_memoized(instruction_template)
|
||||
elif shared.args.chat_template_file:
|
||||
instruction_template_str = load_chat_template_file(shared.args.chat_template_file)
|
||||
else:
|
||||
instruction_template_str = shared.settings['instruction_template_str']
|
||||
|
||||
|
|
|
|||
|
|
@ -215,6 +215,7 @@ group.add_argument('--sampler-priority', type=str, default=_d['sampler_priority'
|
|||
group.add_argument('--dry-sequence-breakers', type=str, default=_d['dry_sequence_breakers'], metavar='N', help='DRY sequence breakers')
|
||||
group.add_argument('--enable-thinking', action=argparse.BooleanOptionalAction, default=True, help='Enable thinking')
|
||||
group.add_argument('--reasoning-effort', type=str, default='medium', metavar='N', help='Reasoning effort')
|
||||
group.add_argument('--chat-template-file', type=str, default=None, help='Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model\'s built-in template.')
|
||||
|
||||
# Handle CMD_FLAGS.txt
|
||||
cmd_flags_path = user_data_dir / "CMD_FLAGS.txt"
|
||||
|
|
@ -376,6 +377,11 @@ default_settings = copy.deepcopy(settings)
|
|||
|
||||
|
||||
def do_cmd_flags_warnings():
|
||||
# Validate --chat-template-file
|
||||
if args.chat_template_file and not Path(args.chat_template_file).is_file():
|
||||
logger.error(f"--chat-template-file: file not found: {args.chat_template_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Security warnings
|
||||
if args.trust_remote_code:
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ def clean_path(base_path: str, path: str):
|
|||
def get_instruction_templates():
|
||||
path = shared.user_data_dir / 'instruction-templates'
|
||||
names = set()
|
||||
for ext in ['yaml', 'yml', 'jinja']:
|
||||
for ext in ['yaml', 'yml', 'jinja', 'jinja2']:
|
||||
for f in path.glob(f'*.{ext}'):
|
||||
names.add(f.stem)
|
||||
return ['None', 'Chat Template'] + sorted(names, key=utils.natural_keys)
|
||||
|
|
@ -235,10 +235,10 @@ def get_instruction_templates():
|
|||
def load_template(name):
|
||||
"""Load a Jinja2 template string from {user_data_dir}/instruction-templates/."""
|
||||
path = shared.user_data_dir / 'instruction-templates'
|
||||
for ext in ['jinja', 'yaml', 'yml']:
|
||||
for ext in ['jinja', 'jinja2', 'yaml', 'yml']:
|
||||
filepath = path / f'{name}.{ext}'
|
||||
if filepath.exists():
|
||||
if ext == 'jinja':
|
||||
if ext in ['jinja', 'jinja2']:
|
||||
return filepath.read_text(encoding='utf-8')
|
||||
else:
|
||||
data = yaml.safe_load(filepath.read_text(encoding='utf-8'))
|
||||
|
|
|
|||
Loading…
Reference in a new issue