diff --git a/modules/chat.py b/modules/chat.py index ad2f4001..9857479a 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -3,8 +3,10 @@ import copy import functools import html import json +import os import pprint import re +import shutil import time from datetime import datetime from functools import partial @@ -1194,7 +1196,7 @@ def find_all_histories_with_first_prompts(state): if re.match(r'^[0-9]{8}-[0-9]{2}-[0-9]{2}-[0-9]{2}$', filename): first_prompt = "" if data and 'visible' in data and len(data['visible']) > 0: - if data['internal'][0][0] == '<|BEGIN-VISIBLE-CHAT|>': + if len(data['internal']) > 0 and data['internal'][0][0] == '<|BEGIN-VISIBLE-CHAT|>': if len(data['visible']) > 1: first_prompt = html.unescape(data['visible'][1][0]) elif i == 0: @@ -1385,12 +1387,17 @@ def generate_pfp_cache(character): for path in [Path(f"user_data/characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: if path.exists(): original_img = Image.open(path) - original_img.save(Path(f'{cache_folder}/pfp_character.png'), format='PNG') + # Define file paths + pfp_path = Path(f'{cache_folder}/pfp_character.png') + thumb_path = Path(f'{cache_folder}/pfp_character_thumb.png') + # Save main picture and thumbnail + original_img.save(pfp_path, format='PNG') thumb = make_thumbnail(original_img) - thumb.save(Path(f'{cache_folder}/pfp_character_thumb.png'), format='PNG') + thumb.save(thumb_path, format='PNG') - return thumb + # Return the path to the thumbnail, not the in-memory PIL Image object. + return str(thumb_path) return None @@ -1507,7 +1514,22 @@ def load_instruction_template_memoized(template): return load_instruction_template(template) -def upload_character(file, img, tavern=False): +def open_image_safely(path): + if path is None or not isinstance(path, str) or not Path(path).exists(): + return None + + if os.path.islink(path): + return None + + try: + return Image.open(path) + except Exception as e: + logger.error(f"Failed to open image file: {path}. Reason: {e}") + return None + + +def upload_character(file, img_path, tavern=False): + img = open_image_safely(img_path) decoded_file = file if isinstance(file, str) else file.decode('utf-8') try: data = json.loads(decoded_file) @@ -1554,12 +1576,17 @@ def build_pygmalion_style_context(data): return context -def upload_tavern_character(img, _json): +def upload_tavern_character(img_path, _json): _json = {'char_name': _json['name'], 'char_persona': _json['description'], 'char_greeting': _json['first_mes'], 'example_dialogue': _json['mes_example'], 'world_scenario': _json['scenario']} - return upload_character(json.dumps(_json), img, tavern=True) + return upload_character(json.dumps(_json), img_path, tavern=True) -def check_tavern_character(img): +def check_tavern_character(img_path): + img = open_image_safely(img_path) + + if img is None: + return "Invalid or disallowed image file.", None, None, gr.update(interactive=False) + if "chara" not in img.info: return "Not a TavernAI card", None, None, gr.update(interactive=False) @@ -1571,7 +1598,8 @@ def check_tavern_character(img): return _json['name'], _json['description'], _json, gr.update(interactive=True) -def upload_your_profile_picture(img): +def upload_your_profile_picture(img_path): + img = open_image_safely(img_path) cache_folder = Path(shared.args.disk_cache_dir) if not cache_folder.exists(): cache_folder.mkdir() @@ -1614,15 +1642,19 @@ def save_character(name, greeting, context, picture, filename): save_file(filepath, data) path_to_img = Path(f'user_data/characters/{filename}.png') if picture is not None: - picture.save(path_to_img) + # Copy the image file from its source path to the character folder + shutil.copy(picture, path_to_img) logger.info(f'Saved {path_to_img}.') def delete_character(name, instruct=False): + # Check for character data files for extension in ["yml", "yaml", "json"]: delete_file(Path(f'user_data/characters/{name}.{extension}')) - delete_file(Path(f'user_data/characters/{name}.png')) + # Check for character image files + for extension in ["png", "jpg", "jpeg"]: + delete_file(Path(f'user_data/characters/{name}.{extension}')) def jinja_template_from_old_format(params, verbose=False): @@ -1974,8 +2006,9 @@ def handle_character_menu_change(state): ] -def handle_character_picture_change(picture): +def handle_character_picture_change(picture_path): """Update or clear cache when character picture changes""" + picture = open_image_safely(picture_path) cache_folder = Path(shared.args.disk_cache_dir) if not cache_folder.exists(): cache_folder.mkdir() diff --git a/modules/exllamav3.py b/modules/exllamav3.py index f7078028..d884bbf7 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -2,6 +2,8 @@ import traceback from pathlib import Path from typing import Any, List, Tuple +import torch + from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.generator import Job @@ -16,7 +18,6 @@ from exllamav3.generator.sampler import ( SS_TopK, SS_TopP ) - from modules import shared from modules.image_utils import ( convert_image_attachments_to_pil, @@ -171,7 +172,7 @@ class Exllamav3Model: result.draft_model = draft_model result.draft_cache = draft_cache - return result + return result, result def is_multimodal(self) -> bool: """Check if this model supports multimodal input.""" @@ -367,11 +368,51 @@ class Exllamav3Model: return output + def get_logits(self, token_ids, **kwargs): + """ + Process a batch of token_ids and return the logits for the last token. + This will reset and overwrite the model's cache. + """ + # Initialize a single params dictionary that will be updated in-place + params = { + "cache": self.cache, + "reconstruct": False, + "attn_mode": "flash_attn", + "batch_shape": (1, self.max_tokens), + "past_len": 0 + } + params.update(kwargs) + + # Process prefix tokens to fill the cache and generate recurrent state + if token_ids.shape[-1] > 1: + prefix_ids = token_ids[:, :-1] + + # This forward call updates the 'params' dict with the recurrent state + self.model.forward( + input_ids=prefix_ids, + params=params + ) + + # Update past_len for the next call + params["past_len"] = prefix_ids.shape[-1] + + # Process the last token, now using the state-filled 'params' dict + last_token_ids = token_ids[:, -1:] + logits = self.model.forward( + input_ids=last_token_ids, + params=params + ) + + return logits.float().cpu() + def encode(self, string, **kwargs): add_bos = kwargs.pop('add_bos', True) return self.tokenizer.encode(string, add_bos=add_bos, **kwargs) def decode(self, ids, **kwargs): + if isinstance(ids, torch.Tensor) and ids.dim() == 0: + ids = ids.view(1) + return self.tokenizer.decode(ids, **kwargs) @property diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index 05b473b7..c606912b 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -103,6 +103,12 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): labels = kwargs.get('labels', None) past_key_values = kwargs.get('past_key_values', None) + # Reset the internal sequence state for standalone calls (logit viewer) + # or the very first step of a new generation. + if past_key_values is None: + self.past_seq = None + self.past_seq_negative = None + if len(args) > 0: if not shared.args.cfg_cache: logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.") @@ -119,8 +125,8 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): ex_cache = self.ex_cache seq = input_ids[0].tolist() - if is_negative and past_key_values is not None: - seq = past_key_values + seq + if is_negative and past_key_values is not None and isinstance(past_key_values, list): + seq = past_key_values + seq seq_tensor = torch.tensor(seq) reset = True @@ -128,97 +134,50 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): # Maximum number of tokens to process in a single forward pass max_chunk_size = 256 + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) == 0 and seq_tensor.shape[0] > past_seq.shape[0]: + reset = False + + # Create a single `params` dictionary that will be used and modified + # in-place across all `forward` calls within this function. + params = { + "attn_mode": "flash_attn", + "cache": ex_cache, + "batch_shape": (1, self.max_tokens), + "reconstruct": False, + "past_len": 0 + } + # Make the forward call if labels is None: - if past_seq is not None: - min_length = min(past_seq.shape[0], seq_tensor.shape[0]) - indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) - if len(indices) > 0: - longest_prefix = indices[0].item() - else: - longest_prefix = min_length + # If it's an efficient continuation, process only the new tokens + if not reset: + params["past_len"] = past_seq.shape[0] + tokens_to_process = seq_tensor[past_seq.shape[0]:] + # Otherwise, process the whole sequence from scratch + else: + tokens_to_process = seq_tensor - if longest_prefix > 0: - reset = False - current_len = longest_prefix - remaining_tokens = len(seq_tensor) - longest_prefix - 1 + # Process all but the last token of the sequence/sub-sequence + if tokens_to_process.shape[0] > 1: + prefix_to_process = tokens_to_process[:-1] - if remaining_tokens > 0: - # Process tokens from longest_prefix to second-to-last token - tokens_to_process = seq_tensor[longest_prefix:-1] + # Process in chunks if the number of tokens is large + for i in range(0, prefix_to_process.shape[0], max_chunk_size): + chunk = prefix_to_process[i:i + max_chunk_size] + self.ex_model.forward(input_ids=chunk.view(1, -1), params=params) + params["past_len"] += chunk.shape[0] - # Process in chunks if the number of tokens is large - for i in range(0, tokens_to_process.shape[0], max_chunk_size): - chunk = tokens_to_process[i:i + max_chunk_size] - self.ex_model.forward( - input_ids=chunk.view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": longest_prefix + i, - "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path - } - ) - - current_len = longest_prefix + remaining_tokens - - if reset: - if len(seq_tensor) > 1: - # Process all tokens except the last one - tokens_to_process = seq_tensor[:-1] - - # Process in chunks if the number of tokens is large - current_len = 0 - for i in range(0, tokens_to_process.shape[0], max_chunk_size): - chunk = tokens_to_process[i:i + max_chunk_size] - self.ex_model.forward( - input_ids=chunk.view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": current_len, - "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path - } - ) - current_len += chunk.shape[0] - else: - current_len = 0 - - # Process the last token and get logits - logits = self.ex_model.forward( - input_ids=seq_tensor[-1:].view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": current_len, - "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path - } - ).to(input_ids.device).float() + # Process the last token to get logits + last_token = tokens_to_process[-1:].view(1, -1) + logits = self.ex_model.forward(input_ids=last_token, params=params).to(input_ids.device).float() else: # When processing with labels, handle as a complete sequence - # Process in chunks if the number of tokens is large - tokens_to_process = seq_tensor - all_logits = None + params["attn_mode"] = "flash_attn_nc" + logits = self.ex_model.forward(input_ids=seq_tensor.view(1,-1), params=params).float() - for i in range(0, tokens_to_process.shape[0], max_chunk_size): - chunk = tokens_to_process[i:i + max_chunk_size] - chunk_logits = self.ex_model.forward( - input_ids=chunk.view(1, -1), - params={ - "attn_mode": "flash_attn_nc", # No caching for training - "reconstruct": False # Force memory-efficient path - } - ).float() - - if all_logits is None: - all_logits = chunk_logits - else: - all_logits = torch.cat([all_logits, chunk_logits], dim=1) - - logits = all_logits if is_negative: self.past_seq_negative = seq_tensor diff --git a/modules/logits.py b/modules/logits.py index 56a20572..d668e44e 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -71,6 +71,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur from modules.torch_utils import get_device is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model' + is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model' if not use_samplers: state = {'stream': True} @@ -88,7 +89,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur scores = sampler_hijack.global_scores[-1] else: - if is_non_hf_exllamav2: + if is_non_hf_exllamav2 or is_non_hf_exllamav3: device = get_device() tokens = shared.tokenizer.encode(prompt) if device: diff --git a/modules/models.py b/modules/models.py index 9535ea82..8c0f1c37 100644 --- a/modules/models.py +++ b/modules/models.py @@ -104,8 +104,7 @@ def ExLlamav3_HF_loader(model_name): def ExLlamav3_loader(model_name): from modules.exllamav3 import Exllamav3Model - model = Exllamav3Model.from_pretrained(model_name) - tokenizer = model.tokenizer + model, tokenizer = Exllamav3Model.from_pretrained(model_name) return model, tokenizer diff --git a/modules/torch_utils.py b/modules/torch_utils.py index ad9b26ad..418520a8 100644 --- a/modules/torch_utils.py +++ b/modules/torch_utils.py @@ -8,7 +8,9 @@ from modules import shared def get_device(): - if torch.cuda.is_available(): + if hasattr(shared.model, 'device'): + return shared.model.device + elif torch.cuda.is_available(): return torch.device('cuda') elif shared.args.deepspeed: import deepspeed diff --git a/modules/ui_chat.py b/modules/ui_chat.py index 7c388607..c342ce5b 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -152,14 +152,14 @@ def create_character_settings_ui(): with gr.Tab('YAML or JSON'): with gr.Row(): shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json', '.yaml'], label='JSON or YAML File', interactive=not mu) - shared.gradio['upload_img_bot'] = gr.Image(type='pil', label='Profile Picture (optional)', interactive=not mu) + shared.gradio['upload_img_bot'] = gr.Image(type='filepath', label='Profile Picture (optional)', interactive=not mu) shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False) with gr.Tab('TavernAI PNG'): with gr.Row(): with gr.Column(): - shared.gradio['upload_img_tavern'] = gr.Image(type='pil', label='TavernAI PNG File', elem_id='upload_img_tavern', interactive=not mu) + shared.gradio['upload_img_tavern'] = gr.Image(type='filepath', label='TavernAI PNG File', elem_id='upload_img_tavern', interactive=not mu) shared.gradio['tavern_json'] = gr.State() with gr.Column(): shared.gradio['tavern_name'] = gr.Textbox(value='', lines=1, label='Name', interactive=False) @@ -168,8 +168,8 @@ def create_character_settings_ui(): shared.gradio['Submit tavern character'] = gr.Button(value='Submit', interactive=False) with gr.Column(scale=1): - shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil', interactive=not mu) - shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('user_data/cache/pfp_me.png')) if Path('user_data/cache/pfp_me.png').exists() else None, interactive=not mu) + shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu) + shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(Path('user_data/cache/pfp_me.png')) if Path('user_data/cache/pfp_me.png').exists() else None, interactive=not mu) def create_chat_settings_ui(): diff --git a/requirements/full/requirements.txt b/requirements/full/requirements.txt index 85119c65..97eac769 100644 --- a/requirements/full/requirements.txt +++ b/requirements/full/requirements.txt @@ -1,10 +1,11 @@ accelerate==1.8.* audioop-lts<1.0; python_version >= "3.13" -bitsandbytes==0.46.* +bitsandbytes==0.48.* colorama datasets einops fastapi==0.112.4 +flash-linear-attention==0.3.2 gradio==4.37.* html2text==2025.4.15 jinja2==3.1.6 @@ -24,7 +25,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.3.1.post19; platform_system == "Windows" tqdm wandb @@ -35,10 +36,10 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/turboderp-org/exllamav3/releases/download/v0.0.6/exllamav3-0.0.6+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/turboderp-org/exllamav3/releases/download/v0.0.6/exllamav3-0.0.6+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/turboderp-org/exllamav3/releases/download/v0.0.7/exllamav3-0.0.7+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/turboderp-org/exllamav3/releases/download/v0.0.7/exllamav3-0.0.7+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2-py3-none-any.whl; platform_system == "Linux" and platform_machine != "x86_64" diff --git a/requirements/full/requirements_amd.txt b/requirements/full/requirements_amd.txt index ffd496f3..b3b0005e 100644 --- a/requirements/full/requirements_amd.txt +++ b/requirements/full/requirements_amd.txt @@ -23,7 +23,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.2.0.post19; platform_system == "Windows" tqdm wandb @@ -34,7 +34,7 @@ sse-starlette==1.6.5 tiktoken # AMD wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+rocm6.2.4.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" diff --git a/requirements/full/requirements_amd_noavx2.txt b/requirements/full/requirements_amd_noavx2.txt index 7a35b553..5e0d375e 100644 --- a/requirements/full/requirements_amd_noavx2.txt +++ b/requirements/full/requirements_amd_noavx2.txt @@ -23,7 +23,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.2.0.post19; platform_system == "Windows" tqdm wandb @@ -34,7 +34,7 @@ sse-starlette==1.6.5 tiktoken # AMD wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+vulkanavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+vulkanavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+rocm6.2.4.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" diff --git a/requirements/full/requirements_apple_intel.txt b/requirements/full/requirements_apple_intel.txt index ebf13242..0bb837ba 100644 --- a/requirements/full/requirements_apple_intel.txt +++ b/requirements/full/requirements_apple_intel.txt @@ -23,7 +23,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.2.0.post19; platform_system == "Windows" tqdm wandb @@ -34,7 +34,5 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" -https://github.com/oobabooga/exllamav3/releases/download/v0.0.6/exllamav3-0.0.6-py3-none-any.whl -https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2-py3-none-any.whl +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" diff --git a/requirements/full/requirements_apple_silicon.txt b/requirements/full/requirements_apple_silicon.txt index 00303ff9..514c0662 100644 --- a/requirements/full/requirements_apple_silicon.txt +++ b/requirements/full/requirements_apple_silicon.txt @@ -23,7 +23,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.2.0.post19; platform_system == "Windows" tqdm wandb @@ -34,8 +34,6 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11" -https://github.com/oobabooga/exllamav3/releases/download/v0.0.6/exllamav3-0.0.6-py3-none-any.whl -https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2-py3-none-any.whl +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11" diff --git a/requirements/full/requirements_cpu_only.txt b/requirements/full/requirements_cpu_only.txt index 9a578501..f68fdd9d 100644 --- a/requirements/full/requirements_cpu_only.txt +++ b/requirements/full/requirements_cpu_only.txt @@ -23,7 +23,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.2.0.post19; platform_system == "Windows" tqdm wandb @@ -34,5 +34,5 @@ sse-starlette==1.6.5 tiktoken # llama.cpp (CPU only, AVX2) -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cpuavx2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cpuavx2-py3-none-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cpuavx2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cpuavx2-py3-none-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" diff --git a/requirements/full/requirements_cpu_only_noavx2.txt b/requirements/full/requirements_cpu_only_noavx2.txt index d777a013..b40f1af2 100644 --- a/requirements/full/requirements_cpu_only_noavx2.txt +++ b/requirements/full/requirements_cpu_only_noavx2.txt @@ -23,7 +23,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.2.0.post19; platform_system == "Windows" tqdm wandb @@ -34,5 +34,5 @@ sse-starlette==1.6.5 tiktoken # llama.cpp (CPU only, no AVX2) -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cpuavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cpuavx-py3-none-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cpuavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cpuavx-py3-none-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" diff --git a/requirements/full/requirements_noavx2.txt b/requirements/full/requirements_noavx2.txt index f35dd111..9de9e65a 100644 --- a/requirements/full/requirements_noavx2.txt +++ b/requirements/full/requirements_noavx2.txt @@ -1,10 +1,11 @@ accelerate==1.8.* audioop-lts<1.0; python_version >= "3.13" -bitsandbytes==0.46.* +bitsandbytes==0.48.* colorama datasets einops fastapi==0.112.4 +flash-linear-attention==0.3.2 gradio==4.37.* html2text==2025.4.15 jinja2==3.1.6 @@ -24,7 +25,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.3.1.post19; platform_system == "Windows" tqdm wandb @@ -35,10 +36,10 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cu124avx-py3-none-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cu124avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/turboderp-org/exllamav3/releases/download/v0.0.6/exllamav3-0.0.6+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/turboderp-org/exllamav3/releases/download/v0.0.6/exllamav3-0.0.6+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cu124avx-py3-none-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cu124avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/turboderp-org/exllamav3/releases/download/v0.0.7/exllamav3-0.0.7+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/turboderp-org/exllamav3/releases/download/v0.0.7/exllamav3-0.0.7+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+cu128.torch2.7.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2-py3-none-any.whl; platform_system == "Linux" and platform_machine != "x86_64" diff --git a/requirements/full/requirements_nowheels.txt b/requirements/full/requirements_nowheels.txt index 1f63e304..3bd20dd9 100644 --- a/requirements/full/requirements_nowheels.txt +++ b/requirements/full/requirements_nowheels.txt @@ -23,7 +23,7 @@ safetensors==0.6.* scipy sentencepiece tensorboard -transformers==4.56.* +transformers==4.57.* triton-windows==3.2.0.post19; platform_system == "Windows" tqdm wandb diff --git a/requirements/portable/requirements.txt b/requirements/portable/requirements.txt index 2162fddf..7a38b1e6 100644 --- a/requirements/portable/requirements.txt +++ b/requirements/portable/requirements.txt @@ -19,5 +19,5 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" diff --git a/requirements/portable/requirements_apple_intel.txt b/requirements/portable/requirements_apple_intel.txt index 91150ed1..047d1c54 100644 --- a/requirements/portable/requirements_apple_intel.txt +++ b/requirements/portable/requirements_apple_intel.txt @@ -19,5 +19,6 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" \ No newline at end of file diff --git a/requirements/portable/requirements_apple_silicon.txt b/requirements/portable/requirements_apple_silicon.txt index 22240386..5c8ae4df 100644 --- a/requirements/portable/requirements_apple_silicon.txt +++ b/requirements/portable/requirements_apple_silicon.txt @@ -19,6 +19,6 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" diff --git a/requirements/portable/requirements_cpu_only.txt b/requirements/portable/requirements_cpu_only.txt index 847e4450..f41efd58 100644 --- a/requirements/portable/requirements_cpu_only.txt +++ b/requirements/portable/requirements_cpu_only.txt @@ -19,5 +19,5 @@ sse-starlette==1.6.5 tiktoken # llama.cpp (CPU only, AVX2) -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cpuavx2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cpuavx2-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cpuavx2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cpuavx2-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements/portable/requirements_cpu_only_noavx2.txt b/requirements/portable/requirements_cpu_only_noavx2.txt index f34e1847..69158050 100644 --- a/requirements/portable/requirements_cpu_only_noavx2.txt +++ b/requirements/portable/requirements_cpu_only_noavx2.txt @@ -19,5 +19,5 @@ sse-starlette==1.6.5 tiktoken # llama.cpp (CPU only, no AVX2) -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cpuavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cpuavx-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cpuavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cpuavx-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements/portable/requirements_noavx2.txt b/requirements/portable/requirements_noavx2.txt index 771d0362..ca66098c 100644 --- a/requirements/portable/requirements_noavx2.txt +++ b/requirements/portable/requirements_noavx2.txt @@ -19,5 +19,5 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cu124avx-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+cu124avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cu124avx-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+cu124avx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" diff --git a/requirements/portable/requirements_vulkan.txt b/requirements/portable/requirements_vulkan.txt index bb3a5ab8..36aff361 100644 --- a/requirements/portable/requirements_vulkan.txt +++ b/requirements/portable/requirements_vulkan.txt @@ -19,5 +19,5 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" diff --git a/requirements/portable/requirements_vulkan_noavx2.txt b/requirements/portable/requirements_vulkan_noavx2.txt index fbc52282..be7170e3 100644 --- a/requirements/portable/requirements_vulkan_noavx2.txt +++ b/requirements/portable/requirements_vulkan_noavx2.txt @@ -19,5 +19,5 @@ sse-starlette==1.6.5 tiktoken # CUDA wheels -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows" -https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.46.0/llama_cpp_binaries-0.46.0+vulkanavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+vulkanavx-py3-none-win_amd64.whl; platform_system == "Windows" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.49.0/llama_cpp_binaries-0.49.0+vulkanavx-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64"