Training: fix checkpoint resume and surface training errors to UI

This commit is contained in:
oobabooga 2026-03-05 15:50:39 -03:00
parent 275810c843
commit 7a1fa8c9ea

View file

@ -649,9 +649,6 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
lora_model.config.use_cache = False
if sys.platform != "win32":
lora_model = torch.compile(lora_model)
def collate_fn(batch):
max_len = max(len(item['input_ids']) for item in batch)
input_ids, labels, attention_mask = [], [], []
@ -753,16 +750,23 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
except Exception as e:
logger.error(f"Failed to create log file due to error: {e}")
thread_error = None
def threaded_run():
log_train_dataset(trainer)
trainer.train(resume_from_checkpoint=resume_checkpoint)
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
tracked.did_save = True
logger.info("LoRA training run is completed and saved.")
# Save log
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
nonlocal thread_error
try:
log_train_dataset(trainer)
trainer.train(resume_from_checkpoint=resume_checkpoint)
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
tracked.did_save = True
logger.info("LoRA training run is completed and saved.")
# Save log
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2)
except Exception as e:
thread_error = e
logger.error(f"Training error: {e}")
thread = threading.Thread(target=threaded_run)
thread.start()
@ -791,6 +795,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
# Check for errors from the training thread
if thread_error is not None:
yield f"Training failed: {thread_error}"
return
# Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save:
logger.info("Training complete, saving")