diff --git a/extensions/perplexity_colors/script.py b/extensions/perplexity_colors/script.py index 849e4e63..d032cebd 100644 --- a/extensions/perplexity_colors/script.py +++ b/extensions/perplexity_colors/script.py @@ -96,23 +96,42 @@ def logits_processor_modifier(logits_processor_list, input_ids): logits_processor_list.append(ppl_logits_processor) +def get_last_token(text, tokens_list, token_ids_list, token_probs_list): + for token, token_id, prob in zip(tokens_list, token_ids_list, token_probs_list): + if text.strip().endswith(token.strip()): # Whitespace could be a problem + return token, token_id, prob + # Unknown? + print("Last token not found in list:", tokens_list) + return '', -1, 0.0 + + def output_modifier(text): global ppl_logits_processor #t0 = time.time() + original_text = text if not params['active'] or ppl_logits_processor is None: return text + # Space at the beginning to account for tokenization spaces... + text = ' ' + html.unescape(text) + # TODO: It's probably more efficient to do this above rather than modifying all these lists # Remove last element of perplexities_list, top_token_ids_list, top_tokens_list, top_probs_list since everything is off by one because this extension runs before generation - perplexities = ppl_logits_processor.perplexities_list[:-1] - top_token_ids_list = ppl_logits_processor.top_token_ids_list[:-1] + perplexities = ppl_logits_processor.perplexities_list + top_token_ids_list = ppl_logits_processor.top_token_ids_list top_tokens_list = [[shared.tokenizer.decode(token_id) for token_id in top_token_ids[0]] for top_token_ids in top_token_ids_list] - top_probs_list = ppl_logits_processor.top_probs_list[:-1] + top_probs_list = ppl_logits_processor.top_probs_list # Remove first element of generated_token_ids, generated_tokens, selected_probs because they are for the last token of the prompt gen_token_ids = ppl_logits_processor.generated_token_ids[1:] + # Add last sampled token, if possible (it could be past the end of the top 5 list) + last_token, last_token_id, last_prob = get_last_token(text, top_tokens_list[-1], top_token_ids_list[-1][0], top_probs_list[-1][0]) + if last_token_id != -1: + gen_token_ids.append(last_token_id) gen_tokens = [shared.tokenizer.decode(token_id) for token_id in gen_token_ids] sel_probs = ppl_logits_processor.selected_probs[1:] + if last_token_id != -1: + sel_probs.append(last_prob) end_part = '' if params['probability_dropdown'] else '' # Helps with finding the index after replacing part of the text. @@ -120,8 +139,7 @@ def output_modifier(text): # Used to find where the message started generating, for working with "continue" generations # Doesn't work for longer messages... Not sure how I should handle this full_msg = shared.tokenizer.decode([token_id for token_id in gen_token_ids[:-1]]).strip() - # Space at the beginning to account for tokenization spaces... - text = ' ' + html.unescape(text) + # There was an issue with tab lengths being off by one... # Seems like it might be model-dependent... #text = re.sub(r'( {3,})', r'\1 ', text) @@ -137,6 +155,7 @@ def output_modifier(text): #i = 0 # Add token index for ability to regenerate from there nonwhitespace_token_found = False + missing_token_count = 0 for index, token, prob, ppl, top_tokens, top_probs in zip(range(len(gen_tokens)), gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list): # Somehow this works without issues, but not sure how... if not nonwhitespace_token_found and token.strip() == '': @@ -153,14 +172,20 @@ def output_modifier(text): color = probability_color_scale(prob) if token.strip() in text[i:]: if params['probability_dropdown']: - text = text[:i] + text[i:].replace(token.replace('\n', ''), add_dropdown_html(token, index, color, top_tokens, top_probs[0], ppl), 1) + text = text[:i] + text[i:].replace(token.replace('\n', ''), add_dropdown_html(token, index, i, color, top_tokens, top_probs[0], ppl), 1) else: text = text[:i] + text[i:].replace(token.replace('\n', ''), add_color_html(token, color), 1) # This might be slightly inefficient i += text[i:].find(end_part) + len(end_part) else: + missing_token_count += 1 print('Missing token:', token, '...', text[i:i+20]) + # If there are any missing tokens, then either the tokenization was off, or this is the start of a conversation, or something else went wrong + if missing_token_count > 5: + print("Canceling token coloring...") + return original_text + # Use full perplexity list for calculating the average here. # Fix issue with mean of empty slice @@ -236,11 +261,11 @@ def add_color_html(token, color): # I think the issue is from HTML elements taking up space in the visible history, and things like history deepcopy add latency proportional to the size of the history. # Potential solution is maybe to modify the main generation code to send just the internal text and not the visible history, to avoid moving too much around. # I wonder if we can also avoid using deepcopy here. -def add_dropdown_html(token, index, color, top_tokens, top_probs, perplexity=0): +def add_dropdown_html(token, index, msg_position, color, top_tokens, top_probs, perplexity=0): #print("Token:", token, token.isspace(), '\n' in token or '\r' in token) output = '' # Use the repr to get characters like \n visible. Exclude the quotes around it - output += f'
{html.escape(repr(token)[1:-1])}