mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
fix: Pass bos_token and eos_token from metadata to jinja2
Fixes loading Seed-Instruct-36B
This commit is contained in:
parent
15c6e43597
commit
b4f06a50b0
|
|
@ -112,7 +112,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
add_generation_prompt=False,
|
||||
enable_thinking=state['enable_thinking'],
|
||||
reasoning_effort=state['reasoning_effort'],
|
||||
thinking_budget=-1 if state.get('enable_thinking', True) else 0
|
||||
thinking_budget=-1 if state.get('enable_thinking', True) else 0,
|
||||
bos_token=shared.bos_token,
|
||||
eos_token=shared.eos_token,
|
||||
)
|
||||
|
||||
chat_renderer = partial(
|
||||
|
|
@ -475,7 +477,7 @@ def get_stopping_strings(state):
|
|||
|
||||
if state['mode'] in ['instruct', 'chat-instruct']:
|
||||
template = jinja_env.from_string(state['instruction_template_str'])
|
||||
renderer = partial(template.render, add_generation_prompt=False)
|
||||
renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token)
|
||||
renderers.append(renderer)
|
||||
|
||||
if state['mode'] in ['chat']:
|
||||
|
|
|
|||
|
|
@ -89,8 +89,9 @@ def get_model_metadata(model):
|
|||
else:
|
||||
bos_token = ""
|
||||
|
||||
template = template.replace('eos_token', "'{}'".format(eos_token))
|
||||
template = template.replace('bos_token', "'{}'".format(bos_token))
|
||||
|
||||
shared.bos_token = bos_token
|
||||
shared.eos_token = eos_token
|
||||
|
||||
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
|
||||
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
|
||||
|
|
@ -160,13 +161,16 @@ def get_model_metadata(model):
|
|||
|
||||
# 4. If a template was found from any source, process it
|
||||
if template:
|
||||
shared.bos_token = '<s>'
|
||||
shared.eos_token = '</s>'
|
||||
|
||||
for k in ['eos_token', 'bos_token']:
|
||||
if k in metadata:
|
||||
value = metadata[k]
|
||||
if isinstance(value, dict):
|
||||
value = value['content']
|
||||
|
||||
template = template.replace(k, "'{}'".format(value))
|
||||
setattr(shared, k, value)
|
||||
|
||||
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
|
||||
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ is_seq2seq = False
|
|||
is_multimodal = False
|
||||
model_dirty_from_training = False
|
||||
lora_names = []
|
||||
bos_token = '<s>'
|
||||
eos_token = '</s>'
|
||||
|
||||
# Image model variables
|
||||
image_model = None
|
||||
|
|
|
|||
Loading…
Reference in a new issue