mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 07:03:37 +00:00
Add apply_chat_template() support for LoRA training
- Support multi-turn conversations (OpenAI messages + ShareGPT formats) - Automatic assistant-only label masking via incremental tokenization - Use tokenizer.apply_chat_template() for proper special token handling - Add "Chat Template" option to the Data Format dropdown - Also accept instruction/output datasets (auto-converted to messages) - Validate chat template availability and dataset format upfront - Fix after_tokens[-1] IndexError when train_only_after is at end of prompt - Update docs
This commit is contained in:
parent
b16a1a874a
commit
d278bb46a2
2 changed files with 179 additions and 29 deletions
|
|
@ -107,8 +107,8 @@ def create_ui():
|
|||
with gr.Column():
|
||||
with gr.Tab(label='Formatted Dataset'):
|
||||
with gr.Row():
|
||||
format = gr.Dropdown(choices=utils.get_datasets('user_data/training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/formats', 'json')}, 'refresh-button', interactive=not mu)
|
||||
format = gr.Dropdown(choices=['None', 'Chat Template'] + [x for x in utils.get_datasets('user_data/training/formats', 'json') if x != 'None'], value='None', label='Data Format', info='The format file used to decide how to format the dataset input. "Chat Template" uses the model\'s built-in chat template via apply_chat_template().', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': ['None', 'Chat Template'] + [x for x in utils.get_datasets('user_data/training/formats', 'json') if x != 'None']}, 'refresh-button', interactive=not mu)
|
||||
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||
|
|
@ -353,7 +353,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
before_tokens = encode(prompt[:ind], True)
|
||||
after_tokens = encode(prompt[ind:], False)
|
||||
|
||||
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
||||
if append_eos_token and len(after_tokens) > 0 and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
||||
after_tokens.append(shared.tokenizer.eos_token_id)
|
||||
|
||||
full_length = len(after_tokens) + len(before_tokens)
|
||||
|
|
@ -371,6 +371,73 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
"attention_mask": [0 if t == shared.tokenizer.pad_token_id else 1 for t in input_ids],
|
||||
}
|
||||
|
||||
def normalize_messages(data_point):
|
||||
"""Convert a dataset row to OpenAI messages format for apply_chat_template()."""
|
||||
if "messages" in data_point:
|
||||
return data_point["messages"]
|
||||
|
||||
if "conversations" in data_point:
|
||||
role_map = {"human": "user", "gpt": "assistant"}
|
||||
return [
|
||||
{"role": role_map.get(turn.get("from", ""), turn.get("from", "")), "content": turn["value"]}
|
||||
for turn in data_point["conversations"]
|
||||
]
|
||||
|
||||
if "instruction" in data_point and "output" in data_point:
|
||||
messages = []
|
||||
if data_point.get("system", "").strip():
|
||||
messages.append({"role": "system", "content": data_point["system"]})
|
||||
messages.append({"role": "user", "content": data_point["instruction"]})
|
||||
messages.append({"role": "assistant", "content": data_point["output"]})
|
||||
return messages
|
||||
|
||||
raise RuntimeError(
|
||||
f'Dataset row must contain "messages", "conversations", or "instruction"/"output" keys. '
|
||||
f'Found: {list(data_point.keys())}'
|
||||
)
|
||||
|
||||
def tokenize_conversation(data_point):
|
||||
"""Tokenize using apply_chat_template() with assistant-only label masking."""
|
||||
messages = normalize_messages(data_point)
|
||||
full_ids = shared.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
|
||||
# Build labels: -100 for everything, then unmask assistant turns.
|
||||
# This assumes apply_chat_template(messages[:i]) is a token-for-token
|
||||
# prefix of apply_chat_template(messages[:i+1]), which holds for all
|
||||
# standard chat templates (Llama, ChatML, Mistral, etc.).
|
||||
labels = [-100] * len(full_ids)
|
||||
for i, msg in enumerate(messages):
|
||||
if msg["role"] == "assistant":
|
||||
# Tokens up to where this assistant turn starts
|
||||
header_ids = shared.tokenizer.apply_chat_template(
|
||||
messages[:i], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Tokens through end of this assistant turn
|
||||
through_ids = shared.tokenizer.apply_chat_template(
|
||||
messages[:i + 1], tokenize=True
|
||||
)
|
||||
# Unmask assistant tokens
|
||||
start = len(header_ids)
|
||||
end = min(len(through_ids), len(full_ids))
|
||||
labels[start:end] = full_ids[start:end]
|
||||
|
||||
# Truncate from the right: keeps the system prompt and early turns
|
||||
if len(full_ids) > cutoff_len:
|
||||
full_ids = full_ids[:cutoff_len]
|
||||
labels = labels[:cutoff_len]
|
||||
|
||||
# Left-pad to cutoff_len
|
||||
pad_len = cutoff_len - len(full_ids)
|
||||
attention_mask = [0] * pad_len + [1] * len(full_ids)
|
||||
labels = [-100] * pad_len + labels
|
||||
input_ids = [shared.tokenizer.pad_token_id] * pad_len + full_ids
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
train_template.clear()
|
||||
|
||||
# == Prep the dataset, format, etc ==
|
||||
|
|
@ -437,38 +504,73 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
yield "Missing format choice input, cannot continue."
|
||||
return
|
||||
|
||||
train_template["template_type"] = "dataset"
|
||||
if format == 'Chat Template':
|
||||
# Use the model's built-in chat template via apply_chat_template()
|
||||
if not getattr(shared.tokenizer, 'chat_template', None):
|
||||
yield "Error: this model's tokenizer does not have a chat template. Use a format file instead, or load an instruct/chat model."
|
||||
return
|
||||
|
||||
with open(clean_path('user_data/training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
|
||||
format_data: dict[str, str] = json.load(formatFile)
|
||||
train_template["template_type"] = "chat_template"
|
||||
|
||||
# == store training prompt ==
|
||||
for _, value in format_data.items():
|
||||
prompt_key = f"template_{len(train_template)}"
|
||||
train_template[prompt_key] = value
|
||||
logger.info("Loading JSON dataset with Chat Template format")
|
||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
|
||||
|
||||
def generate_prompt(data_point: dict[str, str]):
|
||||
for options, data in format_data.items():
|
||||
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
|
||||
for key, val in data_point.items():
|
||||
if type(val) is str:
|
||||
data = data.replace(f'%{key}%', val)
|
||||
return data
|
||||
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
||||
# Validate the first row
|
||||
try:
|
||||
normalize_messages(data['train'][0])
|
||||
except (RuntimeError, KeyError, IndexError) as e:
|
||||
yield f"Error: {e}"
|
||||
return
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
prompt = generate_prompt(data_point)
|
||||
return tokenize(prompt, add_eos_token)
|
||||
train_data = data['train'].map(
|
||||
tokenize_conversation,
|
||||
remove_columns=data['train'].column_names,
|
||||
new_fingerprint='%030x' % random.randrange(16**30)
|
||||
)
|
||||
|
||||
logger.info("Loading JSON datasets")
|
||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
|
||||
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
|
||||
if eval_dataset == 'None':
|
||||
eval_data = None
|
||||
if eval_dataset == 'None':
|
||||
eval_data = None
|
||||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].map(
|
||||
tokenize_conversation,
|
||||
remove_columns=eval_data['train'].column_names,
|
||||
new_fingerprint='%030x' % random.randrange(16**30)
|
||||
)
|
||||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
# Use format file for prompt generation
|
||||
train_template["template_type"] = "dataset"
|
||||
|
||||
with open(clean_path('user_data/training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
|
||||
format_data: dict[str, str] = json.load(formatFile)
|
||||
|
||||
# == store training prompt ==
|
||||
for _, value in format_data.items():
|
||||
prompt_key = f"template_{len(train_template)}"
|
||||
train_template[prompt_key] = value
|
||||
|
||||
def generate_prompt(data_point: dict[str, str]):
|
||||
for options, data in format_data.items():
|
||||
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
|
||||
for key, val in data_point.items():
|
||||
if type(val) is str:
|
||||
data = data.replace(f'%{key}%', val)
|
||||
return data
|
||||
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
prompt = generate_prompt(data_point)
|
||||
return tokenize(prompt, add_eos_token)
|
||||
|
||||
logger.info("Loading JSON datasets")
|
||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
|
||||
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
|
||||
if eval_dataset == 'None':
|
||||
eval_data = None
|
||||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
|
||||
# == We MUST reload model if it went through any previous training, even failed one ==
|
||||
if shared.model_dirty_from_training:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue