diff --git a/modules/chat.py b/modules/chat.py index cd82b813..023f5a3e 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -86,36 +86,6 @@ yaml.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter) -def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True): - ''' - Given a Jinja template, reverse-engineers the prefix and the suffix for - an assistant message (if impersonate=False) or an user message - (if impersonate=True) - ''' - - if impersonate: - messages = [ - {"role": "user", "content": "<<|user-message-1|>>"}, - {"role": "user", "content": "<<|user-message-2|>>"}, - ] - else: - messages = [ - {"role": "assistant", "content": "<<|user-message-1|>>"}, - {"role": "assistant", "content": "<<|user-message-2|>>"}, - ] - - prompt = renderer(messages=messages) - - suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0] - suffix = prompt.split("<<|user-message-2|>>")[1] - prefix = suffix_plus_prefix[len(suffix):] - - if strip_trailing_spaces: - prefix = prefix.rstrip(' ') - - return prefix, suffix - - def get_thinking_suppression_string(template): """ Determines what string needs to be added to suppress thinking mode @@ -341,26 +311,16 @@ def generate_chat_prompt(user_input, state, **kwargs): command = command.replace('<|prompt|>', prompt) command = replace_character_names(command, state['name1'], state['name2']) - if _continue: - prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0] - prefix += messages[-1]["content"] - else: - prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] - if not impersonate: - prefix = apply_extensions('bot_prefix', prefix, state) - - suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1] - outer_messages = [] if state['custom_system_message'].strip() != '': outer_messages.append({"role": "system", "content": state['custom_system_message']}) outer_messages.append({"role": "user", "content": command}) - outer_messages.append({"role": "assistant", "content": prefix}) - prompt = instruct_renderer(messages=outer_messages) - if len(suffix) > 0: - prompt = prompt[:-len(suffix)] + prompt = instruct_renderer( + messages=outer_messages, + add_generation_prompt=True + ) else: # Handle GPT-OSS as a special case when continuing # (otherwise the thinking block gets removed...) @@ -375,29 +335,10 @@ def generate_chat_prompt(user_input, state, **kwargs): assistant_reply_so_far += f"<|channel|>final<|message|>{last_message_to_continue.get('content', '')}" prompt += assistant_reply_so_far else: - prompt = renderer(messages=messages) - if _continue: - suffix = get_generation_prompt(renderer, impersonate=impersonate)[1] - if len(suffix) > 0: - prompt = prompt[:-len(suffix)] - else: - prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] - - # Handle GPT-OSS as a special case when not continuing - if '<|channel|>final<|message|>' in state['instruction_template_str']: - if prefix.endswith("<|channel|>final<|message|>"): - prefix = prefix[:-len("<|channel|>final<|message|>")] - - if impersonate: - prefix += "<|message|>" - - if state['mode'] == 'chat' and not impersonate: - prefix = apply_extensions('bot_prefix', prefix, state) - - prompt += prefix - - if state['mode'] == 'instruct' and 'enable_thinking' in state['instruction_template_str'] and not any((_continue, impersonate, state['enable_thinking'])): - prompt += get_thinking_suppression_string(instruction_template) + prompt = renderer( + messages=messages, + add_generation_prompt=True + ) return prompt @@ -523,24 +464,41 @@ def get_stopping_strings(state): renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2']) renderers.append(renderer) - for renderer in renderers: - prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False) - prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True) + fake_messages = [ + {"role": "user", "content": "first user message"}, + {"role": "assistant", "content": "first assistant message"}, + {"role": "user", "content": "second user message"}, + {"role": "assistant", "content": "second assistant message"}, + ] - stopping_strings += [ - suffix_user + prefix_bot, - suffix_user + prefix_user, - suffix_bot + prefix_bot, - suffix_bot + prefix_user, + stopping_strings = [] + for renderer in renderers: + prompt = renderer(messages=fake_messages) + + # Find positions of each message content + first_user_end = prompt.find("first user message") + len("first user message") + first_assistant_start = prompt.find("first assistant message") + first_assistant_end = prompt.find("first assistant message") + len("first assistant message") + second_user_start = prompt.find("second user message") + second_assistant_end = prompt.find("second assistant message") + len("second assistant message") + + # Extract pieces of text potentially containing unique stopping strings + texts = [ + prompt[first_user_end:first_assistant_start], + prompt[first_assistant_end:second_user_start], + prompt[second_assistant_end:] ] - # Try to find the EOT token - for item in stopping_strings.copy(): - item = item.strip() - if item.startswith("<") and ">" in item: - stopping_strings.append(item.split(">")[0] + ">") - elif item.startswith("[") and "]" in item: - stopping_strings.append(item.split("]")[0] + "]") + for text in texts: + text = text.strip() + if text.startswith("<") and ">" in text: + stopping_strings.append(text.split(">")[0] + ">") + elif text.startswith("[") and "]" in text: + stopping_strings.append(text.split("]")[0] + "]") + elif text.startswith("(") and ")" in text: + stopping_strings.append(text.split(")")[0] + ")") + elif text.startswith("{") and "}" in text: + stopping_strings.append(text.split("}")[0] + "}") if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): stopping_strings += state.pop('stopping_strings') @@ -549,12 +507,6 @@ def get_stopping_strings(state): result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)] result = list(set(result)) - # Handle GPT-OSS as a special case - if '<|channel|>final<|message|>' in state['instruction_template_str'] and "<|end|>" in result: - result.remove("<|end|>") - result.append("<|result|>") - result = list(set(result)) - if shared.args.verbose: logger.info("STOPPING_STRINGS=") pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result)