mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-08 06:33:51 +01:00
72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
from pathlib import Path
|
|
|
|
from tensorrt_llm._tensorrt_engine import LLM
|
|
from tensorrt_llm.llmapi import SamplingParams
|
|
|
|
from modules import shared
|
|
from modules.logging_colors import logger
|
|
|
|
|
|
class TensorRTLLMModel:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, path_to_model):
|
|
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
|
|
|
|
llm = LLM(
|
|
model=str(path_to_model),
|
|
skip_tokenizer_init=False,
|
|
)
|
|
|
|
result = cls()
|
|
result.llm = llm
|
|
result.tokenizer = llm.tokenizer
|
|
return result
|
|
|
|
def generate_with_streaming(self, prompt, state):
|
|
sampling_params = SamplingParams(
|
|
max_tokens=state['max_new_tokens'] if not state['auto_max_new_tokens']
|
|
else state['truncation_length'] - len(shared.tokenizer.encode(prompt)),
|
|
end_id=shared.tokenizer.eos_token_id,
|
|
temperature=state['temperature'],
|
|
top_k=state['top_k'],
|
|
top_p=state['top_p'],
|
|
min_p=state['min_p'],
|
|
repetition_penalty=state['repetition_penalty'],
|
|
presence_penalty=state['presence_penalty'],
|
|
frequency_penalty=state['frequency_penalty'],
|
|
no_repeat_ngram_size=state['no_repeat_ngram_size'] if state['no_repeat_ngram_size'] > 0 else None,
|
|
seed=state['seed'],
|
|
ignore_eos=state['ban_eos_token'],
|
|
add_special_tokens=state['add_bos_token'],
|
|
skip_special_tokens=state['skip_special_tokens'],
|
|
)
|
|
|
|
stop_event = state.get('stop_event')
|
|
result = self.llm.generate_async(prompt, sampling_params=sampling_params, streaming=True)
|
|
|
|
cumulative_reply = ''
|
|
for output in result:
|
|
if shared.stop_everything or (stop_event and stop_event.is_set()):
|
|
result.abort()
|
|
break
|
|
|
|
text_diff = output.outputs[0].text_diff
|
|
if text_diff:
|
|
cumulative_reply += text_diff
|
|
yield cumulative_reply
|
|
|
|
def generate(self, prompt, state):
|
|
output = ''
|
|
for output in self.generate_with_streaming(prompt, state):
|
|
pass
|
|
|
|
return output
|
|
|
|
def unload(self):
|
|
if hasattr(self, 'llm') and self.llm is not None:
|
|
self.llm.shutdown()
|
|
self.llm = None
|