Training: drop conversations exceeding cutoff length instead of truncating

This commit is contained in:
oobabooga 2026-03-05 14:56:27 -03:00
parent c2e494963f
commit 33a38d7ece

View file

@ -26,7 +26,7 @@ from modules.evaluate import (
from modules.logging_colors import logger
from modules.models import reload_model
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "report_to"]
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to"]
WANT_INTERRUPT = False
train_log = {}
@ -99,6 +99,7 @@ def create_ui():
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each document in text datasets.")
excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown'])
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
@ -158,7 +159,7 @@ def create_ui():
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
# Training events
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, report_to]
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output)
@ -292,7 +293,7 @@ def calc_trainable_parameters(model):
return trainable_params, all_param
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, report_to: str):
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str):
import torch
import transformers
@ -389,10 +390,12 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
end = min(len(through_ids), len(full_ids))
labels[start:end] = full_ids[start:end]
# Truncate from the right: keeps the system prompt and early turns
if len(full_ids) > cutoff_len:
full_ids = full_ids[:cutoff_len]
labels = labels[:cutoff_len]
if excess_length == 'truncate':
full_ids = full_ids[:cutoff_len]
labels = labels[:cutoff_len]
else:
return {"input_ids": [], "labels": [], "attention_mask": []}
return {
"input_ids": full_ids,
@ -490,11 +493,19 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
yield f"Error: {e}"
return
total = len(data['train'])
train_data = data['train'].map(
tokenize_conversation,
remove_columns=data['train'].column_names,
new_fingerprint='%030x' % random.randrange(16**30)
)
train_data = train_data.filter(lambda x: len(x['input_ids']) > 0)
dropped = total - len(train_data)
if dropped > 0:
logger.warning(f"Dropped {dropped}/{total} conversations exceeding cutoff length of {cutoff_len} tokens.")
if len(train_data) == 0:
yield f"Error: all {total} conversations exceed the cutoff length of {cutoff_len} tokens. Increase the cutoff length or shorten your data."
return
if eval_dataset == 'None':
eval_data = None
@ -505,6 +516,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
remove_columns=eval_data['train'].column_names,
new_fingerprint='%030x' % random.randrange(16**30)
)
eval_data = eval_data.filter(lambda x: len(x['input_ids']) > 0)
# == We MUST reload model if it went through any previous training, even failed one ==
if shared.model_dirty_from_training: