Add apply_chat_template() support for LoRA training

- Support multi-turn conversations (OpenAI messages + ShareGPT formats)
- Automatic assistant-only label masking via incremental tokenization
- Use tokenizer.apply_chat_template() for proper special token handling
- Add "Chat Template" option to the Data Format dropdown
- Also accept instruction/output datasets (auto-converted to messages)
- Validate chat template availability and dataset format upfront
- Fix after_tokens[-1] IndexError when train_only_after is at end of prompt
- Update docs
This commit is contained in:
oobabooga 2026-03-05 11:46:45 -03:00
parent b16a1a874a
commit d278bb46a2
2 changed files with 179 additions and 29 deletions

View file

@ -107,8 +107,8 @@ def create_ui():
with gr.Column():
with gr.Tab(label='Formatted Dataset'):
with gr.Row():
format = gr.Dropdown(choices=utils.get_datasets('user_data/training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/formats', 'json')}, 'refresh-button', interactive=not mu)
format = gr.Dropdown(choices=['None', 'Chat Template'] + [x for x in utils.get_datasets('user_data/training/formats', 'json') if x != 'None'], value='None', label='Data Format', info='The format file used to decide how to format the dataset input. "Chat Template" uses the model\'s built-in chat template via apply_chat_template().', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(format, lambda: None, lambda: {'choices': ['None', 'Chat Template'] + [x for x in utils.get_datasets('user_data/training/formats', 'json') if x != 'None']}, 'refresh-button', interactive=not mu)
with gr.Row():
dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
@ -353,7 +353,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
before_tokens = encode(prompt[:ind], True)
after_tokens = encode(prompt[ind:], False)
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
if append_eos_token and len(after_tokens) > 0 and after_tokens[-1] != shared.tokenizer.eos_token_id:
after_tokens.append(shared.tokenizer.eos_token_id)
full_length = len(after_tokens) + len(before_tokens)
@ -371,6 +371,73 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
"attention_mask": [0 if t == shared.tokenizer.pad_token_id else 1 for t in input_ids],
}
def normalize_messages(data_point):
"""Convert a dataset row to OpenAI messages format for apply_chat_template()."""
if "messages" in data_point:
return data_point["messages"]
if "conversations" in data_point:
role_map = {"human": "user", "gpt": "assistant"}
return [
{"role": role_map.get(turn.get("from", ""), turn.get("from", "")), "content": turn["value"]}
for turn in data_point["conversations"]
]
if "instruction" in data_point and "output" in data_point:
messages = []
if data_point.get("system", "").strip():
messages.append({"role": "system", "content": data_point["system"]})
messages.append({"role": "user", "content": data_point["instruction"]})
messages.append({"role": "assistant", "content": data_point["output"]})
return messages
raise RuntimeError(
f'Dataset row must contain "messages", "conversations", or "instruction"/"output" keys. '
f'Found: {list(data_point.keys())}'
)
def tokenize_conversation(data_point):
"""Tokenize using apply_chat_template() with assistant-only label masking."""
messages = normalize_messages(data_point)
full_ids = shared.tokenizer.apply_chat_template(messages, tokenize=True)
# Build labels: -100 for everything, then unmask assistant turns.
# This assumes apply_chat_template(messages[:i]) is a token-for-token
# prefix of apply_chat_template(messages[:i+1]), which holds for all
# standard chat templates (Llama, ChatML, Mistral, etc.).
labels = [-100] * len(full_ids)
for i, msg in enumerate(messages):
if msg["role"] == "assistant":
# Tokens up to where this assistant turn starts
header_ids = shared.tokenizer.apply_chat_template(
messages[:i], tokenize=True, add_generation_prompt=True
)
# Tokens through end of this assistant turn
through_ids = shared.tokenizer.apply_chat_template(
messages[:i + 1], tokenize=True
)
# Unmask assistant tokens
start = len(header_ids)
end = min(len(through_ids), len(full_ids))
labels[start:end] = full_ids[start:end]
# Truncate from the right: keeps the system prompt and early turns
if len(full_ids) > cutoff_len:
full_ids = full_ids[:cutoff_len]
labels = labels[:cutoff_len]
# Left-pad to cutoff_len
pad_len = cutoff_len - len(full_ids)
attention_mask = [0] * pad_len + [1] * len(full_ids)
labels = [-100] * pad_len + labels
input_ids = [shared.tokenizer.pad_token_id] * pad_len + full_ids
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
train_template.clear()
# == Prep the dataset, format, etc ==
@ -437,38 +504,73 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
yield "Missing format choice input, cannot continue."
return
train_template["template_type"] = "dataset"
if format == 'Chat Template':
# Use the model's built-in chat template via apply_chat_template()
if not getattr(shared.tokenizer, 'chat_template', None):
yield "Error: this model's tokenizer does not have a chat template. Use a format file instead, or load an instruct/chat model."
return
with open(clean_path('user_data/training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
format_data: dict[str, str] = json.load(formatFile)
train_template["template_type"] = "chat_template"
# == store training prompt ==
for _, value in format_data.items():
prompt_key = f"template_{len(train_template)}"
train_template[prompt_key] = value
logger.info("Loading JSON dataset with Chat Template format")
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
for key, val in data_point.items():
if type(val) is str:
data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
# Validate the first row
try:
normalize_messages(data['train'][0])
except (RuntimeError, KeyError, IndexError) as e:
yield f"Error: {e}"
return
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt, add_eos_token)
train_data = data['train'].map(
tokenize_conversation,
remove_columns=data['train'].column_names,
new_fingerprint='%030x' % random.randrange(16**30)
)
logger.info("Loading JSON datasets")
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
if eval_dataset == 'None':
eval_data = None
if eval_dataset == 'None':
eval_data = None
else:
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].map(
tokenize_conversation,
remove_columns=eval_data['train'].column_names,
new_fingerprint='%030x' % random.randrange(16**30)
)
else:
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
# Use format file for prompt generation
train_template["template_type"] = "dataset"
with open(clean_path('user_data/training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
format_data: dict[str, str] = json.load(formatFile)
# == store training prompt ==
for _, value in format_data.items():
prompt_key = f"template_{len(train_template)}"
train_template[prompt_key] = value
def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
for key, val in data_point.items():
if type(val) is str:
data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt, add_eos_token)
logger.info("Loading JSON datasets")
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
if eval_dataset == 'None':
eval_data = None
else:
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
# == We MUST reload model if it went through any previous training, even failed one ==
if shared.model_dirty_from_training: