Add TensorRT-LLM support (#5715)

This commit is contained in:
oobabooga 2024-06-24 02:30:03 -03:00 committed by GitHub
parent 536f8d58d4
commit 577a8cd3ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 197 additions and 4 deletions

View file

@ -77,6 +77,7 @@ def load_model(model_name, loader=None):
'ExLlamav2_HF': ExLlamav2_HF_loader,
'AutoAWQ': AutoAWQ_loader,
'HQQ': HQQ_loader,
'TensorRT-LLM': TensorRT_LLM_loader,
}
metadata = get_model_metadata(model_name)
@ -101,7 +102,7 @@ def load_model(model_name, loader=None):
tokenizer = load_tokenizer(model_name, model)
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
if loader.lower().startswith('exllama'):
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'):
shared.settings['truncation_length'] = shared.args.max_seq_len
elif loader in ['llama.cpp', 'llamacpp_HF']:
shared.settings['truncation_length'] = shared.args.n_ctx
@ -337,6 +338,13 @@ def HQQ_loader(model_name):
return model
def TensorRT_LLM_loader(model_name):
from modules.tensorrt_llm import TensorRTLLMModel
model = TensorRTLLMModel.from_pretrained(model_name)
return model
def get_max_memory_dict():
max_memory = {}
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'