Address copilot feedback

This commit is contained in:
oobabooga 2026-03-12 19:55:02 -07:00
parent 24fdcc52b3
commit 04213dff14
6 changed files with 14 additions and 6 deletions

View file

@ -136,7 +136,7 @@ class LlamaServer:
logit_bias = []
if state['custom_token_bans']:
logit_bias.extend([[int(token_id), False] for token_id in state['custom_token_bans'].split(',')])
logit_bias.extend([[int(token_id.strip()), False] for token_id in state['custom_token_bans'].split(',') if token_id.strip()])
if state.get('logit_bias'):
for token_id_str, bias in state['logit_bias'].items():

View file

@ -431,7 +431,8 @@ def load_instruction_template(template):
else:
return ''
file_contents = open(filepath, 'r', encoding='utf-8').read()
with open(filepath, 'r', encoding='utf-8') as f:
file_contents = f.read()
data = yaml.safe_load(file_contents)
if 'instruction_template' in data:
return data['instruction_template']

View file

@ -378,7 +378,7 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
to_ban = [int(x.strip()) for x in state['custom_token_bans'].split(',') if x.strip()]
if len(to_ban) > 0:
if generate_params.get('suppress_tokens', None):
generate_params['suppress_tokens'] += to_ban

View file

@ -42,6 +42,9 @@ def load_tools(selected_names):
continue
func_name = tool_def.get('function', {}).get('name', name)
if func_name in executors:
logger.warning(f'Tool "{name}" declares function name "{func_name}" which conflicts with an already loaded tool. Skipping.')
continue
tool_defs.append(tool_def)
executors[func_name] = execute_fn

View file

@ -16,7 +16,11 @@ def _eval(node):
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return node.value
elif isinstance(node, ast.BinOp) and type(node.op) in OPERATORS:
return OPERATORS[type(node.op)](_eval(node.left), _eval(node.right))
left = _eval(node.left)
right = _eval(node.right)
if isinstance(node.op, ast.Pow) and isinstance(right, (int, float)) and abs(right) > 10000:
raise ValueError("Exponent too large (max 10000)")
return OPERATORS[type(node.op)](left, right)
elif isinstance(node, ast.UnaryOp) and type(node.op) in OPERATORS:
return OPERATORS[type(node.op)](_eval(node.operand))
raise ValueError(f"Unsupported expression")

View file

@ -17,7 +17,7 @@ tool = {
def execute(arguments):
count = arguments.get("count", 1)
sides = arguments.get("sides", 20)
count = max(1, min(arguments.get("count", 1), 1000))
sides = max(2, min(arguments.get("sides", 20), 1000))
rolls = [random.randint(1, sides) for _ in range(count)]
return {"rolls": rolls, "total": sum(rolls)}