diff --git a/modules/training.py b/modules/training.py index 6472e692..a762cbda 100644 --- a/modules/training.py +++ b/modules/training.py @@ -517,12 +517,14 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: if shared.model is not None: print("Model reloaded OK, continue with training.") else: - return f"Failed to load {selected_model}." + yield f"Failed to load {selected_model}." + return except Exception: exc = traceback.format_exc() logger.error('Failed to reload the model.') print(exc) - return exc.replace('\n', '\n\n') + yield exc.replace('\n', '\n\n') + return # == Start prepping the model itself == if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): @@ -614,7 +616,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='') if 'loss' in logs: loss = float(logs['loss']) - if loss <= stop_at_loss: + if stop_at_loss > 0 and loss <= stop_at_loss: control.should_epoch_stop = True control.should_training_stop = True print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m") @@ -670,13 +672,13 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: remove_unused_columns=False, ), data_collator=collate_fn, - callbacks=list([Callbacks()]) + callbacks=[Callbacks()] ) # == Save parameters for reuse == with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file: - vars = locals() - json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2) + local_vars = locals() + json.dump({x: local_vars[x] for x in PARAMETERS}, file, indent=2) # == Save training prompt == with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file: