mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-06 13:43:49 +01:00
Add apply_chat_template() support for LoRA training
- Support multi-turn conversations (OpenAI messages + ShareGPT formats) - Automatic assistant-only label masking via incremental tokenization - Use tokenizer.apply_chat_template() for proper special token handling - Add "Chat Template" option to the Data Format dropdown - Also accept instruction/output datasets (auto-converted to messages) - Validate chat template availability and dataset format upfront - Fix after_tokens[-1] IndexError when train_only_after is at end of prompt - Update docs
This commit is contained in:
parent
b16a1a874a
commit
d278bb46a2
|
|
@ -79,6 +79,54 @@ If you have different sets of key inputs, you can make your own format file to m
|
|||
When using raw text files as your dataset, the text is split into sections by the `Hard Cut String` (default `\n\n\n`), tokenized, concatenated into one long token sequence, and then split into non-overlapping chunks of exactly `Cutoff Length` tokens (any remainder shorter than the cutoff is dropped). This is the standard concatenate-and-split approach used by HuggingFace `run_clm.py`.
|
||||
- `Hard Cut String` sets a string that indicates a boundary between unrelated sections of text. This defaults to `\n\n\n`, meaning 3 newlines. When `Add EOS token` is enabled, an EOS token is appended after each section before concatenation. This allows you to insert unrelated sections of text in the same text file, ensuring the model learns proper boundaries between them.
|
||||
|
||||
## Chat Template Format
|
||||
|
||||
Select **Chat Template** as the Data Format to use the model's built-in chat template (via `apply_chat_template()`) instead of a format file. This works with instruct/chat models that ship with a chat template in their tokenizer (Llama 3, Qwen, Mistral, etc.).
|
||||
|
||||
**Advantages over format files:**
|
||||
- Special tokens are handled correctly by the tokenizer itself
|
||||
- Multi-turn conversations are supported natively
|
||||
- Labels are automatically masked so only assistant responses are trained on (no need for `Train Only After`)
|
||||
|
||||
**Dataset formats:** Your JSON dataset can use any of these structures:
|
||||
|
||||
OpenAI messages format (multi-turn):
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "A programming language."},
|
||||
{"role": "user", "content": "What's it used for?"},
|
||||
{"role": "assistant", "content": "Web dev, data science, scripting, and more."}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
The conversation gets tokenized with the model's own chat template (correct special tokens), and the labels are automatically masked so the model only trains on the assistant responses — the system prompt and user turns get `-100` labels and contribute no gradient.
|
||||
|
||||
ShareGPT format (`conversations` key with `from`/`value` fields):
|
||||
```json
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{"from": "system", "value": "You are a helpful assistant."},
|
||||
{"from": "human", "value": "What is Python?"},
|
||||
{"from": "gpt", "value": "A programming language."},
|
||||
{"from": "human", "value": "What's it used for?"},
|
||||
{"from": "gpt", "value": "Web dev, data science, scripting, and more."}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Simple instruction/output format (auto-converted to a single-turn conversation):
|
||||
```json
|
||||
[{"instruction": "What is 2+2?", "output": "4"}]
|
||||
```
|
||||
|
||||
## Target Modules
|
||||
|
||||
By default, **Target all linear layers** is enabled. This uses peft's `all-linear` mode, which applies LoRA to every `nn.Linear` layer in the model except the output head (`lm_head`). It works for any model architecture.
|
||||
|
|
|
|||
|
|
@ -107,8 +107,8 @@ def create_ui():
|
|||
with gr.Column():
|
||||
with gr.Tab(label='Formatted Dataset'):
|
||||
with gr.Row():
|
||||
format = gr.Dropdown(choices=utils.get_datasets('user_data/training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/formats', 'json')}, 'refresh-button', interactive=not mu)
|
||||
format = gr.Dropdown(choices=['None', 'Chat Template'] + [x for x in utils.get_datasets('user_data/training/formats', 'json') if x != 'None'], value='None', label='Data Format', info='The format file used to decide how to format the dataset input. "Chat Template" uses the model\'s built-in chat template via apply_chat_template().', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': ['None', 'Chat Template'] + [x for x in utils.get_datasets('user_data/training/formats', 'json') if x != 'None']}, 'refresh-button', interactive=not mu)
|
||||
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||
|
|
@ -353,7 +353,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
before_tokens = encode(prompt[:ind], True)
|
||||
after_tokens = encode(prompt[ind:], False)
|
||||
|
||||
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
||||
if append_eos_token and len(after_tokens) > 0 and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
||||
after_tokens.append(shared.tokenizer.eos_token_id)
|
||||
|
||||
full_length = len(after_tokens) + len(before_tokens)
|
||||
|
|
@ -371,6 +371,73 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
"attention_mask": [0 if t == shared.tokenizer.pad_token_id else 1 for t in input_ids],
|
||||
}
|
||||
|
||||
def normalize_messages(data_point):
|
||||
"""Convert a dataset row to OpenAI messages format for apply_chat_template()."""
|
||||
if "messages" in data_point:
|
||||
return data_point["messages"]
|
||||
|
||||
if "conversations" in data_point:
|
||||
role_map = {"human": "user", "gpt": "assistant"}
|
||||
return [
|
||||
{"role": role_map.get(turn.get("from", ""), turn.get("from", "")), "content": turn["value"]}
|
||||
for turn in data_point["conversations"]
|
||||
]
|
||||
|
||||
if "instruction" in data_point and "output" in data_point:
|
||||
messages = []
|
||||
if data_point.get("system", "").strip():
|
||||
messages.append({"role": "system", "content": data_point["system"]})
|
||||
messages.append({"role": "user", "content": data_point["instruction"]})
|
||||
messages.append({"role": "assistant", "content": data_point["output"]})
|
||||
return messages
|
||||
|
||||
raise RuntimeError(
|
||||
f'Dataset row must contain "messages", "conversations", or "instruction"/"output" keys. '
|
||||
f'Found: {list(data_point.keys())}'
|
||||
)
|
||||
|
||||
def tokenize_conversation(data_point):
|
||||
"""Tokenize using apply_chat_template() with assistant-only label masking."""
|
||||
messages = normalize_messages(data_point)
|
||||
full_ids = shared.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
|
||||
# Build labels: -100 for everything, then unmask assistant turns.
|
||||
# This assumes apply_chat_template(messages[:i]) is a token-for-token
|
||||
# prefix of apply_chat_template(messages[:i+1]), which holds for all
|
||||
# standard chat templates (Llama, ChatML, Mistral, etc.).
|
||||
labels = [-100] * len(full_ids)
|
||||
for i, msg in enumerate(messages):
|
||||
if msg["role"] == "assistant":
|
||||
# Tokens up to where this assistant turn starts
|
||||
header_ids = shared.tokenizer.apply_chat_template(
|
||||
messages[:i], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Tokens through end of this assistant turn
|
||||
through_ids = shared.tokenizer.apply_chat_template(
|
||||
messages[:i + 1], tokenize=True
|
||||
)
|
||||
# Unmask assistant tokens
|
||||
start = len(header_ids)
|
||||
end = min(len(through_ids), len(full_ids))
|
||||
labels[start:end] = full_ids[start:end]
|
||||
|
||||
# Truncate from the right: keeps the system prompt and early turns
|
||||
if len(full_ids) > cutoff_len:
|
||||
full_ids = full_ids[:cutoff_len]
|
||||
labels = labels[:cutoff_len]
|
||||
|
||||
# Left-pad to cutoff_len
|
||||
pad_len = cutoff_len - len(full_ids)
|
||||
attention_mask = [0] * pad_len + [1] * len(full_ids)
|
||||
labels = [-100] * pad_len + labels
|
||||
input_ids = [shared.tokenizer.pad_token_id] * pad_len + full_ids
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
train_template.clear()
|
||||
|
||||
# == Prep the dataset, format, etc ==
|
||||
|
|
@ -437,38 +504,73 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
yield "Missing format choice input, cannot continue."
|
||||
return
|
||||
|
||||
train_template["template_type"] = "dataset"
|
||||
if format == 'Chat Template':
|
||||
# Use the model's built-in chat template via apply_chat_template()
|
||||
if not getattr(shared.tokenizer, 'chat_template', None):
|
||||
yield "Error: this model's tokenizer does not have a chat template. Use a format file instead, or load an instruct/chat model."
|
||||
return
|
||||
|
||||
with open(clean_path('user_data/training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
|
||||
format_data: dict[str, str] = json.load(formatFile)
|
||||
train_template["template_type"] = "chat_template"
|
||||
|
||||
# == store training prompt ==
|
||||
for _, value in format_data.items():
|
||||
prompt_key = f"template_{len(train_template)}"
|
||||
train_template[prompt_key] = value
|
||||
logger.info("Loading JSON dataset with Chat Template format")
|
||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
|
||||
|
||||
def generate_prompt(data_point: dict[str, str]):
|
||||
for options, data in format_data.items():
|
||||
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
|
||||
for key, val in data_point.items():
|
||||
if type(val) is str:
|
||||
data = data.replace(f'%{key}%', val)
|
||||
return data
|
||||
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
||||
# Validate the first row
|
||||
try:
|
||||
normalize_messages(data['train'][0])
|
||||
except (RuntimeError, KeyError, IndexError) as e:
|
||||
yield f"Error: {e}"
|
||||
return
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
prompt = generate_prompt(data_point)
|
||||
return tokenize(prompt, add_eos_token)
|
||||
train_data = data['train'].map(
|
||||
tokenize_conversation,
|
||||
remove_columns=data['train'].column_names,
|
||||
new_fingerprint='%030x' % random.randrange(16**30)
|
||||
)
|
||||
|
||||
logger.info("Loading JSON datasets")
|
||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
|
||||
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
|
||||
if eval_dataset == 'None':
|
||||
eval_data = None
|
||||
if eval_dataset == 'None':
|
||||
eval_data = None
|
||||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].map(
|
||||
tokenize_conversation,
|
||||
remove_columns=eval_data['train'].column_names,
|
||||
new_fingerprint='%030x' % random.randrange(16**30)
|
||||
)
|
||||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
# Use format file for prompt generation
|
||||
train_template["template_type"] = "dataset"
|
||||
|
||||
with open(clean_path('user_data/training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
|
||||
format_data: dict[str, str] = json.load(formatFile)
|
||||
|
||||
# == store training prompt ==
|
||||
for _, value in format_data.items():
|
||||
prompt_key = f"template_{len(train_template)}"
|
||||
train_template[prompt_key] = value
|
||||
|
||||
def generate_prompt(data_point: dict[str, str]):
|
||||
for options, data in format_data.items():
|
||||
if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)):
|
||||
for key, val in data_point.items():
|
||||
if type(val) is str:
|
||||
data = data.replace(f'%{key}%', val)
|
||||
return data
|
||||
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
||||
|
||||
def generate_and_tokenize_prompt(data_point):
|
||||
prompt = generate_prompt(data_point)
|
||||
return tokenize(prompt, add_eos_token)
|
||||
|
||||
logger.info("Loading JSON datasets")
|
||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
|
||||
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
|
||||
if eval_dataset == 'None':
|
||||
eval_data = None
|
||||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||
|
||||
# == We MUST reload model if it went through any previous training, even failed one ==
|
||||
if shared.model_dirty_from_training:
|
||||
|
|
|
|||
Loading…
Reference in a new issue