From 1c2548fd892ad888676910af990020b9a968eae3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 5 Mar 2026 12:36:10 -0300 Subject: [PATCH] Training: use dynamic padding (pad to batch max instead of cutoff_len) - Remove pre-padding from tokenize() and tokenize_conversation() - Collate function now right-pads each batch to the longest sequence - Set tokenizer padding_side to "right" (standard for training) - Remove dead natural_keys import - Reduces wasted compute on batches with short sequences - Aligns with axolotl/unsloth approach --- modules/training.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/modules/training.py b/modules/training.py index 21f7fb9b..292ee484 100644 --- a/modules/training.py +++ b/modules/training.py @@ -24,7 +24,6 @@ from modules.evaluate import ( ) from modules.logging_colors import logger from modules.models import reload_model -from modules.utils import natural_keys PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "train_only_after", "stop_at_loss", "add_eos_token", "report_to"] WANT_INTERRUPT = False @@ -313,7 +312,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: gradient_accumulation_steps = batch_size // micro_batch_size if shared.tokenizer.pad_token_id is None: shared.tokenizer.pad_token_id = shared.tokenizer.eos_token_id - shared.tokenizer.padding_side = "left" + shared.tokenizer.padding_side = "right" def list_target_modules(): if all_linear: @@ -343,9 +342,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len: input_ids.append(shared.tokenizer.eos_token_id) - pad_len = cutoff_len - len(input_ids) - labels = [-100] * pad_len + list(input_ids) - input_ids = [shared.tokenizer.pad_token_id] * pad_len + input_ids + labels = list(input_ids) else: ind = prompt.index(train_only_after) + len(train_only_after) @@ -358,8 +355,6 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: full_length = len(after_tokens) + len(before_tokens) if full_length > cutoff_len: after_tokens = after_tokens[:cutoff_len - len(before_tokens)] - else: - before_tokens = [shared.tokenizer.pad_token_id] * (cutoff_len - full_length) + before_tokens input_ids = before_tokens + after_tokens labels = [-100] * len(before_tokens) + list(after_tokens) @@ -367,7 +362,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: return { "input_ids": input_ids, "labels": labels, - "attention_mask": [0 if t == shared.tokenizer.pad_token_id else 1 for t in input_ids], + "attention_mask": [1] * len(input_ids), } def normalize_messages(data_point): @@ -425,16 +420,10 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: full_ids = full_ids[:cutoff_len] labels = labels[:cutoff_len] - # Left-pad to cutoff_len - pad_len = cutoff_len - len(full_ids) - attention_mask = [0] * pad_len + [1] * len(full_ids) - labels = [-100] * pad_len + labels - input_ids = [shared.tokenizer.pad_token_id] * pad_len + full_ids - return { - "input_ids": input_ids, + "input_ids": full_ids, "labels": labels, - "attention_mask": attention_mask, + "attention_mask": [1] * len(full_ids), } train_template.clear() @@ -694,13 +683,17 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: lora_model = torch.compile(lora_model) def collate_fn(batch): - input_ids = torch.stack([torch.as_tensor(item['input_ids']) for item in batch]) - labels = torch.stack([torch.as_tensor(item['labels']) for item in batch]) - attention_mask = torch.stack([torch.as_tensor(item['attention_mask']) for item in batch]) + max_len = max(len(item['input_ids']) for item in batch) + input_ids, labels, attention_mask = [], [], [] + for item in batch: + pad_len = max_len - len(item['input_ids']) + input_ids.append(item['input_ids'] + [shared.tokenizer.pad_token_id] * pad_len) + labels.append(item['labels'] + [-100] * pad_len) + attention_mask.append(item['attention_mask'] + [0] * pad_len) return { - 'input_ids': input_ids, - 'labels': labels, - 'attention_mask': attention_mask, + 'input_ids': torch.tensor(input_ids), + 'labels': torch.tensor(labels), + 'attention_mask': torch.tensor(attention_mask), } trainer = transformers.Trainer(