2023-07-12 16:19:12 +02:00
import os
os . environ [ " WANDB_MODE " ] = " offline "
2023-07-12 16:53:31 +02:00
# os.environ["WANDB_DISABLED"] = "true"
2023-07-12 16:19:12 +02:00
2023-03-28 02:24:39 +02:00
import json
2023-04-16 07:35:13 +02:00
import math
2023-05-24 17:43:22 +02:00
import random
2023-07-12 03:49:06 +02:00
import shutil
2023-03-28 02:24:39 +02:00
import sys
import threading
import time
2023-03-29 16:55:34 +02:00
import traceback
2023-07-12 03:49:06 +02:00
from datetime import datetime
2023-03-25 20:08:26 +01:00
from pathlib import Path
2023-03-28 02:24:39 +02:00
2026-03-05 18:39:37 +01:00
import yaml
2023-03-25 20:08:26 +01:00
import gradio as gr
2023-03-28 02:24:39 +02:00
2023-05-06 04:14:56 +02:00
from modules import shared , ui , utils
2023-06-25 06:44:36 +02:00
from modules . evaluate import (
calculate_perplexity ,
generate_markdown_table ,
save_past_evaluations
)
2023-05-22 03:42:34 +02:00
from modules . logging_colors import logger
2023-08-18 21:58:38 +02:00
from modules . models import reload_model
2023-03-25 20:08:26 +01:00
2026-03-05 18:56:27 +01:00
PARAMETERS = [ " lora_name " , " always_override " , " all_linear " , " q_proj_en " , " v_proj_en " , " k_proj_en " , " o_proj_en " , " gate_proj_en " , " down_proj_en " , " up_proj_en " , " save_steps " , " micro_batch_size " , " batch_size " , " epochs " , " learning_rate " , " lr_scheduler_type " , " lora_rank " , " lora_alpha " , " lora_dropout " , " cutoff_len " , " dataset " , " eval_dataset " , " format " , " eval_steps " , " text_dataset " , " higher_rank_limit " , " warmup_steps " , " optimizer " , " stride_length " , " stop_at_loss " , " add_eos_token " , " excess_length " , " report_to " ]
2023-08-18 21:58:38 +02:00
WANT_INTERRUPT = False
2023-04-16 07:46:27 +02:00
2023-06-20 05:47:36 +02:00
train_log = { }
2023-06-25 20:34:46 +02:00
train_template = { }
2023-03-27 19:25:08 +02:00
2023-04-16 07:35:13 +02:00
2023-08-07 02:49:27 +02:00
def create_ui ( ) :
2023-09-26 14:44:04 +02:00
mu = shared . args . multi_user
2023-08-07 02:49:27 +02:00
with gr . Tab ( " Training " , elem_id = " training-tab " ) :
with gr . Tab ( ' Train LoRA ' , elem_id = ' lora-train-tab ' ) :
2023-08-18 21:58:38 +02:00
tmp = gr . State ( ' ' )
2023-03-28 07:15:32 +02:00
with gr . Row ( ) :
2023-08-18 21:58:38 +02:00
with gr . Column ( ) :
2023-10-23 19:09:57 +02:00
gr . Markdown ( " [Tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05- %E 2 %80% 90-Training-Tab) " )
2023-04-16 07:46:27 +02:00
2023-08-18 21:58:38 +02:00
with gr . Row ( ) :
2023-09-26 14:44:04 +02:00
copy_from = gr . Dropdown ( label = ' Copy parameters from ' , value = ' None ' , choices = utils . get_available_loras ( ) , elem_classes = [ ' slim-dropdown ' ] , interactive = not mu )
ui . create_refresh_button ( copy_from , lambda : None , lambda : { ' choices ' : utils . get_available_loras ( ) } , ' refresh-button ' , interactive = not mu )
2023-04-06 07:04:11 +02:00
2023-08-18 21:58:38 +02:00
with gr . Row ( ) :
with gr . Column ( scale = 5 ) :
lora_name = gr . Textbox ( label = ' Name ' , info = ' The name of your new LoRA file ' )
with gr . Column ( ) :
2023-08-23 05:18:16 +02:00
always_override = gr . Checkbox ( label = ' Override Existing Files ' , value = False , info = ' If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same). ' , elem_classes = [ ' no-background ' ] )
2023-04-16 07:46:27 +02:00
2025-04-24 01:10:16 +02:00
with gr . Accordion ( label = ' Target Modules ' , open = False , elem_classes = ' tgw-accordion ' ) :
2026-03-05 20:15:16 +01:00
gr . Markdown ( " Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM and adapter size. " )
all_linear = gr . Checkbox ( label = ' Target all linear layers ' , value = True , info = ' Targets every nn.Linear layer except lm_head. Works for any model architecture. When checked, the individual module checkboxes below are ignored. ' , elem_classes = [ ' no-background ' ] )
2023-10-22 20:57:19 +02:00
with gr . Row ( ) :
with gr . Column ( ) :
q_proj_en = gr . Checkbox ( label = ' Enable q_proj ' , value = True )
with gr . Column ( ) :
v_proj_en = gr . Checkbox ( label = ' Enable v_proj ' , value = True )
with gr . Column ( ) :
k_proj_en = gr . Checkbox ( label = ' Enable k_proj ' , value = False )
with gr . Column ( ) :
o_proj_en = gr . Checkbox ( label = ' Enable o_proj ' , value = False )
with gr . Column ( ) :
gate_proj_en = gr . Checkbox ( label = ' Enable gate_proj ' , value = False )
with gr . Column ( ) :
down_proj_en = gr . Checkbox ( label = ' Enable down_proj ' , value = False )
with gr . Column ( ) :
up_proj_en = gr . Checkbox ( label = ' Enable up_proj ' , value = False )
2023-08-18 21:58:38 +02:00
with gr . Row ( ) :
with gr . Column ( ) :
2026-03-05 19:12:32 +01:00
lora_rank = gr . Slider ( label = ' LoRA Rank ' , value = 8 , minimum = 0 , maximum = 1024 , step = 4 , info = ' Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks. ' )
lora_alpha = gr . Slider ( label = ' LoRA Alpha ' , value = 16 , minimum = 0 , maximum = 2048 , step = 4 , info = ' This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank. ' )
2026-03-05 20:15:16 +01:00
batch_size = gr . Slider ( label = ' Batch Size ' , value = 32 , minimum = 0 , maximum = 1024 , step = 4 , info = ' Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training. ' )
2023-08-18 21:58:38 +02:00
micro_batch_size = gr . Slider ( label = ' Micro Batch Size ' , value = 4 , minimum = 1 , maximum = 128 , step = 1 , info = ' Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage. ' )
2026-03-05 20:15:16 +01:00
cutoff_len = gr . Slider ( label = ' Cutoff Length ' , minimum = 0 , maximum = 4096 , value = 512 , step = 32 , info = ' Maximum sequence length in tokens. For instruction datasets, conversations longer than this are dropped. For text datasets, documents are split into chunks of this size. Higher values require more VRAM. ' )
2023-04-21 05:20:33 +02:00
2023-08-18 21:58:38 +02:00
with gr . Column ( ) :
2026-03-05 20:15:16 +01:00
save_steps = gr . Number ( label = ' Save every n steps ' , value = 0 , info = ' If above 0, a full training checkpoint (adapter weights, optimizer, scheduler) will be saved every time this many steps pass. Training can be resumed from these checkpoints. ' )
2023-08-18 21:58:38 +02:00
epochs = gr . Number ( label = ' Epochs ' , value = 3 , info = ' Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc. ' )
learning_rate = gr . Textbox ( label = ' Learning Rate ' , value = ' 3e-4 ' , info = ' In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low. ' )
2023-10-11 03:20:49 +02:00
with gr . Row ( ) :
2026-03-05 20:15:16 +01:00
lr_scheduler_type = gr . Dropdown ( label = ' LR Scheduler ' , value = ' cosine ' , choices = [ ' linear ' , ' constant ' , ' constant_with_warmup ' , ' cosine ' , ' cosine_with_restarts ' , ' polynomial ' , ' inverse_sqrt ' ] , info = ' Learning rate scheduler - defines how the learning rate changes over time. " Constant " means never change, " linear " means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc. ' , elem_classes = [ ' slim-dropdown ' ] )
2023-07-12 20:29:43 +02:00
2025-04-24 01:10:16 +02:00
with gr . Accordion ( label = ' Advanced Options ' , open = False , elem_classes = ' tgw-accordion ' ) :
2023-08-18 21:58:38 +02:00
with gr . Row ( ) :
with gr . Column ( ) :
2026-03-05 19:12:32 +01:00
lora_dropout = gr . Slider ( label = ' LoRA Dropout ' , minimum = 0.0 , maximum = 1.0 , step = 0.025 , value = 0.0 , info = ' Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default. ' )
2023-08-18 21:58:38 +02:00
stop_at_loss = gr . Slider ( label = ' Stop at loss ' , minimum = 0.0 , maximum = 3.0 , step = 0.1 , value = 0.00 , info = ' The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8) ' )
2023-10-11 03:20:49 +02:00
with gr . Row ( ) :
2026-03-05 20:15:16 +01:00
optimizer = gr . Dropdown ( label = ' Optimizer ' , value = ' adamw_torch ' , choices = [ ' adamw_hf ' , ' adamw_torch ' , ' adamw_torch_fused ' , ' adamw_torch_xla ' , ' adamw_apex_fused ' , ' adafactor ' , ' adamw_bnb_8bit ' , ' adamw_anyprecision ' , ' sgd ' , ' adagrad ' ] , info = ' Optimizer algorithm. adamw_torch is the standard choice. adamw_bnb_8bit uses less VRAM. adafactor is memory-efficient for large models. ' , elem_classes = [ ' slim-dropdown ' ] )
2023-08-18 21:58:38 +02:00
with gr . Column ( ) :
2026-03-05 20:15:16 +01:00
warmup_steps = gr . Number ( label = ' Warmup Steps ' , value = 100 , info = ' For this many steps at the start, the learning rate is gradually ramped up from 0 to the target value. This prevents unstable updates early in training. ' )
2023-08-18 21:58:38 +02:00
2026-03-05 19:12:32 +01:00
add_eos_token = gr . Checkbox ( label = ' Add EOS token ' , value = True , info = " Adds EOS token for each document in text datasets. " )
2026-03-05 18:56:27 +01:00
excess_length = gr . Dropdown ( label = ' Excess length ' , value = ' drop ' , choices = [ ' drop ' , ' truncate ' ] , info = ' What to do with conversations that exceed the cutoff length. " Drop " removes them entirely (recommended). " Truncate " cuts from the right, which may produce incomplete responses. ' , elem_classes = [ ' slim-dropdown ' ] )
2023-08-18 21:58:38 +02:00
higher_rank_limit = gr . Checkbox ( label = ' Enable higher ranks ' , value = False , info = ' If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU. ' )
report_to = gr . Radio ( label = " Save detailed logs with " , value = " None " , choices = [ " None " , " wandb " , " tensorboard " ] , interactive = True )
with gr . Column ( ) :
2026-03-05 20:15:16 +01:00
with gr . Tab ( label = ' Chat Dataset ' ) :
2023-08-18 21:58:38 +02:00
with gr . Row ( ) :
2026-03-05 20:15:16 +01:00
dataset = gr . Dropdown ( choices = utils . get_chat_datasets ( ' user_data/training/datasets ' ) , value = ' None ' , label = ' Dataset File ' , info = ' A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation. ' , elem_classes = [ ' slim-dropdown ' ] , interactive = not mu )
ui . create_refresh_button ( dataset , lambda : None , lambda : { ' choices ' : utils . get_chat_datasets ( ' user_data/training/datasets ' ) } , ' refresh-button ' , interactive = not mu )
2023-08-18 21:58:38 +02:00
with gr . Row ( ) :
2026-03-05 20:15:16 +01:00
format = gr . Dropdown ( choices = get_instruction_templates ( ) , value = ' None ' , label = ' Instruction Template ' , info = ' Select an instruction template for formatting the dataset, or " Chat Template " to use the model \' s built-in chat template. ' , elem_classes = [ ' slim-dropdown ' ] , interactive = not mu )
ui . create_refresh_button ( format , lambda : None , lambda : { ' choices ' : get_instruction_templates ( ) } , ' refresh-button ' , interactive = not mu )
2023-08-18 21:58:38 +02:00
2026-03-05 20:15:16 +01:00
with gr . Tab ( label = " Text Dataset " ) :
2023-08-18 21:58:38 +02:00
with gr . Row ( ) :
2026-03-05 20:15:16 +01:00
text_dataset = gr . Dropdown ( choices = utils . get_text_datasets ( ' user_data/training/datasets ' ) , value = ' None ' , label = ' Dataset File ' , info = ' A JSON file with a " text " key per row, for pretraining-style training. Each row is one document. ' , elem_classes = [ ' slim-dropdown ' ] , interactive = not mu )
ui . create_refresh_button ( text_dataset , lambda : None , lambda : { ' choices ' : utils . get_text_datasets ( ' user_data/training/datasets ' ) } , ' refresh-button ' , interactive = not mu )
2023-08-18 21:58:38 +02:00
2026-03-05 20:15:16 +01:00
stride_length = gr . Slider ( label = ' Stride Length ' , minimum = 0 , maximum = 2048 , value = 256 , step = 32 , info = ' Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries. ' )
2023-08-18 21:58:38 +02:00
2026-03-05 20:15:16 +01:00
with gr . Row ( ) :
eval_dataset = gr . Dropdown ( choices = utils . get_datasets ( ' user_data/training/datasets ' , ' json ' ) , value = ' None ' , label = ' Evaluation Dataset ' , info = ' The (optional) dataset file used to evaluate the model after training. ' , elem_classes = [ ' slim-dropdown ' ] , interactive = not mu )
ui . create_refresh_button ( eval_dataset , lambda : None , lambda : { ' choices ' : utils . get_datasets ( ' user_data/training/datasets ' , ' json ' ) } , ' refresh-button ' , interactive = not mu )
2023-08-18 21:58:38 +02:00
2026-03-05 20:15:16 +01:00
eval_steps = gr . Number ( label = ' Evaluate every n steps ' , value = 100 , info = ' If an evaluation dataset is given, test it every time this many steps pass. ' )
2023-08-18 21:58:38 +02:00
with gr . Row ( ) :
2023-09-26 14:44:04 +02:00
start_button = gr . Button ( " Start LoRA Training " , variant = ' primary ' , interactive = not mu )
stop_button = gr . Button ( " Interrupt " , interactive = not mu )
2023-08-18 21:58:38 +02:00
output = gr . Markdown ( value = " Ready " )
2023-08-07 02:49:27 +02:00
with gr . Tab ( ' Perplexity evaluation ' , elem_id = ' evaluate-tab ' ) :
with gr . Row ( ) :
with gr . Column ( ) :
2023-09-26 14:44:04 +02:00
models = gr . Dropdown ( utils . get_available_models ( ) , label = ' Models ' , multiselect = True , interactive = not mu )
2025-04-26 13:56:54 +02:00
evaluate_text_file = gr . Dropdown ( choices = [ ' wikitext ' , ' ptb ' , ' ptb_new ' ] + utils . get_datasets ( ' user_data/training/datasets ' , ' txt ' ) [ 1 : ] , value = ' wikitext ' , label = ' Input dataset ' , info = ' The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under user_data/training/datasets. ' , interactive = not mu )
2023-08-07 02:49:27 +02:00
with gr . Row ( ) :
2023-08-18 21:58:38 +02:00
with gr . Column ( ) :
2023-09-29 22:06:26 +02:00
stride_length = gr . Slider ( label = ' Stride ' , minimum = 0 , maximum = 32768 , value = 512 , step = 256 , info = ' Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value. ' )
2023-08-18 21:58:38 +02:00
with gr . Column ( ) :
2024-07-28 08:13:34 +02:00
max_length = gr . Number ( label = ' max_length ' , precision = 0 , step = 256 , value = 0 , info = ' The context for each evaluation. If set to 0, the maximum context length for the model will be used. ' )
2023-08-07 02:49:27 +02:00
with gr . Row ( ) :
2023-09-26 14:44:04 +02:00
start_current_evaluation = gr . Button ( " Evaluate loaded model " , interactive = not mu )
start_evaluation = gr . Button ( " Evaluate selected models " , interactive = not mu )
stop_evaluation = gr . Button ( " Interrupt " , interactive = not mu )
2023-08-07 02:49:27 +02:00
with gr . Column ( ) :
evaluation_log = gr . Markdown ( value = ' ' )
2023-10-27 15:49:14 +02:00
evaluation_table = gr . Dataframe ( value = generate_markdown_table ( ) , interactive = True )
2023-08-07 02:49:27 +02:00
with gr . Row ( ) :
2023-09-26 14:44:04 +02:00
save_comments = gr . Button ( ' Save comments ' , elem_classes = " small-button " , interactive = not mu )
refresh_table = gr . Button ( ' Refresh the table ' , elem_classes = " small-button " , interactive = not mu )
2023-08-07 02:49:27 +02:00
# Training events
2026-03-05 18:56:27 +01:00
all_params = [ lora_name , always_override , all_linear , q_proj_en , v_proj_en , k_proj_en , o_proj_en , gate_proj_en , down_proj_en , up_proj_en , save_steps , micro_batch_size , batch_size , epochs , learning_rate , lr_scheduler_type , lora_rank , lora_alpha , lora_dropout , cutoff_len , dataset , eval_dataset , format , eval_steps , text_dataset , higher_rank_limit , warmup_steps , optimizer , stride_length , stop_at_loss , add_eos_token , excess_length , report_to ]
2023-07-12 16:53:31 +02:00
2023-04-21 05:20:33 +02:00
copy_from . change ( do_copy_params , [ copy_from ] + all_params , all_params )
start_button . click ( do_train , all_params , output )
stop_button . click ( do_interrupt , None , None , queue = False )
higher_rank_limit . change ( change_rank_limit , [ higher_rank_limit ] , [ lora_rank , lora_alpha ] )
# Evaluation events. For some reason, the interrupt event
# doesn't work with the .then() syntax, so I write them one
# by one in this ugly but functional way.
ev = start_evaluation . click ( calculate_perplexity , [ models , evaluate_text_file , stride_length , max_length ] , evaluation_log , show_progress = False )
2024-02-22 05:27:25 +01:00
ev . then ( generate_markdown_table , None , evaluation_table , show_progress = False )
2023-04-21 05:20:33 +02:00
2024-02-22 05:27:25 +01:00
ev_cur = start_current_evaluation . click (
lambda : [ ' current model ' ] , None , tmp ) . then (
calculate_perplexity , [ tmp , evaluate_text_file , stride_length , max_length ] , evaluation_log , show_progress = False )
ev_cur . then ( generate_markdown_table , None , evaluation_table , show_progress = False )
2023-04-21 05:20:33 +02:00
stop_evaluation . click ( None , None , None , cancels = [ ev , ev_cur ] , queue = False )
2023-05-25 20:06:22 +02:00
refresh_table . click ( generate_markdown_table , None , evaluation_table , show_progress = True )
2023-04-21 05:20:33 +02:00
save_comments . click (
save_past_evaluations , evaluation_table , None ) . then (
lambda : " Comments saved. " , None , evaluation_log , show_progress = False )
2023-04-16 07:35:13 +02:00
2023-04-07 05:15:45 +02:00
2023-03-28 03:19:06 +02:00
def do_interrupt ( ) :
2023-03-27 19:43:01 +02:00
global WANT_INTERRUPT
WANT_INTERRUPT = True
2023-03-25 20:08:26 +01:00
2023-04-07 05:15:45 +02:00
2023-04-20 00:39:03 +02:00
def do_copy_params ( lora_name : str , * args ) :
f_name = f " { shared . args . lora_dir } / { clean_path ( None , lora_name ) } /training_parameters.json "
if Path ( f_name ) . is_file ( ) :
with open ( f_name , ' r ' , encoding = ' utf-8 ' ) as format_file :
params : dict [ str , str ] = json . load ( format_file )
else :
params = { }
result = list ( )
for i in range ( 0 , len ( PARAMETERS ) ) :
key = PARAMETERS [ i ]
if key in params :
result . append ( params [ key ] )
else :
result . append ( args [ i ] )
2023-04-21 05:20:33 +02:00
2023-04-20 00:39:03 +02:00
return result
def change_rank_limit ( use_higher_ranks : bool ) :
mult = 2 if use_higher_ranks else 1
return { " maximum " : 1024 * mult , " __type__ " : " update " } , { " maximum " : 2048 * mult , " __type__ " : " update " }
2023-03-28 03:19:06 +02:00
def clean_path ( base_path : str , path : str ) :
2023-05-19 17:58:54 +02:00
""" Strips unusual symbols and forcibly builds a path as relative to the intended directory. """
2023-03-25 20:08:26 +01:00
path = path . replace ( ' \\ ' , ' / ' ) . replace ( ' .. ' , ' _ ' )
2023-03-28 03:17:42 +02:00
if base_path is None :
2023-03-25 20:08:26 +01:00
return path
2023-04-16 07:46:27 +02:00
2023-03-28 03:17:42 +02:00
return f ' { Path ( base_path ) . absolute ( ) } / { path } '
2023-03-25 20:08:26 +01:00
2023-04-07 05:15:45 +02:00
2026-03-05 18:39:37 +01:00
def get_instruction_templates ( ) :
path = Path ( ' user_data/instruction-templates ' )
names = set ( )
for ext in [ ' yaml ' , ' yml ' , ' jinja ' ] :
for f in path . glob ( f ' *. { ext } ' ) :
names . add ( f . stem )
return [ ' None ' , ' Chat Template ' ] + sorted ( names , key = utils . natural_keys )
def load_template ( name ) :
""" Load a Jinja2 template string from user_data/instruction-templates/. """
path = Path ( ' user_data/instruction-templates ' )
for ext in [ ' jinja ' , ' yaml ' , ' yml ' ] :
filepath = path / f ' { name } . { ext } '
if filepath . exists ( ) :
if ext == ' jinja ' :
return filepath . read_text ( encoding = ' utf-8 ' )
else :
data = yaml . safe_load ( filepath . read_text ( encoding = ' utf-8 ' ) )
return data . get ( ' instruction_template ' , ' ' )
return ' '
2023-06-27 23:24:04 +02:00
def backup_adapter ( input_folder ) :
2026-03-05 14:49:08 +01:00
# Get the creation date of the adapter file (safetensors or bin)
2023-06-27 23:24:04 +02:00
try :
2026-03-05 14:49:08 +01:00
adapter_file = Path ( f " { input_folder } /adapter_model.safetensors " )
if not adapter_file . is_file ( ) :
adapter_file = Path ( f " { input_folder } /adapter_model.bin " )
2023-06-27 23:24:04 +02:00
if adapter_file . is_file ( ) :
2023-12-20 05:54:32 +01:00
logger . info ( " Backing up existing LoRA adapter " )
2023-06-27 23:24:04 +02:00
creation_date = datetime . fromtimestamp ( adapter_file . stat ( ) . st_ctime )
creation_date_str = creation_date . strftime ( " Backup- % Y- % m- %d " )
# Create the new subfolder
2023-07-12 03:49:06 +02:00
subfolder_path = Path ( f " { input_folder } / { creation_date_str } " )
2025-06-17 16:11:59 +02:00
subfolder_path . mkdir ( parents = True , exist_ok = True )
2023-06-27 23:24:04 +02:00
# Check if the file already exists in the subfolder
2026-03-05 14:49:08 +01:00
backup_adapter_file = subfolder_path / adapter_file . name
2023-06-27 23:24:04 +02:00
if backup_adapter_file . is_file ( ) :
print ( " - Backup already exists. Skipping backup process. " )
return
# Copy existing files to the new subfolder
existing_files = Path ( input_folder ) . iterdir ( )
for file in existing_files :
if file . is_file ( ) :
shutil . copy2 ( file , subfolder_path )
except Exception as e :
print ( " An error occurred in backup_adapter: " , str ( e ) )
2023-07-12 03:49:06 +02:00
2023-07-03 22:38:36 +02:00
def calc_trainable_parameters ( model ) :
trainable_params = 0
all_param = 0
for _ , param in model . named_parameters ( ) :
num_params = param . numel ( )
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr ( param , " ds_numel " ) :
num_params = param . ds_numel
all_param + = num_params
if param . requires_grad :
trainable_params + = num_params
2023-07-12 03:49:06 +02:00
return trainable_params , all_param
2023-07-03 22:38:36 +02:00
2023-06-27 23:24:04 +02:00
2026-03-05 18:56:27 +01:00
def do_train ( lora_name : str , always_override : bool , all_linear : bool , q_proj_en : bool , v_proj_en : bool , k_proj_en : bool , o_proj_en : bool , gate_proj_en : bool , down_proj_en : bool , up_proj_en : bool , save_steps : int , micro_batch_size : int , batch_size : int , epochs : int , learning_rate : str , lr_scheduler_type : str , lora_rank : int , lora_alpha : int , lora_dropout : float , cutoff_len : int , dataset : str , eval_dataset : str , format : str , eval_steps : int , text_dataset : str , higher_rank_limit : bool , warmup_steps : int , optimizer : str , stride_length : int , stop_at_loss : float , add_eos_token : bool , excess_length : str , report_to : str ) :
2023-04-20 00:39:03 +02:00
2025-04-20 18:33:47 +02:00
import torch
import transformers
from datasets import Dataset , load_dataset
2024-09-04 04:40:53 +02:00
from peft import (
LoraConfig ,
get_peft_model ,
prepare_model_for_kbit_training ,
set_peft_model_state_dict
)
2023-04-16 07:35:13 +02:00
global WANT_INTERRUPT
2023-03-27 19:43:01 +02:00
WANT_INTERRUPT = False
2023-03-28 03:17:42 +02:00
2023-03-27 19:25:08 +02:00
# == Input validation / processing ==
2023-08-18 21:58:38 +02:00
yield " Preparing the input... "
2023-04-20 00:39:03 +02:00
lora_file_path = clean_path ( None , lora_name )
if lora_file_path . strip ( ) == ' ' :
yield " Missing or invalid LoRA file name input. "
return
2023-08-18 21:58:38 +02:00
lora_file_path = f " { Path ( shared . args . lora_dir ) } / { lora_file_path } "
2023-03-28 07:15:32 +02:00
actual_lr = float ( learning_rate )
2023-04-06 07:04:11 +02:00
model_type = type ( shared . model ) . __name__
2023-04-16 08:08:37 +02:00
2026-03-05 14:49:08 +01:00
if model_type == " PeftModelForCausalLM " :
if len ( shared . lora_names ) > 0 :
yield " You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)* "
logger . warning ( " Training LoRA over top of another LoRA. May have unexpected effects. " )
2023-04-06 07:04:11 +02:00
else :
2026-03-05 14:49:08 +01:00
yield " Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)* "
logger . warning ( " Model ID not matched due to LoRA loading. Consider reloading base model. " )
2023-05-04 02:43:17 +02:00
2023-04-06 07:04:11 +02:00
time . sleep ( 5 )
2023-03-28 07:15:32 +02:00
if cutoff_len < = 0 or micro_batch_size < = 0 or batch_size < = 0 or actual_lr < = 0 or lora_rank < = 0 or lora_alpha < = 0 :
2023-04-03 05:54:56 +02:00
yield " Cannot input zeroes. "
2023-03-28 07:15:32 +02:00
return
2023-03-28 03:17:42 +02:00
gradient_accumulation_steps = batch_size / / micro_batch_size
2026-03-05 14:49:08 +01:00
if shared . tokenizer . pad_token_id is None :
shared . tokenizer . pad_token_id = shared . tokenizer . eos_token_id
2026-03-05 16:36:10 +01:00
shared . tokenizer . padding_side = " right "
2023-03-28 03:17:42 +02:00
2026-03-05 14:49:08 +01:00
def list_target_modules ( ) :
if all_linear :
return " all-linear "
target_mods = [ f " { name } _proj " for name , enabled in {
" q " : q_proj_en , " k " : k_proj_en , " v " : v_proj_en , " o " : o_proj_en ,
" gate " : gate_proj_en , " down " : down_proj_en , " up " : up_proj_en ,
} . items ( ) if enabled ]
2023-10-22 20:57:19 +02:00
return target_mods
2026-03-05 15:46:45 +01:00
def normalize_messages ( data_point ) :
""" Convert a dataset row to OpenAI messages format for apply_chat_template(). """
if " messages " in data_point :
return data_point [ " messages " ]
if " conversations " in data_point :
role_map = { " human " : " user " , " gpt " : " assistant " }
return [
{ " role " : role_map . get ( turn . get ( " from " , " " ) , turn . get ( " from " , " " ) ) , " content " : turn [ " value " ] }
for turn in data_point [ " conversations " ]
]
raise RuntimeError (
2026-03-05 18:39:37 +01:00
f ' Dataset row must contain " messages " or " conversations " key. '
2026-03-05 15:46:45 +01:00
f ' Found: { list ( data_point . keys ( ) ) } '
)
def tokenize_conversation ( data_point ) :
""" Tokenize using apply_chat_template() with assistant-only label masking. """
messages = normalize_messages ( data_point )
2026-03-05 17:45:05 +01:00
full_ids = list ( shared . tokenizer . apply_chat_template ( messages , tokenize = True , return_dict = False ) )
2026-03-05 15:46:45 +01:00
# Build labels: -100 for everything, then unmask assistant turns.
# This assumes apply_chat_template(messages[:i]) is a token-for-token
# prefix of apply_chat_template(messages[:i+1]), which holds for all
# standard chat templates (Llama, ChatML, Mistral, etc.).
labels = [ - 100 ] * len ( full_ids )
for i , msg in enumerate ( messages ) :
if msg [ " role " ] == " assistant " :
# Tokens up to where this assistant turn starts
header_ids = shared . tokenizer . apply_chat_template (
2026-03-05 17:45:05 +01:00
messages [ : i ] , tokenize = True , return_dict = False , add_generation_prompt = True
2026-03-05 15:46:45 +01:00
)
# Tokens through end of this assistant turn
through_ids = shared . tokenizer . apply_chat_template (
2026-03-05 17:45:05 +01:00
messages [ : i + 1 ] , tokenize = True , return_dict = False
2026-03-05 15:46:45 +01:00
)
# Unmask assistant tokens
start = len ( header_ids )
end = min ( len ( through_ids ) , len ( full_ids ) )
labels [ start : end ] = full_ids [ start : end ]
if len ( full_ids ) > cutoff_len :
2026-03-05 18:56:27 +01:00
if excess_length == ' truncate ' :
full_ids = full_ids [ : cutoff_len ]
labels = labels [ : cutoff_len ]
else :
return { " input_ids " : [ ] , " labels " : [ ] , " attention_mask " : [ ] }
2026-03-05 15:46:45 +01:00
return {
2026-03-05 16:36:10 +01:00
" input_ids " : full_ids ,
2026-03-05 15:46:45 +01:00
" labels " : labels ,
2026-03-05 16:36:10 +01:00
" attention_mask " : [ 1 ] * len ( full_ids ) ,
2026-03-05 15:46:45 +01:00
}
2023-06-25 20:34:46 +02:00
train_template . clear ( )
2023-03-28 07:15:32 +02:00
# == Prep the dataset, format, etc ==
2026-03-05 20:15:16 +01:00
has_text_dataset = text_dataset not in [ ' None ' , ' ' ]
has_chat_dataset = dataset not in [ ' None ' , ' ' ]
if has_text_dataset and has_chat_dataset :
yield " Error: select either a Chat Dataset or a Text Dataset, not both. "
return
2023-04-16 07:46:27 +02:00
2026-03-05 20:15:16 +01:00
def tokenize_text_data ( data ) :
""" Tokenize text dataset rows, concatenate, and split into chunks. """
2026-03-05 14:49:08 +01:00
all_tokens = [ ]
2026-03-05 20:15:16 +01:00
for row in data :
2026-03-05 16:32:17 +01:00
tokens = shared . tokenizer . encode ( row [ ' text ' ] )
2023-07-12 20:29:43 +02:00
if add_eos_token :
tokens . append ( shared . tokenizer . eos_token_id )
2026-03-05 14:49:08 +01:00
all_tokens . extend ( tokens )
2023-05-19 17:58:54 +02:00
2026-03-05 16:32:17 +01:00
stride = int ( stride_length )
step = cutoff_len - stride if stride > 0 else cutoff_len
2023-07-12 20:29:43 +02:00
2026-03-05 16:32:17 +01:00
if step < = 0 :
2026-03-05 20:15:16 +01:00
return None , " Error: stride length must be smaller than cutoff length. "
2026-03-05 16:32:17 +01:00
if len ( all_tokens ) < cutoff_len :
2026-03-05 20:15:16 +01:00
return None , " Error: dataset is too short to fill even one chunk of the given cutoff length. "
2026-03-05 14:49:08 +01:00
2026-03-05 16:32:17 +01:00
chunks = [ ]
for start in range ( 0 , len ( all_tokens ) , step ) :
chunk = all_tokens [ start : start + cutoff_len ]
if len ( chunk ) == 0 :
break
if len ( chunk ) < cutoff_len :
pad_len = cutoff_len - len ( chunk )
chunks . append ( {
" input_ids " : chunk + [ shared . tokenizer . pad_token_id ] * pad_len ,
" labels " : list ( chunk ) + [ - 100 ] * pad_len ,
" attention_mask " : [ 1 ] * len ( chunk ) + [ 0 ] * pad_len ,
} )
else :
chunks . append ( {
" input_ids " : chunk ,
" labels " : list ( chunk ) ,
" attention_mask " : [ 1 ] * cutoff_len ,
} )
2026-03-05 20:15:16 +01:00
return Dataset . from_list ( chunks ) , None
if has_text_dataset :
train_template [ " template_type " ] = " text_dataset "
logger . info ( " Loading text dataset " )
data = load_dataset ( " json " , data_files = clean_path ( ' user_data/training/datasets ' , f ' { text_dataset } .json ' ) )
if " text " not in data [ ' train ' ] . column_names :
yield " Error: text dataset must have a \" text \" key per row. "
2023-03-28 07:15:32 +02:00
return
2023-03-29 16:48:17 +02:00
2026-03-05 20:15:16 +01:00
train_data , err = tokenize_text_data ( data [ ' train ' ] )
if err :
yield err
return
if eval_dataset == ' None ' :
eval_data = None
else :
eval_raw = load_dataset ( " json " , data_files = clean_path ( ' user_data/training/datasets ' , f ' { eval_dataset } .json ' ) )
if " text " not in eval_raw [ ' train ' ] . column_names :
yield " Error: evaluation dataset must have a \" text \" key per row. "
return
eval_data , err = tokenize_text_data ( eval_raw [ ' train ' ] )
if err :
yield err
return
elif has_chat_dataset :
2023-03-29 16:48:17 +02:00
if format in [ ' None ' , ' ' ] :
2023-08-18 21:58:38 +02:00
yield " Missing format choice input, cannot continue. "
2023-03-28 07:15:32 +02:00
return
2023-03-28 03:17:42 +02:00
2026-03-05 15:46:45 +01:00
if format == ' Chat Template ' :
if not getattr ( shared . tokenizer , ' chat_template ' , None ) :
2026-03-05 18:39:37 +01:00
yield " Error: this model ' s tokenizer does not have a chat template. Select an instruction template instead, or load an instruct/chat model. "
2026-03-05 15:46:45 +01:00
return
2026-03-05 18:39:37 +01:00
else :
# Load custom instruction template and set on tokenizer
template_str = load_template ( format )
if not template_str :
yield f " Error: could not load instruction template ' { format } ' . "
return
shared . tokenizer . chat_template = template_str
2023-03-29 16:48:17 +02:00
2026-03-05 18:39:37 +01:00
# Unified path — both cases use tokenize_conversation()
train_template [ " template_type " ] = " chat_template "
2023-06-25 20:34:46 +02:00
2026-03-05 18:39:37 +01:00
logger . info ( " Loading JSON dataset with chat template format " )
data = load_dataset ( " json " , data_files = clean_path ( ' user_data/training/datasets ' , f ' { dataset } .json ' ) )
2023-03-28 03:17:42 +02:00
2026-03-05 18:39:37 +01:00
# Validate the first row
try :
normalize_messages ( data [ ' train ' ] [ 0 ] )
except ( RuntimeError , KeyError , IndexError ) as e :
yield f " Error: { e } "
return
2023-03-28 07:15:32 +02:00
2026-03-05 18:56:27 +01:00
total = len ( data [ ' train ' ] )
2026-03-05 18:39:37 +01:00
train_data = data [ ' train ' ] . map (
tokenize_conversation ,
remove_columns = data [ ' train ' ] . column_names ,
new_fingerprint = ' %030x ' % random . randrange ( 16 * * 30 )
)
2026-03-05 18:56:27 +01:00
train_data = train_data . filter ( lambda x : len ( x [ ' input_ids ' ] ) > 0 )
dropped = total - len ( train_data )
if dropped > 0 :
logger . warning ( f " Dropped { dropped } / { total } conversations exceeding cutoff length of { cutoff_len } tokens. " )
if len ( train_data ) == 0 :
yield f " Error: all { total } conversations exceed the cutoff length of { cutoff_len } tokens. Increase the cutoff length or shorten your data. "
return
2026-03-05 18:39:37 +01:00
if eval_dataset == ' None ' :
eval_data = None
else :
eval_data = load_dataset ( " json " , data_files = clean_path ( ' user_data/training/datasets ' , f ' { eval_dataset } .json ' ) )
eval_data = eval_data [ ' train ' ] . map (
2026-03-05 15:46:45 +01:00
tokenize_conversation ,
2026-03-05 18:39:37 +01:00
remove_columns = eval_data [ ' train ' ] . column_names ,
2026-03-05 15:46:45 +01:00
new_fingerprint = ' %030x ' % random . randrange ( 16 * * 30 )
)
2026-03-05 18:56:27 +01:00
eval_data = eval_data . filter ( lambda x : len ( x [ ' input_ids ' ] ) > 0 )
2026-03-05 20:15:16 +01:00
else :
yield " No dataset selected. Choose a Chat Dataset or a Text Dataset. "
return
2023-03-28 07:15:32 +02:00
2023-07-12 20:29:43 +02:00
# == We MUST reload model if it went through any previous training, even failed one ==
if shared . model_dirty_from_training :
selected_model = shared . model_name
if selected_model :
print ( " \033 [1;31;1m(Model has been modified by previous training, it needs to be reloaded...) \033 [0;37;0m " )
try :
yield f " Reloading { selected_model } ... "
2023-08-18 21:58:38 +02:00
reload_model ( )
2023-07-12 20:29:43 +02:00
if shared . model is not None :
print ( " Model reloaded OK, continue with training. " )
else :
2026-03-05 18:41:44 +01:00
yield f " Failed to load { selected_model } . "
return
2026-03-04 22:06:17 +01:00
except Exception :
2023-07-12 20:29:43 +02:00
exc = traceback . format_exc ( )
logger . error ( ' Failed to reload the model. ' )
print ( exc )
2026-03-05 18:41:44 +01:00
yield exc . replace ( ' \n ' , ' \n \n ' )
return
2023-07-12 20:29:43 +02:00
2023-03-27 19:25:08 +02:00
# == Start prepping the model itself ==
2023-03-25 20:57:36 +01:00
if not hasattr ( shared . model , ' lm_head ' ) or hasattr ( shared . model . lm_head , ' weight ' ) :
2023-12-20 05:54:32 +01:00
logger . info ( " Getting model ready " )
2024-01-04 00:42:20 +01:00
if ' quantization_config ' in shared . model . config . to_dict ( ) :
prepare_model_for_kbit_training ( shared . model )
2023-04-07 05:15:45 +02:00
2023-07-12 20:29:43 +02:00
# base model is now frozen and should not be reused for any other LoRA training than this one
shared . model_dirty_from_training = True
2023-12-20 05:54:32 +01:00
logger . info ( " Preparing for training " )
2026-03-05 14:49:08 +01:00
target_modules = list_target_modules ( )
if not target_modules :
yield " No target modules selected. Enable at least one module or check ' Target all linear layers ' . "
return
2023-03-25 20:08:26 +01:00
config = LoraConfig (
2023-03-28 03:17:42 +02:00
r = lora_rank ,
lora_alpha = lora_alpha ,
2026-03-05 14:49:08 +01:00
target_modules = target_modules ,
2023-03-28 03:17:42 +02:00
lora_dropout = lora_dropout ,
2023-03-25 20:08:26 +01:00
bias = " none " ,
task_type = " CAUSAL_LM "
)
2023-03-29 16:55:34 +02:00
2023-06-27 23:24:04 +02:00
# == Backup the existing adapter ==
if not always_override :
backup_adapter ( lora_file_path )
2023-07-03 22:38:36 +02:00
# == get model trainable params
model_trainable_params , model_all_params = calc_trainable_parameters ( shared . model )
2026-03-05 19:32:49 +01:00
# == Determine if we can resume from a checkpoint ==
resume_checkpoint = None
2023-03-29 16:55:34 +02:00
try :
2023-12-20 05:54:32 +01:00
logger . info ( " Creating LoRA model " )
2023-04-20 00:39:03 +02:00
lora_model = get_peft_model ( shared . model , config )
2026-03-05 19:32:49 +01:00
if not always_override and Path ( lora_file_path ) . exists ( ) :
# Look for HF Trainer checkpoint dirs (full resumption)
checkpoints = sorted ( Path ( lora_file_path ) . glob ( " checkpoint-* " ) , key = os . path . getmtime )
if checkpoints :
resume_checkpoint = str ( checkpoints [ - 1 ] )
logger . info ( f " Will resume from checkpoint: { resume_checkpoint } " )
else :
# Legacy fallback: load bare adapter weights only
safetensors_path = Path ( f " { lora_file_path } /adapter_model.safetensors " )
bin_path = Path ( f " { lora_file_path } /adapter_model.bin " )
if safetensors_path . is_file ( ) :
logger . info ( " Loading existing LoRA data (safetensors) " )
from safetensors . torch import load_file
state_dict_peft = load_file ( str ( safetensors_path ) )
set_peft_model_state_dict ( lora_model , state_dict_peft )
elif bin_path . is_file ( ) :
logger . info ( " Loading existing LoRA data (bin) " )
state_dict_peft = torch . load ( str ( bin_path ) , weights_only = True )
set_peft_model_state_dict ( lora_model , state_dict_peft )
2026-03-04 22:06:17 +01:00
except Exception :
2023-08-03 15:57:21 +02:00
yield traceback . format_exc ( ) . replace ( ' \n ' , ' \n \n ' )
2023-03-29 16:55:34 +02:00
return
2023-04-16 07:35:13 +02:00
class Tracked ( ) :
def __init__ ( self ) :
self . current_steps = 0
self . max_steps = 0
2023-04-20 00:39:03 +02:00
self . did_save = False
2023-04-16 07:35:13 +02:00
tracked = Tracked ( )
2023-04-20 00:39:03 +02:00
actual_save_steps = math . ceil ( save_steps / gradient_accumulation_steps )
2023-04-16 07:35:13 +02:00
class Callbacks ( transformers . TrainerCallback ) :
def on_step_begin ( self , args : transformers . TrainingArguments , state : transformers . TrainerState , control : transformers . TrainerControl , * * kwargs ) :
2023-04-20 00:39:03 +02:00
tracked . current_steps = state . global_step * gradient_accumulation_steps
tracked . max_steps = state . max_steps * gradient_accumulation_steps
2023-04-16 07:35:13 +02:00
if WANT_INTERRUPT :
control . should_epoch_stop = True
control . should_training_stop = True
def on_substep_end ( self , args : transformers . TrainingArguments , state : transformers . TrainerState , control : transformers . TrainerControl , * * kwargs ) :
tracked . current_steps + = 1
if WANT_INTERRUPT :
control . should_epoch_stop = True
control . should_training_stop = True
2023-06-25 20:34:46 +02:00
2023-06-20 05:47:36 +02:00
def on_log ( self , args : transformers . TrainingArguments , state : transformers . TrainerState , control : transformers . TrainerControl , logs , * * kwargs ) :
train_log . update ( logs )
2023-06-25 20:34:46 +02:00
train_log . update ( { " current_steps " : tracked . current_steps } )
if WANT_INTERRUPT :
print ( " \033 [1;31;1mInterrupted by user \033 [0;37;0m " )
print ( f " \033 [1;30;40mStep: { tracked . current_steps } \033 [0;37;0m " , end = ' ' )
if ' loss ' in logs :
loss = float ( logs [ ' loss ' ] )
2026-03-05 18:41:44 +01:00
if stop_at_loss > 0 and loss < = stop_at_loss :
2023-06-25 20:34:46 +02:00
control . should_epoch_stop = True
control . should_training_stop = True
print ( f " \033 [1;31;1mStop Loss { stop_at_loss } reached. \033 [0;37;0m " )
2023-04-16 07:35:13 +02:00
2026-03-05 19:32:49 +01:00
def on_save ( self , args : transformers . TrainingArguments , state : transformers . TrainerState , control : transformers . TrainerControl , * * kwargs ) :
checkpoint_dir = Path ( args . output_dir ) / f " checkpoint- { state . global_step } "
if checkpoint_dir . exists ( ) :
with open ( checkpoint_dir / " training_log.json " , ' w ' , encoding = ' utf-8 ' ) as file :
json . dump ( train_log , file , indent = 2 )
with open ( checkpoint_dir / " training_prompt.json " , ' w ' , encoding = ' utf-8 ' ) as file :
json . dump ( train_template , file , indent = 2 )
2024-01-17 21:11:49 +01:00
# Fix training for mixed precision models
for param in shared . model . parameters ( ) :
if param . requires_grad :
param . data = param . data . float ( )
2026-03-05 14:49:08 +01:00
lora_model . config . use_cache = False
def collate_fn ( batch ) :
2026-03-05 16:36:10 +01:00
max_len = max ( len ( item [ ' input_ids ' ] ) for item in batch )
input_ids , labels , attention_mask = [ ] , [ ] , [ ]
for item in batch :
pad_len = max_len - len ( item [ ' input_ids ' ] )
input_ids . append ( item [ ' input_ids ' ] + [ shared . tokenizer . pad_token_id ] * pad_len )
labels . append ( item [ ' labels ' ] + [ - 100 ] * pad_len )
attention_mask . append ( item [ ' attention_mask ' ] + [ 0 ] * pad_len )
2026-03-05 14:49:08 +01:00
return {
2026-03-05 16:36:10 +01:00
' input_ids ' : torch . tensor ( input_ids ) ,
' labels ' : torch . tensor ( labels ) ,
' attention_mask ' : torch . tensor ( attention_mask ) ,
2026-03-05 14:49:08 +01:00
}
2023-03-25 20:08:26 +01:00
trainer = transformers . Trainer (
2023-03-28 03:17:42 +02:00
model = lora_model ,
2023-03-25 20:08:26 +01:00
train_dataset = train_data ,
2023-03-28 03:17:42 +02:00
eval_dataset = eval_data ,
2023-03-25 20:08:26 +01:00
args = transformers . TrainingArguments (
2026-03-05 14:49:08 +01:00
report_to = report_to if report_to != " None " else " none " ,
2023-03-28 03:17:42 +02:00
per_device_train_batch_size = micro_batch_size ,
gradient_accumulation_steps = gradient_accumulation_steps ,
2023-04-20 00:39:03 +02:00
warmup_steps = math . ceil ( warmup_steps / gradient_accumulation_steps ) ,
2023-03-25 20:08:26 +01:00
num_train_epochs = epochs ,
2023-03-28 03:17:42 +02:00
learning_rate = actual_lr ,
2024-01-04 00:42:20 +01:00
fp16 = False if shared . args . cpu or shared . args . bf16 else True ,
bf16 = shared . args . bf16 ,
2023-04-20 00:39:03 +02:00
optim = optimizer ,
2026-03-05 19:12:32 +01:00
logging_steps = 1 ,
2025-10-28 20:48:04 +01:00
eval_strategy = " steps " if eval_data is not None else " no " ,
2023-04-16 07:35:13 +02:00
eval_steps = math . ceil ( eval_steps / gradient_accumulation_steps ) if eval_data is not None else None ,
2026-03-05 19:32:49 +01:00
save_strategy = " steps " if save_steps > 0 or eval_data is not None else " no " ,
save_steps = actual_save_steps if save_steps > 0 else None ,
2023-04-16 07:35:13 +02:00
output_dir = lora_file_path ,
2023-04-20 00:39:03 +02:00
lr_scheduler_type = lr_scheduler_type ,
2023-05-16 18:40:19 +02:00
load_best_model_at_end = eval_data is not None ,
2023-03-25 20:08:26 +01:00
# TODO: Enable multi-device support
2023-04-10 22:29:00 +02:00
ddp_find_unused_parameters = None ,
2026-03-05 14:49:08 +01:00
use_cpu = shared . args . cpu ,
remove_unused_columns = False ,
2023-03-25 20:08:26 +01:00
) ,
2026-03-05 14:49:08 +01:00
data_collator = collate_fn ,
2026-03-05 18:41:44 +01:00
callbacks = [ Callbacks ( ) ]
2023-03-25 20:08:26 +01:00
)
2023-03-28 03:17:42 +02:00
2023-04-16 07:35:13 +02:00
# == Save parameters for reuse ==
with open ( f " { lora_file_path } /training_parameters.json " , ' w ' , encoding = ' utf-8 ' ) as file :
2026-03-05 18:41:44 +01:00
local_vars = locals ( )
json . dump ( { x : local_vars [ x ] for x in PARAMETERS } , file , indent = 2 )
2023-04-16 07:35:13 +02:00
2023-06-25 20:34:46 +02:00
# == Save training prompt ==
with open ( f " { lora_file_path } /training_prompt.json " , ' w ' , encoding = ' utf-8 ' ) as file :
json . dump ( train_template , file , indent = 2 )
2023-03-27 19:25:08 +02:00
# == Main run and monitor loop ==
2023-12-20 05:54:32 +01:00
logger . info ( " Starting training " )
2023-03-27 19:25:08 +02:00
yield " Starting... "
2023-06-25 20:34:46 +02:00
2023-07-03 22:38:36 +02:00
lora_trainable_param , lora_all_param = calc_trainable_parameters ( lora_model )
2026-03-05 14:49:08 +01:00
if target_modules == " all-linear " :
projections_string = " all-linear "
else :
projections_string = " , " . join ( [ projection . replace ( " _proj " , " " ) for projection in target_modules ] )
2023-07-12 20:29:43 +02:00
2026-03-05 14:49:08 +01:00
print ( f " Training ' { model_type } ' model using ( { projections_string } ) projections " )
2023-07-12 20:29:43 +02:00
2023-07-12 03:49:06 +02:00
if lora_all_param > 0 :
print ( f " Trainable params: { lora_trainable_param : ,d } ( { 100 * lora_trainable_param / lora_all_param : .4f } %), All params: { lora_all_param : ,d } (Model: { model_all_params : ,d } ) " )
2023-07-03 22:38:36 +02:00
2023-06-25 20:34:46 +02:00
train_log . update ( { " base_model_name " : shared . model_name } )
train_log . update ( { " base_model_class " : shared . model . __class__ . __name__ } )
train_log . update ( { " base_loaded_in_4bit " : getattr ( lora_model , " is_loaded_in_4bit " , False ) } )
train_log . update ( { " base_loaded_in_8bit " : getattr ( lora_model , " is_loaded_in_8bit " , False ) } )
2023-07-12 20:29:43 +02:00
train_log . update ( { " projections " : projections_string } )
2023-06-25 20:34:46 +02:00
if stop_at_loss > 0 :
print ( f " Monitoring loss \033 [1;31;1m(Auto-Stop at: { stop_at_loss } ) \033 [0;37;0m " )
2023-04-06 07:04:11 +02:00
if WANT_INTERRUPT :
yield " Interrupted before start. "
return
2023-07-12 16:44:30 +02:00
2023-07-12 16:26:45 +02:00
def log_train_dataset ( trainer ) :
decoded_entries = [ ]
# Try to decode the entries and write the log file
try :
# Iterate over the first 10 elements in the dataset (or fewer if there are less than 10)
for i in range ( min ( 10 , len ( trainer . train_dataset ) ) ) :
decoded_text = shared . tokenizer . decode ( trainer . train_dataset [ i ] [ ' input_ids ' ] )
decoded_entries . append ( { " value " : decoded_text } )
# Write the log file
2025-06-17 16:11:59 +02:00
Path ( ' user_data/logs ' ) . mkdir ( exist_ok = True )
2025-04-26 13:56:54 +02:00
with open ( Path ( ' user_data/logs/train_dataset_sample.json ' ) , ' w ' ) as json_file :
2023-07-12 16:26:45 +02:00
json . dump ( decoded_entries , json_file , indent = 4 )
2025-04-26 13:56:54 +02:00
logger . info ( " Log file ' train_dataset_sample.json ' created in the ' user_data/logs ' directory. " )
2023-07-12 16:26:45 +02:00
except Exception as e :
logger . error ( f " Failed to create log file due to error: { e } " )
2023-03-28 02:24:39 +02:00
2026-03-05 19:50:39 +01:00
thread_error = None
2023-04-06 07:04:11 +02:00
def threaded_run ( ) :
2026-03-05 19:50:39 +01:00
nonlocal thread_error
try :
log_train_dataset ( trainer )
trainer . train ( resume_from_checkpoint = resume_checkpoint )
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model . save_pretrained ( lora_file_path )
tracked . did_save = True
logger . info ( " LoRA training run is completed and saved. " )
# Save log
with open ( f " { lora_file_path } /training_log.json " , ' w ' , encoding = ' utf-8 ' ) as file :
json . dump ( train_log , file , indent = 2 )
except Exception as e :
thread_error = e
logger . error ( f " Training error: { e } " )
2023-03-28 02:24:39 +02:00
2023-04-06 07:04:11 +02:00
thread = threading . Thread ( target = threaded_run )
2023-03-27 19:25:08 +02:00
thread . start ( )
2023-04-06 07:04:11 +02:00
last_step = 0
start_time = time . perf_counter ( )
2023-03-28 02:24:39 +02:00
2023-03-27 19:25:08 +02:00
while thread . is_alive ( ) :
time . sleep ( 0.5 )
2023-03-27 19:43:01 +02:00
if WANT_INTERRUPT :
yield " Interrupting, please wait... *(Run will stop after the current training step completes.)* "
2023-04-06 07:04:11 +02:00
2023-04-16 07:35:13 +02:00
elif tracked . current_steps != last_step :
last_step = tracked . current_steps
2023-04-06 07:04:11 +02:00
time_elapsed = time . perf_counter ( ) - start_time
if time_elapsed < = 0 :
timer_info = " "
total_time_estimate = 999
2023-03-27 19:25:08 +02:00
else :
2023-04-16 07:35:13 +02:00
its = tracked . current_steps / time_elapsed
2023-03-27 19:25:08 +02:00
if its > 1 :
2023-04-06 07:04:11 +02:00
timer_info = f " ` { its : .2f } ` it/s "
2023-03-27 19:25:08 +02:00
else :
2023-04-06 07:04:11 +02:00
timer_info = f " ` { 1.0 / its : .2f } ` s/it "
2023-04-16 07:46:27 +02:00
2023-04-16 07:35:13 +02:00
total_time_estimate = ( 1.0 / its ) * ( tracked . max_steps )
2023-04-16 07:46:27 +02:00
2023-04-16 07:35:13 +02:00
yield f " Running... ** { tracked . current_steps } ** / ** { tracked . max_steps } ** ... { timer_info } , { format_time ( time_elapsed ) } / { format_time ( total_time_estimate ) } ... { format_time ( total_time_estimate - time_elapsed ) } remaining "
2023-03-28 02:24:39 +02:00
2026-03-05 19:50:39 +01:00
# Check for errors from the training thread
if thread_error is not None :
yield f " Training failed: { thread_error } "
return
2023-04-20 00:39:03 +02:00
# Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked . did_save :
2023-12-20 05:54:32 +01:00
logger . info ( " Training complete, saving " )
2023-04-20 00:39:03 +02:00
lora_model . save_pretrained ( lora_file_path )
2023-03-28 02:24:39 +02:00
2023-03-27 19:43:01 +02:00
if WANT_INTERRUPT :
2023-05-22 03:42:34 +02:00
logger . info ( " Training interrupted. " )
2023-08-18 21:58:38 +02:00
yield f " Interrupted. Incomplete LoRA saved to ` { lora_file_path } `. "
2023-03-27 19:43:01 +02:00
else :
2023-05-22 03:42:34 +02:00
logger . info ( " Training complete! " )
2023-08-18 21:58:38 +02:00
yield f " Done! LoRA saved to ` { lora_file_path } `. \n \n Before testing your new LoRA, make sure to first reload the model, as it is currently dirty from training. "
2023-03-28 07:15:32 +02:00
2023-04-07 05:15:45 +02:00
2023-04-06 07:04:11 +02:00
def format_time ( seconds : float ) :
if seconds < 120 :
return f " ` { seconds : .0f } ` seconds "
2023-04-16 07:46:27 +02:00
2023-04-06 07:04:11 +02:00
minutes = seconds / 60
if minutes < 120 :
return f " ` { minutes : .0f } ` minutes "
2023-04-16 07:46:27 +02:00
2023-04-06 07:04:11 +02:00
hours = minutes / 60
return f " ` { hours : .0f } ` hours "