mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-07 22:23:51 +01:00
Training: fix silent error on model reload failure, minor cleanups
This commit is contained in:
parent
5b18be8582
commit
c2e494963f
|
|
@ -517,12 +517,14 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
if shared.model is not None:
|
||||
print("Model reloaded OK, continue with training.")
|
||||
else:
|
||||
return f"Failed to load {selected_model}."
|
||||
yield f"Failed to load {selected_model}."
|
||||
return
|
||||
except Exception:
|
||||
exc = traceback.format_exc()
|
||||
logger.error('Failed to reload the model.')
|
||||
print(exc)
|
||||
return exc.replace('\n', '\n\n')
|
||||
yield exc.replace('\n', '\n\n')
|
||||
return
|
||||
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
|
|
@ -614,7 +616,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='')
|
||||
if 'loss' in logs:
|
||||
loss = float(logs['loss'])
|
||||
if loss <= stop_at_loss:
|
||||
if stop_at_loss > 0 and loss <= stop_at_loss:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
|
||||
|
|
@ -670,13 +672,13 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
|||
remove_unused_columns=False,
|
||||
),
|
||||
data_collator=collate_fn,
|
||||
callbacks=list([Callbacks()])
|
||||
callbacks=[Callbacks()]
|
||||
)
|
||||
|
||||
# == Save parameters for reuse ==
|
||||
with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
|
||||
vars = locals()
|
||||
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
|
||||
local_vars = locals()
|
||||
json.dump({x: local_vars[x] for x in PARAMETERS}, file, indent=2)
|
||||
|
||||
# == Save training prompt ==
|
||||
with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
|
||||
|
|
|
|||
Loading…
Reference in a new issue