Training: UI cleanup and better defaults

This commit is contained in:
oobabooga 2026-03-05 11:15:16 -08:00
parent 33ff3773a0
commit 86d8291e58
2 changed files with 95 additions and 47 deletions

View file

@ -53,8 +53,8 @@ def create_ui():
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
with gr.Accordion(label='Target Modules', open=False, elem_classes='tgw-accordion'):
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM requirements and adapter size.\n\"Target all linear layers\" uses peft's `all-linear` option, which targets every `nn.Linear` layer except `lm_head` and works for any model architecture. Uncheck it to manually select individual projection modules below.")
all_linear = gr.Checkbox(label='Target all linear layers', value=True, info='Targets every nn.Linear layer except lm_head. Works for any model architecture.', elem_classes=['no-background'])
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM and adapter size.")
all_linear = gr.Checkbox(label='Target all linear layers', value=True, info='Targets every nn.Linear layer except lm_head. Works for any model architecture. When checked, the individual module checkboxes below are ignored.', elem_classes=['no-background'])
with gr.Row():
with gr.Column():
q_proj_en = gr.Checkbox(label='Enable q_proj', value=True)
@ -75,17 +75,17 @@ def create_ui():
with gr.Column():
lora_rank = gr.Slider(label='LoRA Rank', value=8, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
lora_alpha = gr.Slider(label='LoRA Alpha', value=16, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
batch_size = gr.Slider(label='Batch Size', value=32, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=512, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=512, step=32, info='Maximum sequence length in tokens. For instruction datasets, conversations longer than this are dropped. For text datasets, documents are split into chunks of this size. Higher values require more VRAM.')
with gr.Column():
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a full training checkpoint (adapter weights, optimizer, scheduler) will be saved every time this many steps pass. Training can be resumed from these checkpoints.')
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
with gr.Row():
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='cosine', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown'])
with gr.Accordion(label='Advanced Options', open=False, elem_classes='tgw-accordion'):
with gr.Row():
@ -93,10 +93,10 @@ def create_ui():
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.0, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
with gr.Row():
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.', elem_classes=['slim-dropdown'])
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Optimizer algorithm. adamw_torch is the standard choice. adamw_bnb_8bit uses less VRAM. adafactor is memory-efficient for large models.', elem_classes=['slim-dropdown'])
with gr.Column():
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate is gradually ramped up from 0 to the target value. This prevents unstable updates early in training.')
add_eos_token = gr.Checkbox(label='Add EOS token', value=True, info="Adds EOS token for each document in text datasets.")
excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown'])
@ -105,27 +105,27 @@ def create_ui():
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
with gr.Column():
with gr.Tab(label='Formatted Dataset'):
with gr.Tab(label='Chat Dataset'):
with gr.Row():
format = gr.Dropdown(choices=get_instruction_templates(), value='None', label='Data Format', info='Select an instruction template for formatting the dataset, or "Chat Template" to use the model\'s built-in chat template.', elem_classes=['slim-dropdown'], interactive=not mu)
dataset = gr.Dropdown(choices=utils.get_chat_datasets('user_data/training/datasets'), value='None', label='Dataset File', info='A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_chat_datasets('user_data/training/datasets')}, 'refresh-button', interactive=not mu)
with gr.Row():
format = gr.Dropdown(choices=get_instruction_templates(), value='None', label='Instruction Template', info='Select an instruction template for formatting the dataset, or "Chat Template" to use the model\'s built-in chat template.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_instruction_templates()}, 'refresh-button', interactive=not mu)
with gr.Row():
dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
with gr.Row():
eval_dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Tab(label="Text Dataset"):
with gr.Row():
text_dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Text Dataset', info='A JSONL file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
text_dataset = gr.Dropdown(choices=utils.get_text_datasets('user_data/training/datasets'), value='None', label='Dataset File', info='A JSON file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_text_datasets('user_data/training/datasets')}, 'refresh-button', interactive=not mu)
stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=0, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.')
stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=256, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.')
with gr.Row():
eval_dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu)
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Row():
start_button = gr.Button("Start LoRA Training", variant='primary', interactive=not mu)
@ -406,35 +406,28 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
train_template.clear()
# == Prep the dataset, format, etc ==
if text_dataset not in ['None', '']:
train_template["template_type"] = "text_dataset"
logger.info("Loading text dataset")
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{text_dataset}.json'))
has_text_dataset = text_dataset not in ['None', '']
has_chat_dataset = dataset not in ['None', '']
if has_text_dataset and has_chat_dataset:
yield "Error: select either a Chat Dataset or a Text Dataset, not both."
return
# Validate the first row has a "text" key
if "text" not in data['train'].column_names:
yield "Error: text dataset must have a \"text\" key per row."
return
# Tokenize each document and concatenate
def tokenize_text_data(data):
"""Tokenize text dataset rows, concatenate, and split into chunks."""
all_tokens = []
for row in data['train']:
for row in data:
tokens = shared.tokenizer.encode(row['text'])
if add_eos_token:
tokens.append(shared.tokenizer.eos_token_id)
all_tokens.extend(tokens)
# Split into chunks with optional overlap (stride)
stride = int(stride_length)
step = cutoff_len - stride if stride > 0 else cutoff_len
if step <= 0:
yield "Error: stride length must be smaller than cutoff length."
return
return None, "Error: stride length must be smaller than cutoff length."
if len(all_tokens) < cutoff_len:
yield "Error: dataset is too short to fill even one chunk of the given cutoff length."
return
return None, "Error: dataset is too short to fill even one chunk of the given cutoff length."
chunks = []
for start in range(0, len(all_tokens), step):
@ -442,7 +435,6 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
if len(chunk) == 0:
break
if len(chunk) < cutoff_len:
# Pad the remainder
pad_len = cutoff_len - len(chunk)
chunks.append({
"input_ids": chunk + [shared.tokenizer.pad_token_id] * pad_len,
@ -456,14 +448,34 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
"attention_mask": [1] * cutoff_len,
})
train_data = Dataset.from_list(chunks)
del all_tokens
eval_data = None
else:
if dataset in ['None', '']:
yield "Missing dataset choice input, cannot continue."
return Dataset.from_list(chunks), None
if has_text_dataset:
train_template["template_type"] = "text_dataset"
logger.info("Loading text dataset")
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{text_dataset}.json'))
if "text" not in data['train'].column_names:
yield "Error: text dataset must have a \"text\" key per row."
return
train_data, err = tokenize_text_data(data['train'])
if err:
yield err
return
if eval_dataset == 'None':
eval_data = None
else:
eval_raw = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
if "text" not in eval_raw['train'].column_names:
yield "Error: evaluation dataset must have a \"text\" key per row."
return
eval_data, err = tokenize_text_data(eval_raw['train'])
if err:
yield err
return
elif has_chat_dataset:
if format in ['None', '']:
yield "Missing format choice input, cannot continue."
return
@ -517,6 +529,9 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
new_fingerprint='%030x' % random.randrange(16**30)
)
eval_data = eval_data.filter(lambda x: len(x['input_ids']) > 0)
else:
yield "No dataset selected. Choose a Chat Dataset or a Text Dataset."
return
# == We MUST reload model if it went through any previous training, even failed one ==
if shared.model_dirty_from_training:

View file

@ -268,6 +268,39 @@ def get_datasets(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
def get_chat_datasets(path: str):
"""List JSON datasets that contain chat conversations (messages or ShareGPT format)."""
return ['None'] + sorted(set([k.stem for k in Path(path).glob('*.json') if k.stem != 'put-trainer-datasets-here' and _is_chat_dataset(k)]), key=natural_keys)
def get_text_datasets(path: str):
"""List JSON datasets that contain raw text ({"text": ...} format)."""
return ['None'] + sorted(set([k.stem for k in Path(path).glob('*.json') if k.stem != 'put-trainer-datasets-here' and _is_text_dataset(k)]), key=natural_keys)
def _peek_json_keys(filepath):
"""Read the first object in a JSON array file and return its keys."""
import json
try:
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict):
return set(data[0].keys())
except Exception:
pass
return set()
def _is_chat_dataset(filepath):
keys = _peek_json_keys(filepath)
return bool(keys & {'messages', 'conversations'})
def _is_text_dataset(filepath):
keys = _peek_json_keys(filepath)
return 'text' in keys
def get_available_chat_styles():
return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)