Handle CMD_FLAGS.txt in the main code (closes #6896)

This commit is contained in:
oobabooga 2025-04-24 08:19:46 -07:00
parent bfbde73409
commit c71a2af5ab
2 changed files with 18 additions and 8 deletions

View file

@ -1,6 +1,7 @@
import argparse
import copy
import os
import shlex
import sys
from collections import OrderedDict
from pathlib import Path
@ -201,6 +202,21 @@ group.add_argument('--nowebui', action='store_true', help='Do not launch the Gra
# Deprecated parameters
group = parser.add_argument_group('Deprecated')
# Handle CMD_FLAGS.txt
cmd_flags_path = Path(__file__).parent.parent / "CMD_FLAGS.txt"
if cmd_flags_path.exists():
with cmd_flags_path.open('r', encoding='utf-8') as f:
cmd_flags = ' '.join(
line.strip().rstrip('\\').strip()
for line in f
if line.strip().rstrip('\\').strip() and not line.strip().startswith('#')
)
if cmd_flags:
# Command-line takes precedence over CMD_FLAGS.txt
sys.argv = [sys.argv[0]] + shlex.split(cmd_flags) + sys.argv[1:]
args = parser.parse_args()
args_defaults = parser.parse_args([])
provided_arguments = []

View file

@ -28,14 +28,7 @@ conda_env_path = os.path.join(script_dir, "installer_files", "env")
state_file = '.installer_state.json'
# Command-line flags
cmd_flags_path = os.path.join(script_dir, "CMD_FLAGS.txt")
if os.path.exists(cmd_flags_path):
with open(cmd_flags_path, 'r') as f:
CMD_FLAGS = ' '.join(line.strip().rstrip('\\').strip() for line in f if line.strip().rstrip('\\').strip() and not line.strip().startswith('#'))
else:
CMD_FLAGS = ''
flags = f"{' '.join([flag for flag in sys.argv[1:] if flag != '--update-wizard'])} {CMD_FLAGS}"
flags = f"{' '.join([flag for flag in sys.argv[1:] if flag != '--update-wizard'])}"
def signal_handler(sig, frame):
@ -300,6 +293,7 @@ def install_webui():
# Write a flag to CMD_FLAGS.txt for CPU mode
if selected_gpu == "NONE":
cmd_flags_path = os.path.join(script_dir, "CMD_FLAGS.txt")
with open(cmd_flags_path, 'r+') as cmd_flags_file:
if "--cpu" not in cmd_flags_file.read():
print_big_message("Adding the --cpu flag to CMD_FLAGS.txt.")