Create an update wizard (#5623)

This commit is contained in:
oobabooga 2024-03-04 15:52:24 -03:00 committed by GitHub
parent 6adf222599
commit 97dc3602fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 57 additions and 130 deletions

View file

@ -32,7 +32,7 @@ if os.path.exists(cmd_flags_path):
else:
CMD_FLAGS = ''
flags = f"{' '.join([flag for flag in sys.argv[1:] if flag != '--update'])} {CMD_FLAGS}"
flags = f"{' '.join([flag for flag in sys.argv[1:] if flag != '--update-wizard'])} {CMD_FLAGS}"
def signal_handler(sig, frame):
@ -200,6 +200,24 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False,
return result
def get_user_choice(question, options_dict):
print()
print(question)
print()
for key, value in options_dict.items():
print(f"{key}) {value}")
print()
choice = input("Input> ").upper()
while choice not in options_dict.keys():
print("Invalid choice. Please try again.")
choice = input("Input> ").upper()
return choice
def install_webui():
# Ask the user for the GPU vendor
@ -207,20 +225,16 @@ def install_webui():
choice = os.environ["GPU_CHOICE"].upper()
print_big_message(f"Selected GPU choice \"{choice}\" based on the GPU_CHOICE environment variable.")
else:
print()
print("What is your GPU?")
print()
print("A) NVIDIA")
print("B) AMD (Linux/MacOS only. Requires ROCm SDK 5.6 on Linux)")
print("C) Apple M Series")
print("D) Intel Arc (IPEX)")
print("N) None (I want to run models in CPU mode)")
print()
choice = input("Input> ").upper()
while choice not in 'ABCDN':
print("Invalid choice. Please try again.")
choice = input("Input> ").upper()
choice = get_user_choice(
"What is your GPU?",
{
'A': 'NVIDIA',
'B': 'AMD (Linux/MacOS only. Requires ROCm SDK 5.6 on Linux)',
'C': 'Apple M Series',
'D': 'Intel Arc (IPEX)',
'N': 'None (I want to run models in CPU mode)'
},
)
gpu_choice_to_name = {
"A": "NVIDIA",
@ -395,15 +409,29 @@ if __name__ == "__main__":
check_env()
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--update', action='store_true', help='Update the web UI.')
parser.add_argument('--install-extensions', action='store_true', help='Install extensions requirements.')
parser.add_argument('--update-wizard', action='store_true', help='Launch a menu with update options.')
args, _ = parser.parse_known_args()
if args.update:
update_requirements()
elif args.install_extensions:
install_extensions_requirements()
update_requirements()
if args.update_wizard:
choice = get_user_choice(
"What would you like to do?",
{
'A': 'Update the web UI',
'B': 'Install/update extensions requirements',
'C': 'Revert local changes to repository files with \"git reset --hard\"',
'N': 'Nothing (exit).'
},
)
if choice == 'A':
update_requirements()
elif choice == 'B':
install_extensions_requirements()
update_requirements()
elif choice == 'C':
run_cmd("git reset --hard", assert_success=True, environment=True)
elif choice == 'N':
sys.exit()
else:
if not is_installed():
install_webui()