mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-05 06:35:15 +00:00
TensorRT-LLM: Migrate from ModelRunner to LLM API, add concurrent API request support
This commit is contained in:
parent
9824c82cb6
commit
f52d9336e5
7 changed files with 50 additions and 89 deletions
|
|
@ -1,15 +1,10 @@
|
|||
from pathlib import Path
|
||||
|
||||
import tensorrt_llm
|
||||
import torch
|
||||
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm.llmapi import SamplingParams
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import (
|
||||
get_max_prompt_length,
|
||||
get_reply_from_output_ids
|
||||
)
|
||||
|
||||
|
||||
class TensorRTLLMModel:
|
||||
|
|
@ -18,91 +13,50 @@ class TensorRTLLMModel:
|
|||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_to_model):
|
||||
|
||||
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
|
||||
# Define model settings
|
||||
runner_kwargs = dict(
|
||||
engine_dir=str(path_to_model),
|
||||
lora_dir=None,
|
||||
rank=runtime_rank,
|
||||
debug_mode=False,
|
||||
lora_ckpt_source="hf",
|
||||
llm = LLM(
|
||||
model=str(path_to_model),
|
||||
skip_tokenizer_init=False,
|
||||
)
|
||||
|
||||
if shared.args.cpp_runner:
|
||||
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
|
||||
runner_kwargs.update(
|
||||
max_batch_size=1,
|
||||
max_beam_width=1,
|
||||
)
|
||||
else:
|
||||
logger.info("TensorRT-LLM: Using \"ModelRunner\"")
|
||||
|
||||
# Load the model
|
||||
runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
|
||||
runner = runner_cls.from_dir(**runner_kwargs)
|
||||
|
||||
result = cls()
|
||||
result.model = runner
|
||||
result.runtime_rank = runtime_rank
|
||||
|
||||
result.llm = llm
|
||||
result.tokenizer = llm.tokenizer
|
||||
return result
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
batch_input_ids = []
|
||||
input_ids = shared.tokenizer.encode(
|
||||
prompt,
|
||||
add_special_tokens=True,
|
||||
truncation=False,
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=state['max_new_tokens'] if not state['auto_max_new_tokens']
|
||||
else state['truncation_length'] - len(shared.tokenizer.encode(prompt)),
|
||||
end_id=shared.tokenizer.eos_token_id,
|
||||
temperature=state['temperature'],
|
||||
top_k=state['top_k'],
|
||||
top_p=state['top_p'],
|
||||
min_p=state['min_p'],
|
||||
repetition_penalty=state['repetition_penalty'],
|
||||
presence_penalty=state['presence_penalty'],
|
||||
frequency_penalty=state['frequency_penalty'],
|
||||
no_repeat_ngram_size=state['no_repeat_ngram_size'] if state['no_repeat_ngram_size'] > 0 else None,
|
||||
seed=state['seed'],
|
||||
ignore_eos=state['ban_eos_token'],
|
||||
add_special_tokens=state['add_bos_token'],
|
||||
skip_special_tokens=state['skip_special_tokens'],
|
||||
)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int32)
|
||||
input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length
|
||||
batch_input_ids.append(input_ids)
|
||||
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
with torch.no_grad():
|
||||
generator = self.model.generate(
|
||||
batch_input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
|
||||
pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
|
||||
temperature=state['temperature'],
|
||||
top_k=state['top_k'],
|
||||
top_p=state['top_p'],
|
||||
repetition_penalty=state['repetition_penalty'],
|
||||
presence_penalty=state['presence_penalty'],
|
||||
frequency_penalty=state['frequency_penalty'],
|
||||
stop_words_list=None,
|
||||
bad_words_list=None,
|
||||
lora_uids=None,
|
||||
prompt_table=None,
|
||||
prompt_tasks=None,
|
||||
streaming=True,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
stop_event = state.get('stop_event')
|
||||
result = self.llm.generate_async(prompt, sampling_params=sampling_params, streaming=True)
|
||||
|
||||
cumulative_reply = ''
|
||||
starting_from = batch_input_ids[0].shape[-1]
|
||||
|
||||
for curr_outputs in generator:
|
||||
if shared.stop_everything:
|
||||
for output in result:
|
||||
if shared.stop_everything or (stop_event and stop_event.is_set()):
|
||||
result.abort()
|
||||
break
|
||||
|
||||
sequence_length = curr_outputs['sequence_lengths'][0].item()
|
||||
output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()
|
||||
|
||||
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
|
||||
starting_from = sequence_length
|
||||
yield cumulative_reply
|
||||
text_diff = output.outputs[0].text_diff
|
||||
if text_diff:
|
||||
cumulative_reply += text_diff
|
||||
yield cumulative_reply
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ''
|
||||
|
|
@ -110,3 +64,8 @@ class TensorRTLLMModel:
|
|||
pass
|
||||
|
||||
return output
|
||||
|
||||
def unload(self):
|
||||
if hasattr(self, 'llm') and self.llm is not None:
|
||||
self.llm.shutdown()
|
||||
self.llm = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue