diff --git a/modules/chat.py b/modules/chat.py index 9290dd62..827b6050 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -643,6 +643,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess output = apply_extensions('history', output) state = apply_extensions('state', state) + # Let the jinja2 template handle the BOS token + if state['mode'] in ['instruct', 'chat-instruct']: + state['add_bos_token'] = False + # Initialize metadata if not present if 'metadata' not in output: output['metadata'] = {} diff --git a/modules/text_generation.py b/modules/text_generation.py index a75141f1..8d1950b9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -134,7 +134,6 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) - if hasattr(shared.tokenizer, 'bos_token_id') and shared.tokenizer.bos_token_id is not None: if add_bos_token: # Add BOS token if missing @@ -142,13 +141,9 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]]) input_ids = torch.cat((bos_tensor, input_ids), 1) - # Prevent double BOS tokens from jinja templates - while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id: - input_ids = input_ids[:, 1:] - else: - # Remove BOS tokens when not wanted - while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id: - input_ids = input_ids[:, 1:] + # Always prevent double BOS tokens (regardless of add_bos_token setting) + while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id: + input_ids = input_ids[:, 1:] if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 5c92f32e..45bef5f9 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -84,7 +84,7 @@ def create_ui(): shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') - shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') + shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Only applies to text completion (notebook). In chat mode, templates control BOS tokens.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache', info='Use a static cache for improved performance.')