From 8b8d39ec4e66affac03c22176ac368785095f584 Mon Sep 17 00:00:00 2001 From: oobabooga Date: Wed, 9 Apr 2025 00:07:08 -0300 Subject: [PATCH] Add ExLlamaV3 support (#6832) --- README.md | 24 ++--- modules/exllamav3_hf.py | 179 +++++++++++++++++++++++++++++++++ modules/loaders.py | 56 ++++++++++- modules/models.py | 17 +++- modules/models_settings.py | 4 +- modules/shared.py | 4 +- one_click.py | 50 +++++++-- requirements.txt | 18 ++-- requirements_amd.txt | 2 +- requirements_amd_noavx2.txt | 2 +- requirements_apple_intel.txt | 1 + requirements_apple_silicon.txt | 1 + requirements_noavx2.txt | 18 ++-- 13 files changed, 322 insertions(+), 54 deletions(-) create mode 100644 modules/exllamav3_hf.py diff --git a/README.md b/README.md index 542e1ae1..63b8931a 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. ## Features -- Supports multiple text generation backends in one UI/API, including [Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), and [ExLlamaV2](https://github.com/turboderp-org/exllamav2). [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) is supported via its own [Dockerfile](https://github.com/oobabooga/text-generation-webui/blob/main/docker/TensorRT-LLM/Dockerfile), and the Transformers loader is compatible with libraries like [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), [AutoAWQ](https://github.com/casper-hansen/AutoAWQ), [HQQ](https://github.com/mobiusml/hqq), and [AQLM](https://github.com/Vahe1994/AQLM), but they must be installed manually. +- Supports multiple text generation backends in one UI/API, including [Transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [ExLlamaV2](https://github.com/turboderp-org/exllamav2). [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) is supported via its own [Dockerfile](https://github.com/oobabooga/text-generation-webui/blob/main/docker/TensorRT-LLM/Dockerfile), and the Transformers loader is compatible with libraries like [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ), [AutoAWQ](https://github.com/casper-hansen/AutoAWQ), [HQQ](https://github.com/mobiusml/hqq), and [AQLM](https://github.com/Vahe1994/AQLM), but they must be installed manually. - OpenAI-compatible API with Chat and Completions endpoints – see [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples). - Automatic prompt formatting using Jinja2 templates. - Three chat modes: `instruct`, `chat-instruct`, and `chat`, with automatic prompt templates in `chat-instruct`. @@ -78,25 +78,19 @@ conda activate textgen | System | GPU | Command | |--------|---------|---------| -| Linux/WSL | NVIDIA | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121` | -| Linux/WSL | CPU only | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cpu` | -| Linux | AMD | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/rocm6.1` | -| MacOS + MPS | Any | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1` | -| Windows | NVIDIA | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121` | -| Windows | CPU only | `pip3 install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1` | +| Linux/WSL | NVIDIA | `pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124` | +| Linux/WSL | CPU only | `pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu` | +| Linux | AMD | `pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/rocm6.1` | +| MacOS + MPS | Any | `pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0` | +| Windows | NVIDIA | `pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124` | +| Windows | CPU only | `pip3 install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0` | The up-to-date commands can be found here: https://pytorch.org/get-started/locally/. -For NVIDIA, you also need to install the CUDA runtime libraries: +If you need `nvcc` to compile some library manually, you will additionally need to install this: ``` -conda install -y -c "nvidia/label/cuda-12.1.1" cuda-runtime -``` - -If you need `nvcc` to compile some library manually, replace the command above with - -``` -conda install -y -c "nvidia/label/cuda-12.1.1" cuda +conda install -y -c "nvidia/label/cuda-12.4.1" cuda ``` #### 3. Install the web UI diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py new file mode 100644 index 00000000..3bf44c9b --- /dev/null +++ b/modules/exllamav3_hf.py @@ -0,0 +1,179 @@ +import os +import traceback +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from exllamav3 import Cache, Config, Model +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from modules import shared +from modules.logging_colors import logger + +try: + import flash_attn +except Exception: + logger.warning('Failed to load flash-attention due to the following error:\n') + traceback.print_exc() + + +class Exllamav3HF(PreTrainedModel): + def __init__(self, model_dir): + super().__init__(PretrainedConfig()) + self.generation_config = GenerationConfig() + + config = Config.from_directory(model_dir) + self.ex_model = Model.from_config(config) + + # Calculate the closest multiple of 256 at or above the chosen value + max_tokens = shared.args.max_seq_len + if max_tokens % 256 != 0: + adjusted_tokens = ((max_tokens // 256) + 1) * 256 + logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}") + max_tokens = adjusted_tokens + + self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens) + + # Create load parameters dictionary + load_params = {'progressbar': True} + if shared.args.gpu_split: + split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] + load_params['use_per_device'] = split + + self.ex_model.load(**load_params) + self.past_seq = None + self.max_tokens = max_tokens + + def _validate_model_class(self): + pass + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + pass + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return {'input_ids': input_ids, **kwargs} + + @property + def device(self) -> torch.device: + return torch.device(0) + + def __call__(self, *args, **kwargs): + use_cache = kwargs.get('use_cache', True) + labels = kwargs.get('labels', None) + past_key_values = kwargs.get('past_key_values', None) + + if len(args) > 0: + if not shared.args.cfg_cache: + logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.") + return + + input_ids = args[0] + is_negative = True + past_seq = self.past_seq_negative + ex_cache = self.ex_cache_negative + else: + input_ids = kwargs['input_ids'] + is_negative = False + past_seq = self.past_seq + ex_cache = self.ex_cache + + seq = input_ids[0].tolist() + if is_negative and past_key_values is not None: + seq = past_key_values + seq + + seq_tensor = torch.tensor(seq) + reset = True + + # Make the forward call + if labels is None: + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length + + if longest_prefix > 0: + reset = False + current_len = longest_prefix + if len(seq_tensor) - longest_prefix > 1: + self.ex_model.forward( + input_ids=seq_tensor[longest_prefix:-1].view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": longest_prefix, + "batch_shape": (1, self.max_tokens) + } + ) + + current_len = longest_prefix + len(seq_tensor) - longest_prefix - 1 + + if reset: + if len(seq_tensor) > 1: + self.ex_model.forward( + input_ids=seq_tensor[:-1].view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": 0, + "batch_shape": (1, self.max_tokens) + } + ) + + current_len = len(seq_tensor) - 1 + else: + current_len = 0 + + logits = self.ex_model.forward( + input_ids=seq_tensor[-1:].view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": current_len, + "batch_shape": (1, self.max_tokens) + } + ).to(input_ids.device).float() + else: + logits = self.ex_model.forward( + input_ids=seq_tensor.view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": 0, + "batch_shape": (1, self.max_tokens) + } + ).float() + + if is_negative: + self.past_seq_negative = seq_tensor + else: + self.past_seq = seq_tensor + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, logits.shape[-1]) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + + pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) + + return Exllamav3HF(pretrained_model_name_or_path) diff --git a/modules/loaders.py b/modules/loaders.py index 88ded1d1..980a13e6 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -23,7 +23,6 @@ loaders_and_params = OrderedDict({ 'use_double_quant', 'use_eager_attention', 'bf16', - 'trust_remote_code', 'no_use_fast', ], @@ -76,6 +75,13 @@ loaders_and_params = OrderedDict({ 'no_use_fast', 'llamacpp_HF_info', ], + 'ExLlamav3_HF': [ + 'max_seq_len', + 'gpu_split', + 'cfg_cache', + 'trust_remote_code', + 'no_use_fast', + ], 'ExLlamav2_HF': [ 'max_seq_len', 'cache_type', @@ -174,30 +180,38 @@ def transformers_samplers(): loaders_samplers = { 'Transformers': transformers_samplers(), 'HQQ': transformers_samplers(), - 'ExLlamav2': { + 'ExLlamav3_HF': { 'temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', + 'smoothing_curve', 'min_p', 'top_p', 'top_k', 'typical_p', 'xtc_threshold', 'xtc_probability', + 'epsilon_cutoff', + 'eta_cutoff', 'tfs', 'top_a', + 'top_n_sigma', 'dry_multiplier', 'dry_allowed_length', 'dry_base', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', 'repetition_penalty_range', + 'guidance_scale', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'do_sample', 'dynamic_temperature', 'temperature_last', 'auto_max_new_tokens', @@ -205,8 +219,12 @@ loaders_samplers = { 'add_bos_token', 'skip_special_tokens', 'seed', + 'sampler_priority', 'custom_token_bans', + 'negative_prompt', 'dry_sequence_breakers', + 'grammar_string', + 'grammar_file_row', }, 'ExLlamav2_HF': { 'temperature', @@ -254,6 +272,40 @@ loaders_samplers = { 'grammar_string', 'grammar_file_row', }, + 'ExLlamav2': { + 'temperature', + 'dynatemp_low', + 'dynatemp_high', + 'dynatemp_exponent', + 'smoothing_factor', + 'min_p', + 'top_p', + 'top_k', + 'typical_p', + 'xtc_threshold', + 'xtc_probability', + 'tfs', + 'top_a', + 'dry_multiplier', + 'dry_allowed_length', + 'dry_base', + 'repetition_penalty', + 'frequency_penalty', + 'presence_penalty', + 'repetition_penalty_range', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'dynamic_temperature', + 'temperature_last', + 'auto_max_new_tokens', + 'ban_eos_token', + 'add_bos_token', + 'skip_special_tokens', + 'seed', + 'custom_token_bans', + 'dry_sequence_breakers', + }, 'llama.cpp': { 'temperature', 'min_p', diff --git a/modules/models.py b/modules/models.py index 3951fe82..288bc1b6 100644 --- a/modules/models.py +++ b/modules/models.py @@ -69,8 +69,9 @@ def load_model(model_name, loader=None): 'Transformers': huggingface_loader, 'llama.cpp': llamacpp_loader, 'llamacpp_HF': llamacpp_HF_loader, - 'ExLlamav2': ExLlamav2_loader, + 'ExLlamav3_HF': ExLlamav3_HF_loader, 'ExLlamav2_HF': ExLlamav2_HF_loader, + 'ExLlamav2': ExLlamav2_loader, 'HQQ': HQQ_loader, 'TensorRT-LLM': TensorRT_LLM_loader, } @@ -304,11 +305,10 @@ def llamacpp_HF_loader(model_name): return model -def ExLlamav2_loader(model_name): - from modules.exllamav2 import Exllamav2Model +def ExLlamav3_HF_loader(model_name): + from modules.exllamav3_hf import Exllamav3HF - model, tokenizer = Exllamav2Model.from_pretrained(model_name) - return model, tokenizer + return Exllamav3HF.from_pretrained(model_name) def ExLlamav2_HF_loader(model_name): @@ -317,6 +317,13 @@ def ExLlamav2_HF_loader(model_name): return Exllamav2HF.from_pretrained(model_name) +def ExLlamav2_loader(model_name): + from modules.exllamav2 import Exllamav2Model + + model, tokenizer = Exllamav2Model.from_pretrained(model_name) + return model, tokenizer + + def HQQ_loader(model_name): try: from hqq.core.quantize import HQQBackend, HQQLinear diff --git a/modules/models_settings.py b/modules/models_settings.py index b67d28a0..51994e23 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -158,14 +158,14 @@ def infer_loader(model_name, model_settings): path_to_model = Path(f'{shared.args.model_dir}/{model_name}') if not path_to_model.exists(): loader = None - elif (path_to_model / 'quantize_config.json').exists(): # Old GPTQ metadata file - loader = 'ExLlamav2_HF' elif len(list(path_to_model.glob('*.gguf'))) > 0 and path_to_model.is_dir() and (path_to_model / 'tokenizer_config.json').exists(): loader = 'llamacpp_HF' elif len(list(path_to_model.glob('*.gguf'))) > 0: loader = 'llama.cpp' elif re.match(r'.*\.gguf', model_name.lower()): loader = 'llama.cpp' + elif re.match(r'.*exl3', model_name.lower()): + loader = 'ExLlamav3_HF' elif re.match(r'.*exl2', model_name.lower()): loader = 'ExLlamav2_HF' elif re.match(r'.*-hqq', model_name.lower()): diff --git a/modules/shared.py b/modules/shared.py index 77bd7639..0981f6fb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,7 +86,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft # Model loader group = parser.add_argument_group('Model loader') -group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.') +group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.') # Transformers/Accelerate group = parser.add_argument_group('Transformers/Accelerate') @@ -273,6 +273,8 @@ def fix_loader_name(name): return 'ExLlamav2' elif name in ['exllamav2-hf', 'exllamav2_hf', 'exllama-v2-hf', 'exllama_v2_hf', 'exllama-v2_hf', 'exllama2-hf', 'exllama2_hf', 'exllama-2-hf', 'exllama_2_hf', 'exllama-2_hf']: return 'ExLlamav2_HF' + elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']: + return 'ExLlamav3_HF' elif name in ['hqq']: return 'HQQ' elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']: diff --git a/one_click.py b/one_click.py index 72626010..fcca4ff5 100644 --- a/one_click.py +++ b/one_click.py @@ -16,10 +16,11 @@ import sys # os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030' -# Define the required PyTorch version -TORCH_VERSION = "2.4.1" -TORCHVISION_VERSION = "0.19.1" -TORCHAUDIO_VERSION = "2.4.1" +# Define the required versions +TORCH_VERSION = "2.6.0" +TORCHVISION_VERSION = "0.21.0" +TORCHAUDIO_VERSION = "2.6.0" +PYTHON_VERSION = "3.11" # Environment script_dir = os.getcwd() @@ -101,13 +102,20 @@ def torch_version(): return torver -def update_pytorch(): +def update_pytorch_and_python(): print_big_message("Checking for PyTorch updates.") + + # Update the Python version. Left here for future reference in case this becomes necessary. + # print_big_message("Checking for PyTorch and Python updates.") + # current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + # if current_python_version != PYTHON_VERSION: + # run_cmd(f"conda install -y python={PYTHON_VERSION}", assert_success=True, environment=True) + torver = torch_version() base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}" if "+cu" in torver: - install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu121" + install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124" elif "+rocm" in torver: install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.1" elif "+cpu" in torver: @@ -245,7 +253,7 @@ def install_webui(): choice = get_user_choice( "What is your GPU?", { - 'A': 'NVIDIA - CUDA 12.1', + 'A': 'NVIDIA - CUDA 12.4', 'B': 'AMD - Linux/macOS only, requires ROCm 6.1', 'C': 'Apple M Series', 'D': 'Intel Arc (beta)', @@ -273,7 +281,7 @@ def install_webui(): # Handle CUDA version display elif any((is_windows(), is_linux())) and selected_gpu == "NVIDIA": - print("CUDA: 12.1") + print("CUDA: 12.4") # No PyTorch for AMD on Windows (?) elif is_windows() and selected_gpu == "AMD": @@ -284,7 +292,7 @@ def install_webui(): install_pytorch = f"python -m pip install torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION} " if selected_gpu == "NVIDIA": - install_pytorch += "--index-url https://download.pytorch.org/whl/cu121" + install_pytorch += "--index-url https://download.pytorch.org/whl/cu124" elif selected_gpu == "AMD": install_pytorch += "--index-url https://download.pytorch.org/whl/rocm6.1" elif selected_gpu in ["APPLE", "NONE"]: @@ -297,7 +305,7 @@ def install_webui(): # Install Git and then Pytorch print_big_message("Installing PyTorch.") - run_cmd(f"conda install -y -k ninja git && {install_pytorch} && python -m pip install py-cpuinfo==9.0.0", assert_success=True, environment=True) + run_cmd(f"conda install -y ninja git && {install_pytorch} && python -m pip install py-cpuinfo==9.0.0", assert_success=True, environment=True) if selected_gpu == "INTEL": # Install oneAPI dependencies via conda @@ -323,6 +331,24 @@ def install_extensions_requirements(): run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True) +def clean_outdated_pytorch_cuda_dependencies(): + patterns = ["cu121", "cu122", "torch2.4"] + result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True) + matching_packages = [] + + for line in result.stdout.decode('utf-8').splitlines(): + if "==" in line: + pkg_name, version = line.split('==', 1) + if any(pattern in version for pattern in patterns): + matching_packages.append(pkg_name) + + if matching_packages: + print(f"Uninstalling: {', '.join(matching_packages)}") + run_cmd(f"python -m pip uninstall -y {' '.join(matching_packages)}", assert_success=True, environment=True) + + return matching_packages + + def update_requirements(initial_installation=False, pull=True): # Create .git directory if missing if not os.path.exists(os.path.join(script_dir, ".git")): @@ -410,7 +436,9 @@ def update_requirements(initial_installation=False, pull=True): # Update PyTorch if not initial_installation: - update_pytorch() + clean_outdated_pytorch_cuda_dependencies() + update_pytorch_and_python() + torver = torch_version() print_big_message(f"Installing webui requirements from file: {requirements_file}") print(f"TORCH: {torver}\n") diff --git a/requirements.txt b/requirements.txt index 4cf99b69..b9b4ea7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,16 +36,18 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cp https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" # llama-cpp-python (CUDA, with GGML_CUDA_FORCE_MMQ) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" # llama-cpp-python (CUDA, without GGML_CUDA_FORCE_MMQ) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" # CUDA wheels -https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu121.torch2.4.1-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu121.torch2.4.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system == "Linux" and platform_machine != "x86_64" -https://github.com/oobabooga/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu122torch2.4.1cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" diff --git a/requirements_amd.txt b/requirements_amd.txt index 0d205725..3d24891f 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -36,5 +36,5 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cp # AMD wheels https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.3.8+rocm6.1.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.4.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index 93a46a64..057b631d 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -35,5 +35,5 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cp https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" # AMD wheels -https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.4.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index 00353bfd..eba21ec2 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -33,4 +33,5 @@ tiktoken # Mac wheels https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" +https://github.com/oobabooga/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1-py3-none-any.whl https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index 7076b386..2048c99b 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -34,4 +34,5 @@ tiktoken https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11" +https://github.com/oobabooga/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1-py3-none-any.whl https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index d5f456f8..60b71ac1 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -36,16 +36,18 @@ https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cp https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" # llama-cpp-python (CUDA, with GGML_CUDA_FORCE_MMQ) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu121avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" # llama-cpp-python (CUDA, without GGML_CUDA_FORCE_MMQ) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu121avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu121avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" # CUDA wheels -https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu121.torch2.4.1-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu121.torch2.4.1-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system == "Linux" and platform_machine != "x86_64" -https://github.com/oobabooga/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu122torch2.4.1cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu124torch2.6.0cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"