Instruction template: make "Send to default/notebook" work without a tokenizer

This commit is contained in:
oobabooga 2024-02-16 07:59:09 -08:00
parent f465b7b486
commit 0e1d8d5601
2 changed files with 34 additions and 33 deletions

View file

@ -166,53 +166,54 @@ def generate_chat_prompt(user_input, state, **kwargs):
prompt = remove_extra_bos(prompt)
return prompt
# Handle truncation
max_length = get_max_prompt_length(state)
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
while len(messages) > 0 and encoded_length > max_length:
# Handle truncation
if shared.tokenizer is not None:
max_length = get_max_prompt_length(state)
encoded_length = get_encoded_length(prompt)
while len(messages) > 0 and encoded_length > max_length:
# Remove old message, save system message
if len(messages) > 2 and messages[0]['role'] == 'system':
messages.pop(1)
# Remove old message, save system message
if len(messages) > 2 and messages[0]['role'] == 'system':
messages.pop(1)
# Remove old message when no system message is present
elif len(messages) > 1 and messages[0]['role'] != 'system':
messages.pop(0)
# Remove old message when no system message is present
elif len(messages) > 1 and messages[0]['role'] != 'system':
messages.pop(0)
# Resort to truncating the user input
else:
# Resort to truncating the user input
else:
user_message = messages[-1]['content']
user_message = messages[-1]['content']
# Bisect the truncation point
left, right = 0, len(user_message) - 1
# Bisect the truncation point
left, right = 0, len(user_message) - 1
while right - left > 1:
mid = (left + right) // 2
while right - left > 1:
mid = (left + right) // 2
messages[-1]['content'] = user_message[mid:]
messages[-1]['content'] = user_message[mid:]
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
if encoded_length <= max_length:
right = mid
else:
left = mid
messages[-1]['content'] = user_message[right:]
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
if encoded_length <= max_length:
right = mid
if encoded_length > max_length:
logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n")
raise ValueError
else:
left = mid
logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}, available context length: {max_length}.")
break
messages[-1]['content'] = user_message[right:]
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
if encoded_length > max_length:
logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n")
raise ValueError
else:
logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}, available context length: {max_length}.")
break
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
if also_return_rows:
return prompt, [message['content'] for message in messages]