mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-05 06:35:15 +00:00
Add TensorRT-LLM support (#5715)
This commit is contained in:
parent
536f8d58d4
commit
577a8cd3ee
9 changed files with 197 additions and 4 deletions
|
|
@ -77,6 +77,7 @@ def load_model(model_name, loader=None):
|
|||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||
'AutoAWQ': AutoAWQ_loader,
|
||||
'HQQ': HQQ_loader,
|
||||
'TensorRT-LLM': TensorRT_LLM_loader,
|
||||
}
|
||||
|
||||
metadata = get_model_metadata(model_name)
|
||||
|
|
@ -101,7 +102,7 @@ def load_model(model_name, loader=None):
|
|||
tokenizer = load_tokenizer(model_name, model)
|
||||
|
||||
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
||||
if loader.lower().startswith('exllama'):
|
||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'):
|
||||
shared.settings['truncation_length'] = shared.args.max_seq_len
|
||||
elif loader in ['llama.cpp', 'llamacpp_HF']:
|
||||
shared.settings['truncation_length'] = shared.args.n_ctx
|
||||
|
|
@ -337,6 +338,13 @@ def HQQ_loader(model_name):
|
|||
return model
|
||||
|
||||
|
||||
def TensorRT_LLM_loader(model_name):
|
||||
from modules.tensorrt_llm import TensorRTLLMModel
|
||||
|
||||
model = TensorRTLLMModel.from_pretrained(model_name)
|
||||
return model
|
||||
|
||||
|
||||
def get_max_memory_dict():
|
||||
max_memory = {}
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue