diff --git a/modules/chat.py b/modules/chat.py
index acfc2f66..d1474cfe 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -112,7 +112,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
add_generation_prompt=False,
enable_thinking=state['enable_thinking'],
reasoning_effort=state['reasoning_effort'],
- thinking_budget=-1 if state.get('enable_thinking', True) else 0
+ thinking_budget=-1 if state.get('enable_thinking', True) else 0,
+ bos_token=shared.bos_token,
+ eos_token=shared.eos_token,
)
chat_renderer = partial(
@@ -475,7 +477,7 @@ def get_stopping_strings(state):
if state['mode'] in ['instruct', 'chat-instruct']:
template = jinja_env.from_string(state['instruction_template_str'])
- renderer = partial(template.render, add_generation_prompt=False)
+ renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token)
renderers.append(renderer)
if state['mode'] in ['chat']:
diff --git a/modules/models_settings.py b/modules/models_settings.py
index 6dc000b4..d333e269 100644
--- a/modules/models_settings.py
+++ b/modules/models_settings.py
@@ -89,8 +89,9 @@ def get_model_metadata(model):
else:
bos_token = ""
- template = template.replace('eos_token', "'{}'".format(eos_token))
- template = template.replace('bos_token', "'{}'".format(bos_token))
+
+ shared.bos_token = bos_token
+ shared.eos_token = eos_token
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
@@ -160,13 +161,16 @@ def get_model_metadata(model):
# 4. If a template was found from any source, process it
if template:
+ shared.bos_token = ''
+ shared.eos_token = ''
+
for k in ['eos_token', 'bos_token']:
if k in metadata:
value = metadata[k]
if isinstance(value, dict):
value = value['content']
- template = template.replace(k, "'{}'".format(value))
+ setattr(shared, k, value)
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
diff --git a/modules/shared.py b/modules/shared.py
index 2f39e495..7b572dec 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -19,6 +19,8 @@ is_seq2seq = False
is_multimodal = False
model_dirty_from_training = False
lora_names = []
+bos_token = ''
+eos_token = ''
# Image model variables
image_model = None