diff --git a/modules/chat.py b/modules/chat.py index 1a16a689..c10d91a7 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -108,7 +108,14 @@ def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=Tru suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0] suffix = prompt.split("<<|user-message-2|>>")[1] - prefix = suffix_plus_prefix[len(suffix):] + + # Remove the message suffix. The first case handles the GPT-OSS model + # in a way that is likely to not interfere with previous models. + if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix: + start_index = suffix_plus_prefix.rindex('<|start|>') + prefix = suffix_plus_prefix[start_index:] + else: + prefix = suffix_plus_prefix[len(suffix):] if strip_trailing_spaces: prefix = prefix.rstrip(' ')