From 1d5a015ce778270968e01adb6a30a99c6b431f76 Mon Sep 17 00:00:00 2001 From: Johan Date: Sat, 21 Oct 2023 06:54:06 +0200 Subject: [PATCH] Enable special token support for exllamav2 (#4314) --- modules/exllamav2.py | 8 ++++---- modules/loaders.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index a75ede46..0287a177 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -64,7 +64,7 @@ class Exllamav2Model: return result, result def encode(self, string, **kwargs): - return self.tokenizer.encode(string, add_bos=True) + return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True) def decode(self, ids, **kwargs): if isinstance(ids, list): @@ -72,7 +72,7 @@ class Exllamav2Model: elif isinstance(ids, torch.Tensor) and ids.numel() == 1: ids = ids.view(1, -1) - return self.tokenizer.decode(ids)[0] + return self.tokenizer.decode(ids, decode_special_tokens=True)[0] def get_logits(self, token_ids, **kwargs): self.cache.current_seq_len = 0 @@ -97,7 +97,7 @@ class Exllamav2Model: if len(to_ban) > 0: settings.disallow_tokens(self.tokenizer, to_ban) - ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token']) + ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True) ids = ids[:, -get_max_prompt_length(state):] initial_len = ids.shape[-1] @@ -119,7 +119,7 @@ class Exllamav2Model: if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): has_leading_space = True - decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0] + decoded_text = self.tokenizer.decode(ids[:, initial_len:], decode_special_tokens=not state['skip_special_tokens'])[0] if has_leading_space: decoded_text = ' ' + decoded_text diff --git a/modules/loaders.py b/modules/loaders.py index ab10e0a4..bd3f04af 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -231,6 +231,7 @@ loaders_samplers = { 'ban_eos_token', 'add_bos_token', 'custom_token_bans', + 'skip_special_tokens', 'auto_max_new_tokens', }, 'ExLlamav2_HF': {