Minor cleanup across multiple modules

This commit is contained in:
oobabooga 2026-03-19 08:02:23 -07:00
parent 5453b9f30e
commit e0e20ab9e7
6 changed files with 28 additions and 82 deletions

View file

@ -263,7 +263,7 @@ def convert_history(history):
seen_non_system = True
meta = {}
tool_calls = entry.get("tool_calls")
if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0:
if tool_calls and isinstance(tool_calls, list):
meta["tool_calls"] = tool_calls
if content.strip() == "":
content = "" # keep empty content, don't skip
@ -315,7 +315,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
raise InvalidRequestError(message="messages is required", param='messages')
tools = None
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and body['tools']:
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
tool_choice = body.get('tool_choice', None)

View file

@ -500,9 +500,8 @@ class LlamaServer:
health_url = f"http://127.0.0.1:{self.port}/health"
while True:
# Check if process is still alive
if self.process.poll() is not None:
# Process has terminated
exit_code = self.process.poll()
exit_code = self.process.poll()
if exit_code is not None:
raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}")
try:

View file

@ -453,15 +453,11 @@ def load_user_config():
'''
Loads custom model-specific settings
'''
user_config = {}
if Path(f'{args.model_dir}/config-user.yaml').exists():
file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip()
if file_content:
user_config = yaml.safe_load(file_content)
else:
user_config = {}
else:
user_config = {}
return user_config

View file

@ -3,6 +3,10 @@ import random
import re
def _make_tool_call(name, arguments):
return {"type": "function", "function": {"name": name, "arguments": arguments}}
def get_tool_call_id() -> str:
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b = [random.choice(letter_bytes) for _ in range(8)]
@ -149,13 +153,7 @@ def _parse_channel_tool_calls(answer: str, tool_names: list[str]):
if start_pos is None:
prefix = answer.rfind('<|start|>assistant', 0, m.start())
start_pos = prefix if prefix != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
matches.append(_make_tool_call(func_name, arguments))
except json.JSONDecodeError:
pass
if matches:
@ -185,13 +183,7 @@ def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]):
arguments = json.loads(json_str)
if start_pos is None:
start_pos = m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
matches.append(_make_tool_call(func_name, arguments))
except json.JSONDecodeError:
pass
return matches, start_pos
@ -226,13 +218,7 @@ def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]):
arguments = json.loads(json_str)
if start_pos is None:
start_pos = match.start()
matches.append({
"type": "function",
"function": {
"name": name,
"arguments": arguments
}
})
matches.append(_make_tool_call(name, arguments))
except json.JSONDecodeError:
pass
return matches, start_pos
@ -269,13 +255,7 @@ def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]):
arguments[param_name] = param_value
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
matches.append(_make_tool_call(func_name, arguments))
return matches, start_pos
@ -305,13 +285,7 @@ def _parse_kimi_tool_calls(answer: str, tool_names: list[str]):
# Check for section begin marker before the call marker
section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
start_pos = section if section != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
matches.append(_make_tool_call(func_name, arguments))
except json.JSONDecodeError:
pass
return matches, start_pos
@ -348,13 +322,7 @@ def _parse_minimax_tool_calls(answer: str, tool_names: list[str]):
arguments[param_name] = param_value
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
matches.append(_make_tool_call(func_name, arguments))
return matches, start_pos
@ -382,13 +350,7 @@ def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]):
# Check for section begin marker before the call marker
section = answer.rfind('<tool▁calls▁begin>', 0, m.start())
start_pos = section if section != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
matches.append(_make_tool_call(func_name, arguments))
except json.JSONDecodeError:
pass
return matches, start_pos
@ -428,13 +390,7 @@ def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
arguments[k] = v
if start_pos is None:
start_pos = tc_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
matches.append(_make_tool_call(func_name, arguments))
return matches, start_pos
@ -486,13 +442,7 @@ def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
if start_pos is None:
start_pos = bracket_match.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
matches.append(_make_tool_call(func_name, arguments))
return matches, start_pos

View file

@ -732,11 +732,13 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
if lora_all_param > 0:
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
train_log.update({"base_model_name": shared.model_name})
train_log.update({"base_model_class": shared.model.__class__.__name__})
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
train_log.update({"projections": projections_string})
train_log.update({
"base_model_name": shared.model_name,
"base_model_class": shared.model.__class__.__name__,
"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False),
"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False),
"projections": projections_string,
})
if stop_at_loss > 0:
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")

View file

@ -299,7 +299,7 @@ def apply_interface_values(state, use_persistent=False):
elements = list_interface_input_elements()
if len(state) == 0:
if not state:
return [gr.update() for k in elements] # Dummy, do nothing
else:
return [state[k] if k in state else gr.update() for k in elements]
@ -307,9 +307,8 @@ def apply_interface_values(state, use_persistent=False):
def save_settings(state, preset, extensions_list, show_controls, theme_state, manual_save=False):
output = copy.deepcopy(shared.settings)
exclude = []
for k in state:
if k in shared.settings and k not in exclude:
if k in shared.settings:
output[k] = state[k]
if preset:
@ -323,7 +322,7 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma
output['custom_stopping_strings'] = output.get('custom_stopping_strings') or ''
output['custom_token_bans'] = output.get('custom_token_bans') or ''
output['show_controls'] = show_controls
output['dark_theme'] = True if theme_state == 'dark' else False
output['dark_theme'] = theme_state == 'dark'
output.pop('instruction_template_str')
output.pop('truncation_length')