From 7a1fa8c9ea7ddc74821ff8ddf17bd790dacdd40c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:50:39 -0300 Subject: [PATCH] Training: fix checkpoint resume and surface training errors to UI --- modules/training.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/modules/training.py b/modules/training.py index 45520830..0cb29dce 100644 --- a/modules/training.py +++ b/modules/training.py @@ -649,9 +649,6 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: lora_model.config.use_cache = False - if sys.platform != "win32": - lora_model = torch.compile(lora_model) - def collate_fn(batch): max_len = max(len(item['input_ids']) for item in batch) input_ids, labels, attention_mask = [], [], [] @@ -753,16 +750,23 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: except Exception as e: logger.error(f"Failed to create log file due to error: {e}") + thread_error = None + def threaded_run(): - log_train_dataset(trainer) - trainer.train(resume_from_checkpoint=resume_checkpoint) - # Note: save in the thread in case the gradio thread breaks (eg browser closed) - lora_model.save_pretrained(lora_file_path) - tracked.did_save = True - logger.info("LoRA training run is completed and saved.") - # Save log - with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file: - json.dump(train_log, file, indent=2) + nonlocal thread_error + try: + log_train_dataset(trainer) + trainer.train(resume_from_checkpoint=resume_checkpoint) + # Note: save in the thread in case the gradio thread breaks (eg browser closed) + lora_model.save_pretrained(lora_file_path) + tracked.did_save = True + logger.info("LoRA training run is completed and saved.") + # Save log + with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file: + json.dump(train_log, file, indent=2) + except Exception as e: + thread_error = e + logger.error(f"Training error: {e}") thread = threading.Thread(target=threaded_run) thread.start() @@ -791,6 +795,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining" + # Check for errors from the training thread + if thread_error is not None: + yield f"Training failed: {thread_error}" + return + # Saving in the train thread might fail if an error occurs, so save here if so. if not tracked.did_save: logger.info("Training complete, saving")