mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 22:27:29 +00:00
Backend cleanup (#6025)
This commit is contained in:
parent
6a1682aa95
commit
bd7cc4234d
23 changed files with 57 additions and 442 deletions
|
|
@ -73,13 +73,11 @@ def load_model(model_name, loader=None):
|
|||
load_func_map = {
|
||||
'Transformers': huggingface_loader,
|
||||
'AutoGPTQ': AutoGPTQ_loader,
|
||||
'GPTQ-for-LLaMa': GPTQ_loader,
|
||||
'llama.cpp': llamacpp_loader,
|
||||
'llamacpp_HF': llamacpp_HF_loader,
|
||||
'ExLlamav2': ExLlamav2_loader,
|
||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||
'AutoAWQ': AutoAWQ_loader,
|
||||
'QuIP#': QuipSharp_loader,
|
||||
'HQQ': HQQ_loader,
|
||||
}
|
||||
|
||||
|
|
@ -310,55 +308,6 @@ def AutoAWQ_loader(model_name):
|
|||
return model
|
||||
|
||||
|
||||
def QuipSharp_loader(model_name):
|
||||
try:
|
||||
with RelativeImport("repositories/quip-sharp"):
|
||||
from lib.utils.unsafe_import import model_from_hf_path
|
||||
except:
|
||||
logger.error(
|
||||
"\nQuIP# has not been found. It must be installed manually for now.\n"
|
||||
"For instructions on how to do that, please consult:\n"
|
||||
"https://github.com/oobabooga/text-generation-webui/pull/4803\n"
|
||||
)
|
||||
return None, None
|
||||
|
||||
# This fixes duplicate logging messages after the import above.
|
||||
handlers = logging.getLogger().handlers
|
||||
if len(handlers) > 1:
|
||||
logging.getLogger().removeHandler(handlers[1])
|
||||
|
||||
model_dir = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if not all((model_dir / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']):
|
||||
logger.error(f"Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into {model_dir}: special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json.")
|
||||
return None, None
|
||||
|
||||
model, model_str = model_from_hf_path(
|
||||
model_dir,
|
||||
use_cuda_graph=False,
|
||||
use_flash_attn=not shared.args.no_flash_attn
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def GPTQ_loader(model_name):
|
||||
|
||||
# Monkey patch
|
||||
if shared.args.monkey_patch:
|
||||
logger.warning("Applying the monkey patch for using LoRAs with GPTQ models. It may cause undefined behavior outside its intended scope.")
|
||||
from modules.monkey_patch_gptq_lora import load_model_llama
|
||||
|
||||
model, _ = load_model_llama(model_name)
|
||||
|
||||
# No monkey patch
|
||||
else:
|
||||
import modules.GPTQ_loader
|
||||
|
||||
model = modules.GPTQ_loader.load_quantized(model_name)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def AutoGPTQ_loader(model_name):
|
||||
import modules.AutoGPTQ_loader
|
||||
|
||||
|
|
@ -380,12 +329,12 @@ def ExLlamav2_HF_loader(model_name):
|
|||
|
||||
def HQQ_loader(model_name):
|
||||
from hqq.core.quantize import HQQBackend, HQQLinear
|
||||
from hqq.engine.hf import HQQModelForCausalLM
|
||||
from hqq.models.hf.base import AutoHQQHFModel
|
||||
|
||||
logger.info(f"Loading HQQ model with backend: \"{shared.args.hqq_backend}\"")
|
||||
|
||||
model_dir = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
model = HQQModelForCausalLM.from_quantized(str(model_dir))
|
||||
model = AutoHQQHFModel.from_quantized(str(model_dir))
|
||||
HQQLinear.set_backend(getattr(HQQBackend, shared.args.hqq_backend))
|
||||
return model
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue