mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
Organize one_click.py
This commit is contained in:
parent
e243424ba1
commit
99588be576
183
one_click.py
183
one_click.py
|
|
@ -15,7 +15,6 @@ import sys
|
|||
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0'
|
||||
# os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030'
|
||||
|
||||
|
||||
# Define the required versions
|
||||
TORCH_VERSION = "2.6.0"
|
||||
TORCHVISION_VERSION = "0.21.0"
|
||||
|
|
@ -62,6 +61,19 @@ def is_x86_64():
|
|||
return platform.machine() == "x86_64"
|
||||
|
||||
|
||||
def is_installed():
|
||||
site_packages_path = None
|
||||
for sitedir in site.getsitepackages():
|
||||
if "site-packages" in sitedir and conda_env_path in sitedir:
|
||||
site_packages_path = sitedir
|
||||
break
|
||||
|
||||
if site_packages_path:
|
||||
return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py'))
|
||||
else:
|
||||
return os.path.isdir(conda_env_path)
|
||||
|
||||
|
||||
def cpu_has_avx2():
|
||||
try:
|
||||
import cpuinfo
|
||||
|
|
@ -104,44 +116,13 @@ def torch_version():
|
|||
return torver
|
||||
|
||||
|
||||
def update_pytorch_and_python():
|
||||
print_big_message("Checking for PyTorch updates.")
|
||||
|
||||
# Update the Python version. Left here for future reference in case this becomes necessary.
|
||||
# print_big_message("Checking for PyTorch and Python updates.")
|
||||
# current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
# if current_python_version != PYTHON_VERSION:
|
||||
# run_cmd(f"conda install -y python={PYTHON_VERSION}", assert_success=True, environment=True)
|
||||
|
||||
torver = torch_version()
|
||||
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}"
|
||||
|
||||
if "+cu" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124"
|
||||
elif "+rocm" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.1"
|
||||
elif "+cpu" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
|
||||
elif "+cxx11" in torver:
|
||||
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
|
||||
install_cmd = f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
else:
|
||||
install_cmd = base_cmd
|
||||
|
||||
run_cmd(install_cmd, assert_success=True, environment=True)
|
||||
def get_current_commit():
|
||||
result = run_cmd("git rev-parse HEAD", capture_output=True, environment=True)
|
||||
return result.stdout.decode('utf-8').strip()
|
||||
|
||||
|
||||
def is_installed():
|
||||
site_packages_path = None
|
||||
for sitedir in site.getsitepackages():
|
||||
if "site-packages" in sitedir and conda_env_path in sitedir:
|
||||
site_packages_path = sitedir
|
||||
break
|
||||
|
||||
if site_packages_path:
|
||||
return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py'))
|
||||
else:
|
||||
return os.path.isdir(conda_env_path)
|
||||
def get_extensions_names():
|
||||
return [foldername for foldername in os.listdir('extensions') if os.path.isfile(os.path.join('extensions', foldername, 'requirements.txt'))]
|
||||
|
||||
|
||||
def check_env():
|
||||
|
|
@ -157,35 +138,11 @@ def check_env():
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
def get_current_commit():
|
||||
result = run_cmd("git rev-parse HEAD", capture_output=True, environment=True)
|
||||
return result.stdout.decode('utf-8').strip()
|
||||
|
||||
|
||||
def clear_cache():
|
||||
run_cmd("conda clean -a -y", environment=True)
|
||||
run_cmd("python -m pip cache purge", environment=True)
|
||||
|
||||
|
||||
def print_big_message(message):
|
||||
message = message.strip()
|
||||
lines = message.split('\n')
|
||||
print("\n\n*******************************************************************")
|
||||
for line in lines:
|
||||
print("*", line)
|
||||
|
||||
print("*******************************************************************\n\n")
|
||||
|
||||
|
||||
def calculate_file_hash(file_path):
|
||||
p = os.path.join(script_dir, file_path)
|
||||
if os.path.isfile(p):
|
||||
with open(p, 'rb') as f:
|
||||
return hashlib.sha256(f.read()).hexdigest()
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None):
|
||||
# Use the conda environment
|
||||
if environment:
|
||||
|
|
@ -210,6 +167,25 @@ def run_cmd(cmd, assert_success=False, environment=False, capture_output=False,
|
|||
return result
|
||||
|
||||
|
||||
def print_big_message(message):
|
||||
message = message.strip()
|
||||
lines = message.split('\n')
|
||||
print("\n\n*******************************************************************")
|
||||
for line in lines:
|
||||
print("*", line)
|
||||
|
||||
print("*******************************************************************\n\n")
|
||||
|
||||
|
||||
def calculate_file_hash(file_path):
|
||||
p = os.path.join(script_dir, file_path)
|
||||
if os.path.isfile(p):
|
||||
with open(p, 'rb') as f:
|
||||
return hashlib.sha256(f.read()).hexdigest()
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def generate_alphabetic_sequence(index):
|
||||
result = ''
|
||||
while index >= 0:
|
||||
|
|
@ -238,6 +214,51 @@ def get_user_choice(question, options_dict):
|
|||
return choice
|
||||
|
||||
|
||||
def update_pytorch_and_python():
|
||||
print_big_message("Checking for PyTorch updates.")
|
||||
|
||||
# Update the Python version. Left here for future reference in case this becomes necessary.
|
||||
# print_big_message("Checking for PyTorch and Python updates.")
|
||||
# current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
# if current_python_version != PYTHON_VERSION:
|
||||
# run_cmd(f"conda install -y python={PYTHON_VERSION}", assert_success=True, environment=True)
|
||||
|
||||
torver = torch_version()
|
||||
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}"
|
||||
|
||||
if "+cu" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124"
|
||||
elif "+rocm" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.1"
|
||||
elif "+cpu" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
|
||||
elif "+cxx11" in torver:
|
||||
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
|
||||
install_cmd = f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
else:
|
||||
install_cmd = base_cmd
|
||||
|
||||
run_cmd(install_cmd, assert_success=True, environment=True)
|
||||
|
||||
|
||||
def clean_outdated_pytorch_cuda_dependencies():
|
||||
patterns = ["cu121", "cu122", "torch2.4"]
|
||||
result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True)
|
||||
matching_packages = []
|
||||
|
||||
for line in result.stdout.decode('utf-8').splitlines():
|
||||
if "==" in line:
|
||||
pkg_name, version = line.split('==', 1)
|
||||
if any(pattern in version for pattern in patterns):
|
||||
matching_packages.append(pkg_name)
|
||||
|
||||
if matching_packages:
|
||||
print(f"\nUninstalling: {', '.join(matching_packages)}\n")
|
||||
run_cmd(f"python -m pip uninstall -y {' '.join(matching_packages)}", assert_success=True, environment=True)
|
||||
|
||||
return matching_packages
|
||||
|
||||
|
||||
def install_webui():
|
||||
if os.path.isfile(state_file):
|
||||
os.remove(state_file)
|
||||
|
|
@ -323,37 +344,6 @@ def install_webui():
|
|||
update_requirements(initial_installation=True, pull=False)
|
||||
|
||||
|
||||
def get_extensions_names():
|
||||
return [foldername for foldername in os.listdir('extensions') if os.path.isfile(os.path.join('extensions', foldername, 'requirements.txt'))]
|
||||
|
||||
|
||||
def install_extensions_requirements():
|
||||
print_big_message("Installing extensions requirements.\nSome of these may fail on Windows.\nDon\'t worry if you see error messages, as they will not affect the main program.")
|
||||
extensions = get_extensions_names()
|
||||
for i, extension in enumerate(extensions):
|
||||
print(f"\n\n--- [{i + 1}/{len(extensions)}]: {extension}\n\n")
|
||||
extension_req_path = os.path.join("extensions", extension, "requirements.txt")
|
||||
run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True)
|
||||
|
||||
|
||||
def clean_outdated_pytorch_cuda_dependencies():
|
||||
patterns = ["cu121", "cu122", "torch2.4"]
|
||||
result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True)
|
||||
matching_packages = []
|
||||
|
||||
for line in result.stdout.decode('utf-8').splitlines():
|
||||
if "==" in line:
|
||||
pkg_name, version = line.split('==', 1)
|
||||
if any(pattern in version for pattern in patterns):
|
||||
matching_packages.append(pkg_name)
|
||||
|
||||
if matching_packages:
|
||||
print(f"\nUninstalling: {', '.join(matching_packages)}\n")
|
||||
run_cmd(f"python -m pip uninstall -y {' '.join(matching_packages)}", assert_success=True, environment=True)
|
||||
|
||||
return matching_packages
|
||||
|
||||
|
||||
def update_requirements(initial_installation=False, pull=True):
|
||||
# Create .git directory if missing
|
||||
if not os.path.exists(os.path.join(script_dir, ".git")):
|
||||
|
|
@ -475,6 +465,15 @@ def update_requirements(initial_installation=False, pull=True):
|
|||
clear_cache()
|
||||
|
||||
|
||||
def install_extensions_requirements():
|
||||
print_big_message("Installing extensions requirements.\nSome of these may fail on Windows.\nDon\'t worry if you see error messages, as they will not affect the main program.")
|
||||
extensions = get_extensions_names()
|
||||
for i, extension in enumerate(extensions):
|
||||
print(f"\n\n--- [{i + 1}/{len(extensions)}]: {extension}\n\n")
|
||||
extension_req_path = os.path.join("extensions", extension, "requirements.txt")
|
||||
run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True)
|
||||
|
||||
|
||||
def launch_webui():
|
||||
run_cmd(f"python server.py {flags}", environment=True)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue