mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-07 06:03:51 +01:00
Training: fix checkpoint resume and surface training errors to UI
This commit is contained in:
parent
275810c843
commit
7a1fa8c9ea
|
|
@ -649,9 +649,6 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
|
||||
lora_model.config.use_cache = False
|
||||
|
||||
if sys.platform != "win32":
|
||||
lora_model = torch.compile(lora_model)
|
||||
|
||||
def collate_fn(batch):
|
||||
max_len = max(len(item['input_ids']) for item in batch)
|
||||
input_ids, labels, attention_mask = [], [], []
|
||||
|
|
@ -753,16 +750,23 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to create log file due to error: {e}")
|
||||
|
||||
thread_error = None
|
||||
|
||||
def threaded_run():
|
||||
log_train_dataset(trainer)
|
||||
trainer.train(resume_from_checkpoint=resume_checkpoint)
|
||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
tracked.did_save = True
|
||||
logger.info("LoRA training run is completed and saved.")
|
||||
# Save log
|
||||
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
||||
json.dump(train_log, file, indent=2)
|
||||
nonlocal thread_error
|
||||
try:
|
||||
log_train_dataset(trainer)
|
||||
trainer.train(resume_from_checkpoint=resume_checkpoint)
|
||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
tracked.did_save = True
|
||||
logger.info("LoRA training run is completed and saved.")
|
||||
# Save log
|
||||
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
||||
json.dump(train_log, file, indent=2)
|
||||
except Exception as e:
|
||||
thread_error = e
|
||||
logger.error(f"Training error: {e}")
|
||||
|
||||
thread = threading.Thread(target=threaded_run)
|
||||
thread.start()
|
||||
|
|
@ -791,6 +795,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
|
||||
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||
|
||||
# Check for errors from the training thread
|
||||
if thread_error is not None:
|
||||
yield f"Training failed: {thread_error}"
|
||||
return
|
||||
|
||||
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||
if not tracked.did_save:
|
||||
logger.info("Training complete, saving")
|
||||
|
|
|
|||
Loading…
Reference in a new issue