From b131f865840aff5ccb7516535efc2c683f763cf1 Mon Sep 17 00:00:00 2001
From: SeanScripts <64337075+SeanScripts@users.noreply.github.com>
Date: Tue, 18 Feb 2025 06:56:28 -0800
Subject: [PATCH] Perplexity colors extension v2 (#6756)
---
extensions/perplexity_colors/script.py | 275 ++++++++++++++++++-------
1 file changed, 201 insertions(+), 74 deletions(-)
diff --git a/extensions/perplexity_colors/script.py b/extensions/perplexity_colors/script.py
index 2a986ac4..849e4e63 100644
--- a/extensions/perplexity_colors/script.py
+++ b/extensions/perplexity_colors/script.py
@@ -1,9 +1,14 @@
import time
+import html
+import functools
+import re
+
import gradio
import numpy as np
import torch
from transformers import LogitsProcessor
+import colorsys
from modules import html_generator, shared
@@ -28,7 +33,7 @@ class PerplexityLogits(LogitsProcessor):
self.verbose = verbose
def __call__(self, input_ids, scores):
- # t0 = time.time()
+ #t0 = time.time()
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
log_probs = torch.nan_to_num(torch.log(probs)) # Note: This is to convert log(0) nan to 0, but probs*log_probs makes this 0 not affect the perplexity.
entropy = -torch.sum(probs * log_probs)
@@ -42,9 +47,8 @@ class PerplexityLogits(LogitsProcessor):
if len(self.selected_probs) > 0:
# Is the selected token in the top tokens?
if self.verbose:
- print('Probs: Token after', shared.tokenizer.decode(last_token_id))
- print('Probs:', [shared.tokenizer.decode(token_id) for token_id in self.top_token_ids_list[-1][0]])
- print('Probs:', [round(float(prob), 4) for prob in self.top_probs_list[-1][0]])
+ print(shared.tokenizer.decode(last_token_id), [shared.tokenizer.decode(token_id) for token_id in self.top_token_ids_list[-1][0]],
+ [round(float(prob), 4) for prob in self.top_probs_list[-1][0]])
if last_token_id in self.top_token_ids_list[-1][0]:
idx = self.top_token_ids_list[-1][0].index(last_token_id)
self.selected_probs.append(self.top_probs_list[-1][0][idx])
@@ -60,7 +64,7 @@ class PerplexityLogits(LogitsProcessor):
pplbar = "-"
if not np.isnan(perplexity):
pplbar = "*" * round(perplexity)
- print(f"PPL: Token after {shared.tokenizer.decode(last_token_id)}\t{perplexity:.2f}\t{pplbar}")
+ print(f"PPL for token after {shared.tokenizer.decode(last_token_id)}: {perplexity:.2f} {pplbar}")
# Get top 5 probabilities
top_tokens_and_probs = torch.topk(probs, 5)
@@ -73,14 +77,15 @@ class PerplexityLogits(LogitsProcessor):
probs = probs.cpu().numpy().flatten()
self.last_probs = probs # Need to keep this as a reference for top probs
- # t1 = time.time()
- # print(f"PPL Processor: {(t1-t0):.3f} s")
+ #t1 = time.time()
+ #print(f"PPL Processor: {(t1-t0):.3f} s")
# About 1 ms, though occasionally up to around 100 ms, not sure why...
# Doesn't actually modify the logits!
return scores
# Stores the perplexity and top probabilities
+# global ppl_logits_processor
ppl_logits_processor = None
@@ -93,9 +98,9 @@ def logits_processor_modifier(logits_processor_list, input_ids):
def output_modifier(text):
global ppl_logits_processor
- # t0 = time.time()
+ #t0 = time.time()
- if not params['active']:
+ if not params['active'] or ppl_logits_processor is None:
return text
# TODO: It's probably more efficient to do this above rather than modifying all these lists
@@ -111,110 +116,147 @@ def output_modifier(text):
end_part = '' if params['probability_dropdown'] else '' # Helps with finding the index after replacing part of the text.
- i = 0
- for token, prob, ppl, top_tokens, top_probs in zip(gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list):
+ # Initial space added to deal with some tokenizers...
+ # Used to find where the message started generating, for working with "continue" generations
+ # Doesn't work for longer messages... Not sure how I should handle this
+ full_msg = shared.tokenizer.decode([token_id for token_id in gen_token_ids[:-1]]).strip()
+ # Space at the beginning to account for tokenization spaces...
+ text = ' ' + html.unescape(text)
+ # There was an issue with tab lengths being off by one...
+ # Seems like it might be model-dependent...
+ #text = re.sub(r'( {3,})', r'\1 ', text)
+ # Subtracting 2 to hopefully help with the tokenization spaces and continue issues,
+ # Though it's possible it could overwrite the previous token if it's the same in the last 2 chars
+ i = text.find(full_msg) - 2
+ if i < 0:
+ # Backup, try removing the extra whitespace (needed for continue)
+ i = text.find(full_msg.strip()) - 2
+ if i < 0:
+ i = 0
+
+ #i = 0
+ # Add token index for ability to regenerate from there
+ nonwhitespace_token_found = False
+ for index, token, prob, ppl, top_tokens, top_probs in zip(range(len(gen_tokens)), gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list):
+ # Somehow this works without issues, but not sure how...
+ if not nonwhitespace_token_found and token.strip() == '':
+ #print('Ignoring initial whitespace token...')
+ continue
+ nonwhitespace_token_found = True
+ max_prob = top_probs[0][0]
color = 'ffffff'
if params['color_by_probability'] and params['color_by_perplexity']:
- color = probability_perplexity_color_scale(prob, ppl)
+ color = probability_perplexity_color_scale(prob, max_prob, ppl)
elif params['color_by_perplexity']:
color = perplexity_color_scale(ppl)
elif params['color_by_probability']:
color = probability_color_scale(prob)
- if token in text[i:]:
+ if token.strip() in text[i:]:
if params['probability_dropdown']:
- text = text[:i] + text[i:].replace(token, add_dropdown_html(token, color, top_tokens, top_probs[0], ppl), 1)
+ text = text[:i] + text[i:].replace(token.replace('\n', ''), add_dropdown_html(token, index, color, top_tokens, top_probs[0], ppl), 1)
else:
- text = text[:i] + text[i:].replace(token, add_color_html(token, color), 1)
+ text = text[:i] + text[i:].replace(token.replace('\n', ''), add_color_html(token, color), 1)
+
+ # This might be slightly inefficient
i += text[i:].find(end_part) + len(end_part)
+ else:
+ print('Missing token:', token, '...', text[i:i+20])
# Use full perplexity list for calculating the average here.
- print('Average perplexity:', round(np.mean(ppl_logits_processor.perplexities_list[:-1]), 4))
- # t1 = time.time()
- # print(f"Modifier: {(t1-t0):.3f} s")
+ # Fix issue with mean of empty slice
+ if len(ppl_logits_processor.perplexities_list) > 1:
+ print('Average perplexity:', round(np.mean(ppl_logits_processor.perplexities_list[:-1]), 4))
+ #t1 = time.time()
+ #print(f"Output modifier: {(t1-t0):.3f} s")
# About 50 ms
- return text
+ return text.strip() # Remove extra beginning whitespace that some tokenizers add
def probability_color_scale(prob):
'''
Green-yellow-red color scale
'''
+ # hue (0.0 = red, 0.33 = green)
+ # saturation (0.0 = gray / white, 1.0 = normal, just leave at 1.0)
+ # brightness (0.0 = black, 1.0 = brightest, use something in between for better readability if you want...)
+ hue = prob * 0.33
+ rv, gv, bv = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
+ # to hex
+ hex_col = f"{int(rv*255):02x}{int(gv*255):02x}{int(bv*255):02x}"
- rv = 0
- gv = 0
- if prob <= 0.5:
- rv = 'ff'
- gv = hex(int(255 * prob * 2))[2:]
- if len(gv) < 2:
- gv = '0' * (2 - len(gv)) + gv
- else:
- rv = hex(int(255 - 255 * (prob - 0.5) * 2))[2:]
- gv = 'ff'
- if len(rv) < 2:
- rv = '0' * (2 - len(rv)) + rv
-
- return rv + gv + '00'
+ return hex_col
def perplexity_color_scale(ppl):
'''
Red component only, white for 0 perplexity (sorry if you're not in dark mode)
'''
- value = hex(max(int(255.0 - params['ppl_scale'] * (float(ppl) - 1.0)), 0))[2:]
- if len(value) < 2:
- value = '0' * (2 - len(value)) + value
+ # hue (0.0 = red)
+ # saturation (1.0 = red)
+ # brightness (0.0 = black, 1.0 = red)
+ # scale saturation from white to red the higher the perplexity
- return 'ff' + value + value
+ ppl = min(ppl, params['ppl_scale']) # clip ppl to 0-params['ppl_scale'] for color scaling. 15 should be fine for clipping and scaling
+ sat = ppl / params['ppl_scale']
+ rv, gv, bv = colorsys.hsv_to_rgb(0.0, sat, 1.0)
+
+ # to hex
+ hex_col = f"{int(rv*255):02x}{int(gv*255):02x}{int(bv*255):02x}"
+
+ return hex_col
-def probability_perplexity_color_scale(prob, ppl):
+def probability_perplexity_color_scale(prob, max_prob, ppl):
'''
- Green-yellow-red for probability and blue component for perplexity
+ Green-yellow-red for relative probability compared to maximum for the current token, and blue component for perplexity
'''
-
- rv = 0
- gv = 0
- bv = hex(min(max(int(params['ppl_scale'] * (float(ppl) - 1.0)), 0), 255))[2:]
- if len(bv) < 2:
- bv = '0' * (2 - len(bv)) + bv
-
- if prob <= 0.5:
- rv = 'ff'
- gv = hex(int(255 * prob * 2))[2:]
- if len(gv) < 2:
- gv = '0' * (2 - len(gv)) + gv
- else:
- rv = hex(int(255 - 255 * (prob - 0.5) * 2))[2:]
- gv = 'ff'
- if len(rv) < 2:
- rv = '0' * (2 - len(rv)) + rv
-
- return rv + gv + bv
+ hue = prob/max_prob * 0.33
+ rv, gv, _ = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
+
+ ppl = min(ppl, params['ppl_scale']) # clip ppl to 0-params['ppl_scale'] for color scaling. 15 should be fine for clipping and scaling
+ bv = ppl / params['ppl_scale']
+
+ # to hex
+ hex_col = f"{int(rv*255):02x}{int(gv*255):02x}{int(bv*255):02x}"
+
+ return hex_col
def add_color_html(token, color):
- return f'{token}'
+ output = ''
+ output += f'{html.escape(repr(token)[1:-1])}'
+ #if '\n' in token or '\r' in token: #token.isspace():
+ # output += '
'
+ return output
-# TODO: Major issue: Applying this to too many tokens will cause a permanent slowdown in generation speed until the messages are removed from the history.
+# TODO: Might also need message index for the click-to-regenerate feature to work... For now it only works in the last message, which I think is fine.
+
+# TODO: Major issue: Applying this to too many tokens will cause a permanent slowdown in generation speed until the messages are removed from the history. The slowdown seems to be mostly resolved in the current version though
# I think the issue is from HTML elements taking up space in the visible history, and things like history deepcopy add latency proportional to the size of the history.
# Potential solution is maybe to modify the main generation code to send just the internal text and not the visible history, to avoid moving too much around.
# I wonder if we can also avoid using deepcopy here.
-def add_dropdown_html(token, color, top_tokens, top_probs, perplexity=0):
- html = f'
| {token_option} | {prob:.4f} |
| {html.escape(repr(token_option))} | {prob:.4f} |
| Perplexity: | {perplexity:.4f} |
tags to preserve whitespace
# formatting. If you're coloring tokens by perplexity or probability, or especially if you're using
# the probability dropdown, you probably care more about seeing the tokens the model actually outputted
# rather than rendering ```code blocks``` or *italics*.
+@functools.lru_cache(maxsize=4096)
def convert_to_markdown(string):
return '' + string + '
'
+def convert_to_markdown_wrapped(string, use_cache=True):
+ if use_cache:
+ return convert_to_markdown(string)
+ return convert_to_markdown.__wrapped__(string)
+# This is still necessary for formatting to work correctly
html_generator.convert_to_markdown = convert_to_markdown
@@ -298,7 +425,7 @@ def ui():
def update_prob_dropdown_check(x):
params.update({'probability_dropdown': x})
- active_check = gradio.Checkbox(value=True, label="Compute probabilities and perplexity scores", info="Activate this extension. Note that this extension currently does not work with exllama or llama.cpp.")
+ active_check = gradio.Checkbox(value=True, label="Compute probabilities and perplexity scores", info="Activate this extension. Note that this extension currently does not work with llama.cpp, but it does work with ExLlamav2_HF and llamacpp_HF when set up correctly")
color_by_ppl_check = gradio.Checkbox(value=False, label="Color by perplexity", info="Higher perplexity is more red. If also showing probability, higher perplexity has more blue component.")
color_by_prob_check = gradio.Checkbox(value=False, label="Color by probability", info="Green-yellow-red linear scale, with 100% green, 50% yellow, 0% red.")
prob_dropdown_check = gradio.Checkbox(value=False, label="Probability dropdown", info="Hover over a token to show a dropdown of top token probabilities. Currently slightly buggy with whitespace between tokens.")