From 1c89376370b63ad32fef472114ec036edaaf8d1c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:23:24 -0700 Subject: [PATCH] training: Add gradient_checkpointing for lower VRAM by default --- docs/05 - Training Tab.md | 2 ++ modules/training.py | 8 +++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/05 - Training Tab.md b/docs/05 - Training Tab.md index 0bfc59aa..46424eab 100644 --- a/docs/05 - Training Tab.md +++ b/docs/05 - Training Tab.md @@ -100,6 +100,8 @@ Each parameter has a description in the UI. Below is guidance on the most import VRAM usage during training is roughly similar to inference with ~1000 tokens of context. If you can run the model, you can probably train LoRAs with the default settings. If you run out of VRAM, reduce `Micro Batch Size` or `Cutoff Length`. Training 4-bit quantized models uses more VRAM — set `Micro Batch Size` to `1` to compensate. +**Gradient checkpointing** is enabled by default. It reduces VRAM usage by recomputing activations during the backward pass instead of storing them in memory. The tradeoff is ~20-30% slower training. There is no impact on accuracy — the results are mathematically identical. The savings are most noticeable with longer sequences and larger batch sizes. You can disable it if you have VRAM to spare and want faster training. + ### Rank Higher rank = more learning capacity = larger adapter = more VRAM. Use 4–8 for style/format, 128–256 to teach factual knowledge. diff --git a/modules/training.py b/modules/training.py index 878bb222..6549b35e 100644 --- a/modules/training.py +++ b/modules/training.py @@ -26,7 +26,7 @@ from modules.evaluate import ( from modules.logging_colors import logger from modules.models import reload_model -PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to"] +PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to", "gradient_checkpointing"] WANT_INTERRUPT = False train_log = {} @@ -101,6 +101,7 @@ def create_ui(): add_eos_token = gr.Checkbox(label='Add EOS token', value=True, info="Adds EOS token for each document in text datasets.") excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown']) + gradient_checkpointing = gr.Checkbox(label='Gradient checkpointing', value=True, info='Trades ~20-30% training speed for reduced VRAM usage by recomputing activations during the backward pass instead of storing them. No impact on accuracy.') higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.') report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True) @@ -159,7 +160,7 @@ def create_ui(): refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu) # Training events - all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to] + all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to, gradient_checkpointing] copy_from.change(do_copy_params, [copy_from] + all_params, all_params) start_button.click(do_train, all_params, output) @@ -293,7 +294,7 @@ def calc_trainable_parameters(model): return trainable_params, all_param -def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str): +def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str, gradient_checkpointing: bool = True): import torch import transformers @@ -708,6 +709,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: load_best_model_at_end=eval_data is not None, # TODO: Enable multi-device support ddp_find_unused_parameters=None, + gradient_checkpointing=gradient_checkpointing, use_cpu=shared.args.cpu, remove_unused_columns=False, ),