Organize one_click.py

This commit is contained in:
oobabooga 2025-04-20 18:57:26 -07:00
parent e243424ba1
commit 99588be576

View file

@ -15,7 +15,6 @@ import sys
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0'
# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030'
# Define the required versions
TORCH_VERSION = "2.6.0"
TORCHVISION_VERSION = "0.21.0"
@ -62,6 +61,19 @@ def is_x86_64():
return platform.machine() == "x86_64"
def is_installed():
site_packages_path = None
for sitedir in site.getsitepackages():
if "site-packages" in sitedir and conda_env_path in sitedir:
site_packages_path = sitedir
break
if site_packages_path:
return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py'))
else:
return os.path.isdir(conda_env_path)
def cpu_has_avx2():
try:
import cpuinfo
@ -104,44 +116,13 @@ def torch_version():
return torver
def update_pytorch_and_python():
print_big_message("Checking for PyTorch updates.")
# Update the Python version. Left here for future reference in case this becomes necessary.
# print_big_message("Checking for PyTorch and Python updates.")
# current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
# if current_python_version != PYTHON_VERSION:
# run_cmd(f"conda install -y python={PYTHON_VERSION}", assert_success=True, environment=True)
torver = torch_version()
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}"
if "+cu" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124"
elif "+rocm" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.1"
elif "+cpu" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
elif "+cxx11" in torver:
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
install_cmd = f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
else:
install_cmd = base_cmd
run_cmd(install_cmd, assert_success=True, environment=True)
def get_current_commit():
result = run_cmd("git rev-parse HEAD", capture_output=True, environment=True)
return result.stdout.decode('utf-8').strip()
def is_installed():
site_packages_path = None
for sitedir in site.getsitepackages():
if "site-packages" in sitedir and conda_env_path in sitedir:
site_packages_path = sitedir
break
if site_packages_path:
return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py'))
else:
return os.path.isdir(conda_env_path)
def get_extensions_names():
return [foldername for foldername in os.listdir('extensions') if os.path.isfile(os.path.join('extensions', foldername, 'requirements.txt'))]
def check_env():
@ -157,35 +138,11 @@ def check_env():
sys.exit(1)
def get_current_commit():
result = run_cmd("git rev-parse HEAD", capture_output=True, environment=True)
return result.stdout.decode('utf-8').strip()
def clear_cache():
run_cmd("conda clean -a -y", environment=True)
run_cmd("python -m pip cache purge", environment=True)
def print_big_message(message):
message = message.strip()
lines = message.split('\n')
print("\n\n*******************************************************************")
for line in lines:
print("*", line)
print("*******************************************************************\n\n")
def calculate_file_hash(file_path):
p = os.path.join(script_dir, file_path)
if os.path.isfile(p):
with open(p, 'rb') as f:
return hashlib.sha256(f.read()).hexdigest()
else:
return ''
def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None):
# Use the conda environment
if environment:
@ -210,6 +167,25 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False,
return result
def print_big_message(message):
message = message.strip()
lines = message.split('\n')
print("\n\n*******************************************************************")
for line in lines:
print("*", line)
print("*******************************************************************\n\n")
def calculate_file_hash(file_path):
p = os.path.join(script_dir, file_path)
if os.path.isfile(p):
with open(p, 'rb') as f:
return hashlib.sha256(f.read()).hexdigest()
else:
return ''
def generate_alphabetic_sequence(index):
result = ''
while index >= 0:
@ -238,6 +214,51 @@ def get_user_choice(question, options_dict):
return choice
def update_pytorch_and_python():
print_big_message("Checking for PyTorch updates.")
# Update the Python version. Left here for future reference in case this becomes necessary.
# print_big_message("Checking for PyTorch and Python updates.")
# current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
# if current_python_version != PYTHON_VERSION:
# run_cmd(f"conda install -y python={PYTHON_VERSION}", assert_success=True, environment=True)
torver = torch_version()
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}"
if "+cu" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124"
elif "+rocm" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.1"
elif "+cpu" in torver:
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
elif "+cxx11" in torver:
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
install_cmd = f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
else:
install_cmd = base_cmd
run_cmd(install_cmd, assert_success=True, environment=True)
def clean_outdated_pytorch_cuda_dependencies():
patterns = ["cu121", "cu122", "torch2.4"]
result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True)
matching_packages = []
for line in result.stdout.decode('utf-8').splitlines():
if "==" in line:
pkg_name, version = line.split('==', 1)
if any(pattern in version for pattern in patterns):
matching_packages.append(pkg_name)
if matching_packages:
print(f"\nUninstalling: {', '.join(matching_packages)}\n")
run_cmd(f"python -m pip uninstall -y {' '.join(matching_packages)}", assert_success=True, environment=True)
return matching_packages
def install_webui():
if os.path.isfile(state_file):
os.remove(state_file)
@ -323,37 +344,6 @@ def install_webui():
update_requirements(initial_installation=True, pull=False)
def get_extensions_names():
return [foldername for foldername in os.listdir('extensions') if os.path.isfile(os.path.join('extensions', foldername, 'requirements.txt'))]
def install_extensions_requirements():
print_big_message("Installing extensions requirements.\nSome of these may fail on Windows.\nDon\'t worry if you see error messages, as they will not affect the main program.")
extensions = get_extensions_names()
for i, extension in enumerate(extensions):
print(f"\n\n--- [{i + 1}/{len(extensions)}]: {extension}\n\n")
extension_req_path = os.path.join("extensions", extension, "requirements.txt")
run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True)
def clean_outdated_pytorch_cuda_dependencies():
patterns = ["cu121", "cu122", "torch2.4"]
result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True)
matching_packages = []
for line in result.stdout.decode('utf-8').splitlines():
if "==" in line:
pkg_name, version = line.split('==', 1)
if any(pattern in version for pattern in patterns):
matching_packages.append(pkg_name)
if matching_packages:
print(f"\nUninstalling: {', '.join(matching_packages)}\n")
run_cmd(f"python -m pip uninstall -y {' '.join(matching_packages)}", assert_success=True, environment=True)
return matching_packages
def update_requirements(initial_installation=False, pull=True):
# Create .git directory if missing
if not os.path.exists(os.path.join(script_dir, ".git")):
@ -475,6 +465,15 @@ def update_requirements(initial_installation=False, pull=True):
clear_cache()
def install_extensions_requirements():
print_big_message("Installing extensions requirements.\nSome of these may fail on Windows.\nDon\'t worry if you see error messages, as they will not affect the main program.")
extensions = get_extensions_names()
for i, extension in enumerate(extensions):
print(f"\n\n--- [{i + 1}/{len(extensions)}]: {extension}\n\n")
extension_req_path = os.path.join("extensions", extension, "requirements.txt")
run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True)
def launch_webui():
run_cmd(f"python server.py {flags}", environment=True)