mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 07:03:37 +00:00
Training: replace raw text file with JSONL text dataset, re-add stride overlap
- Replace "Raw text file" tab with "Text Dataset" tab using JSONL format with "text" key per row - Re-add stride overlap for chunking (configurable Stride Length slider, 0-2048 tokens) - Pad remainder chunks instead of dropping them - Remove hard_cut_string, min_chars, raw_text_file parameters - Remove .txt file and directory loading support
This commit is contained in:
parent
d278bb46a2
commit
da2d4f1a6a
2 changed files with 72 additions and 63 deletions
|
|
@ -26,7 +26,7 @@ from modules.logging_colors import logger
|
|||
from modules.models import reload_model
|
||||
from modules.utils import natural_keys
|
||||
|
||||
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
||||
PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "train_only_after", "stop_at_loss", "add_eos_token", "report_to"]
|
||||
WANT_INTERRUPT = False
|
||||
|
||||
train_log = {}
|
||||
|
|
@ -120,13 +120,12 @@ def create_ui():
|
|||
|
||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||
|
||||
with gr.Tab(label="Raw text file"):
|
||||
with gr.Tab(label="Text Dataset"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'txt')}, 'refresh-button', interactive=not mu)
|
||||
text_dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Text Dataset', info='A JSONL file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||
ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
|
||||
|
||||
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
||||
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
|
||||
stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=0, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.')
|
||||
|
||||
with gr.Row():
|
||||
start_button = gr.Button("Start LoRA Training", variant='primary', interactive=not mu)
|
||||
|
|
@ -160,7 +159,7 @@ def create_ui():
|
|||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
|
||||
|
||||
# Training events
|
||||
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
||||
all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, train_only_after, stop_at_loss, add_eos_token, report_to]
|
||||
|
||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||
start_button.click(do_train, all_params, output)
|
||||
|
|
@ -271,7 +270,7 @@ def calc_trainable_parameters(model):
|
|||
return trainable_params, all_param
|
||||
|
||||
|
||||
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
|
||||
def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, train_only_after: str, stop_at_loss: float, add_eos_token: bool, report_to: str):
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
|
@ -441,58 +440,57 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
train_template.clear()
|
||||
|
||||
# == Prep the dataset, format, etc ==
|
||||
if raw_text_file not in ['None', '']:
|
||||
train_template["template_type"] = "raw_text"
|
||||
logger.info("Loading raw text file dataset")
|
||||
fullpath = clean_path('user_data/training/datasets', f'{raw_text_file}')
|
||||
fullpath = Path(fullpath)
|
||||
if fullpath.is_dir():
|
||||
logger.info('Training path directory {}'.format(raw_text_file))
|
||||
raw_text = ""
|
||||
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
|
||||
for file_path in file_paths:
|
||||
if file_path.is_file():
|
||||
with file_path.open('r', encoding='utf-8') as file:
|
||||
raw_text += file.read().replace('\r', '')
|
||||
if text_dataset not in ['None', '']:
|
||||
train_template["template_type"] = "text_dataset"
|
||||
logger.info("Loading text dataset")
|
||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{text_dataset}.json'))
|
||||
|
||||
logger.info(f"Loaded training file: {file_path.name}")
|
||||
else:
|
||||
with open(clean_path('user_data/training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||
raw_text = file.read().replace('\r', '')
|
||||
|
||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
||||
eos_added = 0
|
||||
all_tokens = []
|
||||
for text_part in raw_text.split(cut_string):
|
||||
if len(text_part.strip()) <= min_chars:
|
||||
continue
|
||||
|
||||
tokens = shared.tokenizer.encode(text_part)
|
||||
if add_eos_token:
|
||||
tokens.append(shared.tokenizer.eos_token_id)
|
||||
eos_added += 1
|
||||
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
if eos_added > 0:
|
||||
print(f"EOS added to {eos_added} text blocks")
|
||||
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
|
||||
# Concatenate-and-split: non-overlapping chunks of exactly cutoff_len
|
||||
num_chunks = len(all_tokens) // cutoff_len
|
||||
if num_chunks == 0:
|
||||
yield "Error: text is too short to fill even one chunk of the given cutoff length."
|
||||
# Validate the first row has a "text" key
|
||||
if "text" not in data['train'].column_names:
|
||||
yield "Error: text dataset must have a \"text\" key per row."
|
||||
return
|
||||
|
||||
train_data = Dataset.from_list([
|
||||
{
|
||||
"input_ids": all_tokens[i * cutoff_len:(i + 1) * cutoff_len],
|
||||
"labels": all_tokens[i * cutoff_len:(i + 1) * cutoff_len],
|
||||
"attention_mask": [1] * cutoff_len,
|
||||
}
|
||||
for i in range(num_chunks)
|
||||
])
|
||||
# Tokenize each document and concatenate
|
||||
all_tokens = []
|
||||
for row in data['train']:
|
||||
tokens = shared.tokenizer.encode(row['text'])
|
||||
if add_eos_token:
|
||||
tokens.append(shared.tokenizer.eos_token_id)
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
# Split into chunks with optional overlap (stride)
|
||||
stride = int(stride_length)
|
||||
step = cutoff_len - stride if stride > 0 else cutoff_len
|
||||
|
||||
if step <= 0:
|
||||
yield "Error: stride length must be smaller than cutoff length."
|
||||
return
|
||||
|
||||
if len(all_tokens) < cutoff_len:
|
||||
yield "Error: dataset is too short to fill even one chunk of the given cutoff length."
|
||||
return
|
||||
|
||||
chunks = []
|
||||
for start in range(0, len(all_tokens), step):
|
||||
chunk = all_tokens[start:start + cutoff_len]
|
||||
if len(chunk) == 0:
|
||||
break
|
||||
if len(chunk) < cutoff_len:
|
||||
# Pad the remainder
|
||||
pad_len = cutoff_len - len(chunk)
|
||||
chunks.append({
|
||||
"input_ids": chunk + [shared.tokenizer.pad_token_id] * pad_len,
|
||||
"labels": list(chunk) + [-100] * pad_len,
|
||||
"attention_mask": [1] * len(chunk) + [0] * pad_len,
|
||||
})
|
||||
else:
|
||||
chunks.append({
|
||||
"input_ids": chunk,
|
||||
"labels": list(chunk),
|
||||
"attention_mask": [1] * cutoff_len,
|
||||
})
|
||||
|
||||
train_data = Dataset.from_list(chunks)
|
||||
del all_tokens
|
||||
eval_data = None
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue