mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-07 14:13:49 +01:00
Training: drop conversations exceeding cutoff length instead of truncating
This commit is contained in:
parent
c2e494963f
commit
33a38d7ece
|
|
@ -26,7 +26,7 @@ from modules.evaluate import (
|
|||
from modules.logging_colors import logger
|
||||
from modules.models import reload_model
|
||||
|
||||
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "report_to"]
|
||||
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to"]
|
||||
WANT_INTERRUPT = False
|
||||
|
||||
train_log = {}
|
||||
|
|
@ -99,6 +99,7 @@ def create_ui():
|
|||
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
|
||||
|
||||
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each document in text datasets.")
|
||||
excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown'])
|
||||
|
||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
||||
|
|
@ -158,7 +159,7 @@ def create_ui():
|
|||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
|
||||
|
||||
# Training events
|
||||
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, report_to]
|
||||
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to]
|
||||
|
||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||
start_button.click(do_train, all_params, output)
|
||||
|
|
@ -292,7 +293,7 @@ def calc_trainable_parameters(model):
|
|||
return trainable_params, all_param
|
||||
|
||||
|
||||
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, report_to: str):
|
||||
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str):
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
|
@ -389,10 +390,12 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
end = min(len(through_ids), len(full_ids))
|
||||
labels[start:end] = full_ids[start:end]
|
||||
|
||||
# Truncate from the right: keeps the system prompt and early turns
|
||||
if len(full_ids) > cutoff_len:
|
||||
full_ids = full_ids[:cutoff_len]
|
||||
labels = labels[:cutoff_len]
|
||||
if excess_length == 'truncate':
|
||||
full_ids = full_ids[:cutoff_len]
|
||||
labels = labels[:cutoff_len]
|
||||
else:
|
||||
return {"input_ids": [], "labels": [], "attention_mask": []}
|
||||
|
||||
return {
|
||||
"input_ids": full_ids,
|
||||
|
|
@ -490,11 +493,19 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
yield f"Error: {e}"
|
||||
return
|
||||
|
||||
total = len(data['train'])
|
||||
train_data = data['train'].map(
|
||||
tokenize_conversation,
|
||||
remove_columns=data['train'].column_names,
|
||||
new_fingerprint='%030x' % random.randrange(16**30)
|
||||
)
|
||||
train_data = train_data.filter(lambda x: len(x['input_ids']) > 0)
|
||||
dropped = total - len(train_data)
|
||||
if dropped > 0:
|
||||
logger.warning(f"Dropped {dropped}/{total} conversations exceeding cutoff length of {cutoff_len} tokens.")
|
||||
if len(train_data) == 0:
|
||||
yield f"Error: all {total} conversations exceed the cutoff length of {cutoff_len} tokens. Increase the cutoff length or shorten your data."
|
||||
return
|
||||
|
||||
if eval_dataset == 'None':
|
||||
eval_data = None
|
||||
|
|
@ -505,6 +516,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
remove_columns=eval_data['train'].column_names,
|
||||
new_fingerprint='%030x' % random.randrange(16**30)
|
||||
)
|
||||
eval_data = eval_data.filter(lambda x: len(x['input_ids']) > 0)
|
||||
|
||||
# == We MUST reload model if it went through any previous training, even failed one ==
|
||||
if shared.model_dirty_from_training:
|
||||
|
|
|
|||
Loading…
Reference in a new issue