diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 46502bdc..cabfce99 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -1,9 +1,12 @@ import copy +import functools import json import time from collections import deque +from pathlib import Path import tiktoken +import yaml from pydantic import ValidationError from extensions.openai.errors import InvalidRequestError @@ -22,6 +25,18 @@ from modules.presets import load_preset_memoized from modules.text_generation import decode, encode, generate_reply +@functools.cache +def load_chat_template_file(filepath): + """Load a chat template from a file path (.jinja, .jinja2, or .yaml/.yml).""" + filepath = Path(filepath) + ext = filepath.suffix.lower() + text = filepath.read_text(encoding='utf-8') + if ext in ['.yaml', '.yml']: + data = yaml.safe_load(text) + return data.get('instruction_template', '') + return text + + def convert_logprobs_to_tiktoken(model, logprobs): # more problems than it's worth. # try: @@ -234,6 +249,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p instruction_template = body['instruction_template'] instruction_template = "Alpaca" if instruction_template == "None" else instruction_template instruction_template_str = load_instruction_template_memoized(instruction_template) + elif shared.args.chat_template_file: + instruction_template_str = load_chat_template_file(shared.args.chat_template_file) else: instruction_template_str = shared.settings['instruction_template_str'] diff --git a/modules/shared.py b/modules/shared.py index de0820af..bc7ea8ba 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -215,6 +215,7 @@ group.add_argument('--sampler-priority', type=str, default=_d['sampler_priority' group.add_argument('--dry-sequence-breakers', type=str, default=_d['dry_sequence_breakers'], metavar='N', help='DRY sequence breakers') group.add_argument('--enable-thinking', action=argparse.BooleanOptionalAction, default=True, help='Enable thinking') group.add_argument('--reasoning-effort', type=str, default='medium', metavar='N', help='Reasoning effort') +group.add_argument('--chat-template-file', type=str, default=None, help='Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model\'s built-in template.') # Handle CMD_FLAGS.txt cmd_flags_path = user_data_dir / "CMD_FLAGS.txt" @@ -376,6 +377,11 @@ default_settings = copy.deepcopy(settings) def do_cmd_flags_warnings(): + # Validate --chat-template-file + if args.chat_template_file and not Path(args.chat_template_file).is_file(): + logger.error(f"--chat-template-file: file not found: {args.chat_template_file}") + sys.exit(1) + # Security warnings if args.trust_remote_code: logger.warning( diff --git a/modules/training.py b/modules/training.py index c9f32e64..87539461 100644 --- a/modules/training.py +++ b/modules/training.py @@ -226,7 +226,7 @@ def clean_path(base_path: str, path: str): def get_instruction_templates(): path = shared.user_data_dir / 'instruction-templates' names = set() - for ext in ['yaml', 'yml', 'jinja']: + for ext in ['yaml', 'yml', 'jinja', 'jinja2']: for f in path.glob(f'*.{ext}'): names.add(f.stem) return ['None', 'Chat Template'] + sorted(names, key=utils.natural_keys) @@ -235,10 +235,10 @@ def get_instruction_templates(): def load_template(name): """Load a Jinja2 template string from {user_data_dir}/instruction-templates/.""" path = shared.user_data_dir / 'instruction-templates' - for ext in ['jinja', 'yaml', 'yml']: + for ext in ['jinja', 'jinja2', 'yaml', 'yml']: filepath = path / f'{name}.{ext}' if filepath.exists(): - if ext == 'jinja': + if ext in ['jinja', 'jinja2']: return filepath.read_text(encoding='utf-8') else: data = yaml.safe_load(filepath.read_text(encoding='utf-8'))