TensorRT-LLM: Migrate from ModelRunner to LLM API, add concurrent API request support

This commit is contained in:
oobabooga 2026-03-05 18:09:45 -08:00
parent 9824c82cb6
commit f52d9336e5
7 changed files with 50 additions and 89 deletions

View file

@ -1,15 +1,10 @@
from pathlib import Path
import tensorrt_llm
import torch
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.llmapi import SamplingParams
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import (
get_max_prompt_length,
get_reply_from_output_ids
)
class TensorRTLLMModel:
@ -18,91 +13,50 @@ class TensorRTLLMModel:
@classmethod
def from_pretrained(cls, path_to_model):
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
runtime_rank = tensorrt_llm.mpi_rank()
# Define model settings
runner_kwargs = dict(
engine_dir=str(path_to_model),
lora_dir=None,
rank=runtime_rank,
debug_mode=False,
lora_ckpt_source="hf",
llm = LLM(
model=str(path_to_model),
skip_tokenizer_init=False,
)
if shared.args.cpp_runner:
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
runner_kwargs.update(
max_batch_size=1,
max_beam_width=1,
)
else:
logger.info("TensorRT-LLM: Using \"ModelRunner\"")
# Load the model
runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
runner = runner_cls.from_dir(**runner_kwargs)
result = cls()
result.model = runner
result.runtime_rank = runtime_rank
result.llm = llm
result.tokenizer = llm.tokenizer
return result
def generate_with_streaming(self, prompt, state):
batch_input_ids = []
input_ids = shared.tokenizer.encode(
prompt,
add_special_tokens=True,
truncation=False,
sampling_params = SamplingParams(
max_tokens=state['max_new_tokens'] if not state['auto_max_new_tokens']
else state['truncation_length'] - len(shared.tokenizer.encode(prompt)),
end_id=shared.tokenizer.eos_token_id,
temperature=state['temperature'],
top_k=state['top_k'],
top_p=state['top_p'],
min_p=state['min_p'],
repetition_penalty=state['repetition_penalty'],
presence_penalty=state['presence_penalty'],
frequency_penalty=state['frequency_penalty'],
no_repeat_ngram_size=state['no_repeat_ngram_size'] if state['no_repeat_ngram_size'] > 0 else None,
seed=state['seed'],
ignore_eos=state['ban_eos_token'],
add_special_tokens=state['add_bos_token'],
skip_special_tokens=state['skip_special_tokens'],
)
input_ids = torch.tensor(input_ids, dtype=torch.int32)
input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length
batch_input_ids.append(input_ids)
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
else:
max_new_tokens = state['max_new_tokens']
with torch.no_grad():
generator = self.model.generate(
batch_input_ids,
max_new_tokens=max_new_tokens,
end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
temperature=state['temperature'],
top_k=state['top_k'],
top_p=state['top_p'],
repetition_penalty=state['repetition_penalty'],
presence_penalty=state['presence_penalty'],
frequency_penalty=state['frequency_penalty'],
stop_words_list=None,
bad_words_list=None,
lora_uids=None,
prompt_table=None,
prompt_tasks=None,
streaming=True,
output_sequence_lengths=True,
return_dict=True,
)
torch.cuda.synchronize()
stop_event = state.get('stop_event')
result = self.llm.generate_async(prompt, sampling_params=sampling_params, streaming=True)
cumulative_reply = ''
starting_from = batch_input_ids[0].shape[-1]
for curr_outputs in generator:
if shared.stop_everything:
for output in result:
if shared.stop_everything or (stop_event and stop_event.is_set()):
result.abort()
break
sequence_length = curr_outputs['sequence_lengths'][0].item()
output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
starting_from = sequence_length
yield cumulative_reply
text_diff = output.outputs[0].text_diff
if text_diff:
cumulative_reply += text_diff
yield cumulative_reply
def generate(self, prompt, state):
output = ''
@ -110,3 +64,8 @@ class TensorRTLLMModel:
pass
return output
def unload(self):
if hasattr(self, 'llm') and self.llm is not None:
self.llm.shutdown()
self.llm = None