fix: Pass bos_token and eos_token from metadata to jinja2

Fixes loading Seed-Instruct-36B
This commit is contained in:
oobabooga 2025-12-04 19:11:31 -08:00
parent 15c6e43597
commit b4f06a50b0
3 changed files with 13 additions and 5 deletions

View file

@ -112,7 +112,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
add_generation_prompt=False,
enable_thinking=state['enable_thinking'],
reasoning_effort=state['reasoning_effort'],
thinking_budget=-1 if state.get('enable_thinking', True) else 0
thinking_budget=-1 if state.get('enable_thinking', True) else 0,
bos_token=shared.bos_token,
eos_token=shared.eos_token,
)
chat_renderer = partial(
@ -475,7 +477,7 @@ def get_stopping_strings(state):
if state['mode'] in ['instruct', 'chat-instruct']:
template = jinja_env.from_string(state['instruction_template_str'])
renderer = partial(template.render, add_generation_prompt=False)
renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token)
renderers.append(renderer)
if state['mode'] in ['chat']:

View file

@ -89,8 +89,9 @@ def get_model_metadata(model):
else:
bos_token = ""
template = template.replace('eos_token', "'{}'".format(eos_token))
template = template.replace('bos_token', "'{}'".format(bos_token))
shared.bos_token = bos_token
shared.eos_token = eos_token
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
@ -160,13 +161,16 @@ def get_model_metadata(model):
# 4. If a template was found from any source, process it
if template:
shared.bos_token = '<s>'
shared.eos_token = '</s>'
for k in ['eos_token', 'bos_token']:
if k in metadata:
value = metadata[k]
if isinstance(value, dict):
value = value['content']
template = template.replace(k, "'{}'".format(value))
setattr(shared, k, value)
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
template = re.sub(r'raise_exception\([^)]*\)', "''", template)

View file

@ -19,6 +19,8 @@ is_seq2seq = False
is_multimodal = False
model_dirty_from_training = False
lora_names = []
bos_token = '<s>'
eos_token = '</s>'
# Image model variables
image_model = None