From 518e5c4244b1d373d616ab32215b2f1c195deae8 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 13 Mar 2023 16:45:08 -0300 Subject: [PATCH] Some minor fixes to the GPTQ loader --- modules/quant_loader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/quant_loader.py b/modules/quant_loader.py index 7a5f8461..c2723490 100644 --- a/modules/quant_loader.py +++ b/modules/quant_loader.py @@ -7,6 +7,8 @@ import torch import modules.shared as shared sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) +import llama +import opt def load_quantized(model_name): @@ -21,9 +23,9 @@ def load_quantized(model_name): model_type = shared.args.gptq_model_type.lower() if model_type == 'llama': - from llama import load_quant + load_quant = llama.load_quant elif model_type == 'opt': - from opt import load_quant + load_quant = opt.load_quant else: print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") exit() @@ -50,7 +52,7 @@ def load_quantized(model_name): print(f"Could not find {pt_model}, exiting...") exit() - model = load_quant(path_to_model, str(pt_path), shared.args.gptq_bits) + model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits) # Multiple GPUs or GPU+CPU if shared.args.gpu_memory: