Add PyPI fallback for PyTorch install commands

This commit is contained in:
oobabooga 2026-03-07 23:06:15 -03:00
parent aeeff41cc0
commit b3705d87bf

View file

@ -111,13 +111,14 @@ def get_gpu_choice():
def get_pytorch_install_command(gpu_choice):
"""Get PyTorch installation command based on GPU choice"""
base_cmd = f"python -m pip install torch=={TORCH_VERSION} "
pypi_fallback = " --extra-index-url https://pypi.org/simple/"
if gpu_choice == "NVIDIA_CUDA128":
return base_cmd + "--index-url https://download.pytorch.org/whl/cu128"
return base_cmd + "--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback
elif gpu_choice == "AMD":
return base_cmd + "--index-url https://download.pytorch.org/whl/rocm6.4"
return base_cmd + "--index-url https://download.pytorch.org/whl/rocm6.4" + pypi_fallback
elif gpu_choice in ["APPLE", "NONE"]:
return base_cmd + "--index-url https://download.pytorch.org/whl/cpu"
return base_cmd + "--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback
elif gpu_choice == "INTEL":
if is_linux():
return "python -m pip install torch==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
@ -130,16 +131,17 @@ def get_pytorch_install_command(gpu_choice):
def get_pytorch_update_command(gpu_choice):
"""Get PyTorch update command based on GPU choice"""
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} "
pypi_fallback = " --extra-index-url https://pypi.org/simple/"
if gpu_choice == "NVIDIA_CUDA128":
return f"{base_cmd} --index-url https://download.pytorch.org/whl/cu128"
return f"{base_cmd}--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback
elif gpu_choice == "AMD":
return f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.4"
return f"{base_cmd}--index-url https://download.pytorch.org/whl/rocm6.4" + pypi_fallback
elif gpu_choice in ["APPLE", "NONE"]:
return f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
return f"{base_cmd}--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback
elif gpu_choice == "INTEL":
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
return f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
return f"{base_cmd}{intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
else:
return base_cmd