diff --git a/modules/chat.py b/modules/chat.py index acfc2f66..d1474cfe 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -112,7 +112,9 @@ def generate_chat_prompt(user_input, state, **kwargs): add_generation_prompt=False, enable_thinking=state['enable_thinking'], reasoning_effort=state['reasoning_effort'], - thinking_budget=-1 if state.get('enable_thinking', True) else 0 + thinking_budget=-1 if state.get('enable_thinking', True) else 0, + bos_token=shared.bos_token, + eos_token=shared.eos_token, ) chat_renderer = partial( @@ -475,7 +477,7 @@ def get_stopping_strings(state): if state['mode'] in ['instruct', 'chat-instruct']: template = jinja_env.from_string(state['instruction_template_str']) - renderer = partial(template.render, add_generation_prompt=False) + renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token) renderers.append(renderer) if state['mode'] in ['chat']: diff --git a/modules/models_settings.py b/modules/models_settings.py index 6dc000b4..d333e269 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -89,8 +89,9 @@ def get_model_metadata(model): else: bos_token = "" - template = template.replace('eos_token', "'{}'".format(eos_token)) - template = template.replace('bos_token', "'{}'".format(bos_token)) + + shared.bos_token = bos_token + shared.eos_token = eos_token template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL) template = re.sub(r'raise_exception\([^)]*\)', "''", template) @@ -160,13 +161,16 @@ def get_model_metadata(model): # 4. If a template was found from any source, process it if template: + shared.bos_token = '' + shared.eos_token = '' + for k in ['eos_token', 'bos_token']: if k in metadata: value = metadata[k] if isinstance(value, dict): value = value['content'] - template = template.replace(k, "'{}'".format(value)) + setattr(shared, k, value) template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL) template = re.sub(r'raise_exception\([^)]*\)', "''", template) diff --git a/modules/shared.py b/modules/shared.py index 2f39e495..7b572dec 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -19,6 +19,8 @@ is_seq2seq = False is_multimodal = False model_dirty_from_training = False lora_names = [] +bos_token = '' +eos_token = '' # Image model variables image_model = None