From 99588be5760702d413435c9167935f29832a6d06 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 20 Apr 2025 18:57:26 -0700 Subject: [PATCH] Organize one_click.py --- one_click.py | 183 +++++++++++++++++++++++++-------------------------- 1 file changed, 91 insertions(+), 92 deletions(-) diff --git a/one_click.py b/one_click.py index 99bdc41e..eff2ed9f 100644 --- a/one_click.py +++ b/one_click.py @@ -15,7 +15,6 @@ import sys # os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0' # os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030' - # Define the required versions TORCH_VERSION = "2.6.0" TORCHVISION_VERSION = "0.21.0" @@ -62,6 +61,19 @@ def is_x86_64(): return platform.machine() == "x86_64" +def is_installed(): + site_packages_path = None + for sitedir in site.getsitepackages(): + if "site-packages" in sitedir and conda_env_path in sitedir: + site_packages_path = sitedir + break + + if site_packages_path: + return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py')) + else: + return os.path.isdir(conda_env_path) + + def cpu_has_avx2(): try: import cpuinfo @@ -104,44 +116,13 @@ def torch_version(): return torver -def update_pytorch_and_python(): - print_big_message("Checking for PyTorch updates.") - - # Update the Python version. Left here for future reference in case this becomes necessary. - # print_big_message("Checking for PyTorch and Python updates.") - # current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - # if current_python_version != PYTHON_VERSION: - # run_cmd(f"conda install -y python={PYTHON_VERSION}", assert_success=True, environment=True) - - torver = torch_version() - base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}" - - if "+cu" in torver: - install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124" - elif "+rocm" in torver: - install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.1" - elif "+cpu" in torver: - install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu" - elif "+cxx11" in torver: - intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10" - install_cmd = f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" - else: - install_cmd = base_cmd - - run_cmd(install_cmd, assert_success=True, environment=True) +def get_current_commit(): + result = run_cmd("git rev-parse HEAD", capture_output=True, environment=True) + return result.stdout.decode('utf-8').strip() -def is_installed(): - site_packages_path = None - for sitedir in site.getsitepackages(): - if "site-packages" in sitedir and conda_env_path in sitedir: - site_packages_path = sitedir - break - - if site_packages_path: - return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py')) - else: - return os.path.isdir(conda_env_path) +def get_extensions_names(): + return [foldername for foldername in os.listdir('extensions') if os.path.isfile(os.path.join('extensions', foldername, 'requirements.txt'))] def check_env(): @@ -157,35 +138,11 @@ def check_env(): sys.exit(1) -def get_current_commit(): - result = run_cmd("git rev-parse HEAD", capture_output=True, environment=True) - return result.stdout.decode('utf-8').strip() - - def clear_cache(): run_cmd("conda clean -a -y", environment=True) run_cmd("python -m pip cache purge", environment=True) -def print_big_message(message): - message = message.strip() - lines = message.split('\n') - print("\n\n*******************************************************************") - for line in lines: - print("*", line) - - print("*******************************************************************\n\n") - - -def calculate_file_hash(file_path): - p = os.path.join(script_dir, file_path) - if os.path.isfile(p): - with open(p, 'rb') as f: - return hashlib.sha256(f.read()).hexdigest() - else: - return '' - - def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None): # Use the conda environment if environment: @@ -210,6 +167,25 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, return result +def print_big_message(message): + message = message.strip() + lines = message.split('\n') + print("\n\n*******************************************************************") + for line in lines: + print("*", line) + + print("*******************************************************************\n\n") + + +def calculate_file_hash(file_path): + p = os.path.join(script_dir, file_path) + if os.path.isfile(p): + with open(p, 'rb') as f: + return hashlib.sha256(f.read()).hexdigest() + else: + return '' + + def generate_alphabetic_sequence(index): result = '' while index >= 0: @@ -238,6 +214,51 @@ def get_user_choice(question, options_dict): return choice +def update_pytorch_and_python(): + print_big_message("Checking for PyTorch updates.") + + # Update the Python version. Left here for future reference in case this becomes necessary. + # print_big_message("Checking for PyTorch and Python updates.") + # current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + # if current_python_version != PYTHON_VERSION: + # run_cmd(f"conda install -y python={PYTHON_VERSION}", assert_success=True, environment=True) + + torver = torch_version() + base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}" + + if "+cu" in torver: + install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124" + elif "+rocm" in torver: + install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.1" + elif "+cpu" in torver: + install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu" + elif "+cxx11" in torver: + intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10" + install_cmd = f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" + else: + install_cmd = base_cmd + + run_cmd(install_cmd, assert_success=True, environment=True) + + +def clean_outdated_pytorch_cuda_dependencies(): + patterns = ["cu121", "cu122", "torch2.4"] + result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True) + matching_packages = [] + + for line in result.stdout.decode('utf-8').splitlines(): + if "==" in line: + pkg_name, version = line.split('==', 1) + if any(pattern in version for pattern in patterns): + matching_packages.append(pkg_name) + + if matching_packages: + print(f"\nUninstalling: {', '.join(matching_packages)}\n") + run_cmd(f"python -m pip uninstall -y {' '.join(matching_packages)}", assert_success=True, environment=True) + + return matching_packages + + def install_webui(): if os.path.isfile(state_file): os.remove(state_file) @@ -323,37 +344,6 @@ def install_webui(): update_requirements(initial_installation=True, pull=False) -def get_extensions_names(): - return [foldername for foldername in os.listdir('extensions') if os.path.isfile(os.path.join('extensions', foldername, 'requirements.txt'))] - - -def install_extensions_requirements(): - print_big_message("Installing extensions requirements.\nSome of these may fail on Windows.\nDon\'t worry if you see error messages, as they will not affect the main program.") - extensions = get_extensions_names() - for i, extension in enumerate(extensions): - print(f"\n\n--- [{i + 1}/{len(extensions)}]: {extension}\n\n") - extension_req_path = os.path.join("extensions", extension, "requirements.txt") - run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True) - - -def clean_outdated_pytorch_cuda_dependencies(): - patterns = ["cu121", "cu122", "torch2.4"] - result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True) - matching_packages = [] - - for line in result.stdout.decode('utf-8').splitlines(): - if "==" in line: - pkg_name, version = line.split('==', 1) - if any(pattern in version for pattern in patterns): - matching_packages.append(pkg_name) - - if matching_packages: - print(f"\nUninstalling: {', '.join(matching_packages)}\n") - run_cmd(f"python -m pip uninstall -y {' '.join(matching_packages)}", assert_success=True, environment=True) - - return matching_packages - - def update_requirements(initial_installation=False, pull=True): # Create .git directory if missing if not os.path.exists(os.path.join(script_dir, ".git")): @@ -475,6 +465,15 @@ def update_requirements(initial_installation=False, pull=True): clear_cache() +def install_extensions_requirements(): + print_big_message("Installing extensions requirements.\nSome of these may fail on Windows.\nDon\'t worry if you see error messages, as they will not affect the main program.") + extensions = get_extensions_names() + for i, extension in enumerate(extensions): + print(f"\n\n--- [{i + 1}/{len(extensions)}]: {extension}\n\n") + extension_req_path = os.path.join("extensions", extension, "requirements.txt") + run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True) + + def launch_webui(): run_cmd(f"python server.py {flags}", environment=True)