diff --git a/extensions/openai/logits.py b/extensions/openai/logits.py index 357e70fa..280612db 100644 --- a/extensions/openai/logits.py +++ b/extensions/openai/logits.py @@ -5,7 +5,5 @@ from modules.logits import get_next_logits def _get_next_logits(body): # Pre-process the input payload to simulate a real generation use_samplers = body['use_samplers'] - state = process_parameters(body) if use_samplers else {} - state['stream'] = True - + state = process_parameters(body) return get_next_logits(body['prompt'], state, use_samplers, "", top_logits=body['top_logits'], return_dict=True) diff --git a/modules/cache_utils.py b/modules/cache_utils.py deleted file mode 100644 index 0d1368a2..00000000 --- a/modules/cache_utils.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -from numba import njit - -from modules import shared - - -def process_llamacpp_cache(model, new_sequence, past_sequence): - if len(past_sequence) == 0 or len(new_sequence) == 0: - return past_sequence - - i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence) - overlap_length = i2 - i1 + 1 - - # Do StreamingLLM if i1 > 0 (ie the longest common subsequence is not a prefix) - # and the overlap length is sufficiently long. - if i1 > 0 and overlap_length > 0.2 * len(new_sequence): - - new_sequence = torch.tensor(new_sequence) - past_sequence = torch.tensor(past_sequence) - - prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1]) - sink_length = max(prefix_length, shared.args.attention_sink_size) - removed_length = i1 - sink_length - - if removed_length <= 0: - return past_sequence.tolist() - - matching_prefix = past_sequence[:prefix_length] - removed_chunk = past_sequence[sink_length:i1] - overlapping_sequence = new_sequence[j1:j2 + 1] - added_chunk = new_sequence[j2 + 1:] - - # print(past_sequence.tolist()) - # print(new_sequence.tolist()) - - print() - print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix))) - print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk))) - print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk))) - print('REMOVED LENGTH=', removed_length) - print() - - # Remove interval [sink_length, sink_length + removed_length) from the context - # Update model.n_tokens - model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length) - model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length) - - new_sequence = new_sequence.tolist() - model.input_ids[:j2 + 1] = new_sequence[:j2 + 1] - model.n_tokens = j2 + 1 - - return new_sequence[:j2 + 1] - else: - return past_sequence - - -def find_prefix_length(past_seq, seq_tensor): - ''' - Given two torch tensors, finds the length of the longest - common prefix between the two. - ''' - min_length = min(past_seq.shape[0], seq_tensor.shape[0]) - indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) - if len(indices) > 0: - prefix_length = indices[0].item() - else: - prefix_length = min_length - - return prefix_length - - -@njit -def find_longest_common_substring_indices(list1, list2): - ''' - Given two lists, solves the Longest Common Substring problem. - - It returns the indices where the substring starts and ends in - s1 and s2. - - Example: - - ir, jr, ir2, jr2 = find_longest_common_substring_indices(s1, s2) - print(s1[ir:jr + 1]) - print(s2[ir2:jr2 + 1]) - - Adapted from - https://rosettacode.org/wiki/Longest_common_substring#Python - ''' - - len_list1, len_list2 = len(list1), len(list2) - start_index_list1, end_index_list1 = 0, -1 - start_index_list2, end_index_list2 = 0, -1 - - # for index1 in tqdm(range(0, len_list1), desc="StreamingLLM prompt comparison", leave=False): - for index1 in range(0, len_list1): - try: - index2 = list2.index(list1[index1]) - except: - continue - - while index2 >= 0: - temp_index1, temp_index2 = index1, index2 - while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]: - if temp_index1 - index1 >= end_index_list1 - start_index_list1: - start_index_list1, end_index_list1 = index1, temp_index1 - start_index_list2, end_index_list2 = index2, temp_index2 - - temp_index1 += 1 - temp_index2 += 1 - try: - index2 = list2.index(list1[index1], index2 + 1) - except: - break - - return start_index_list1, end_index_list1, start_index_list2, end_index_list2 diff --git a/modules/evaluate.py b/modules/evaluate.py index 35c72689..dd7ef9a7 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -40,17 +40,13 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): ''' if shared.args.loader == "llama.cpp": - logger.error("llamacpp_HF is required for perplexity evaluation with GGUF models. Please reload the model with llamacpp_HF instead of llama.cpp.") + logger.error("Perplexity evaluation is not implemented for the llama.cpp loader.") raise ValueError if shared.args.loader == "ExLlamav2": logger.error("ExLlamav2_HF is required for perplexity evaluation with EXL2 models. Please reload the model with ExLlamav2_HF instead of ExLlamav2.") raise ValueError - if shared.args.loader == "llamacpp_HF" and not shared.args.logits_all: - logger.error("--logits_all is required for perplexity evaluation with GGUF models. Please reload the model with that option set/checked.") - raise ValueError - if not shared.args.no_use_fast: logger.warning("--no_use_fast is not set. If tokenizing the input dataset takes a long time, try reloading the model with that option set/checked.") diff --git a/modules/llama_cpp_python_hijack.py b/modules/llama_cpp_python_hijack.py deleted file mode 100644 index 8572cd81..00000000 --- a/modules/llama_cpp_python_hijack.py +++ /dev/null @@ -1,165 +0,0 @@ -import importlib -import platform -from typing import Sequence - -import numpy as np -from tqdm import tqdm - -from modules import shared -from modules.cache_utils import process_llamacpp_cache - -imported_module = None -not_available_modules = set() - - -def llama_cpp_lib(): - global imported_module, not_available_modules - - # Determine the platform - is_macos = platform.system() == 'Darwin' - - # Define the library names based on the platform - if is_macos: - lib_names = [ - (None, 'llama_cpp') - ] - else: - lib_names = [ - ('cpu', 'llama_cpp'), - ('tensorcores', 'llama_cpp_cuda_tensorcores'), - (None, 'llama_cpp_cuda'), - (None, 'llama_cpp') - ] - - for arg, lib_name in lib_names: - if lib_name in not_available_modules: - continue - - should_import = (arg is None or getattr(shared.args, arg)) - - if should_import: - if imported_module and imported_module != lib_name: - # Conflict detected, raise an exception - raise Exception(f"Cannot import `{lib_name}` because `{imported_module}` is already imported. Switching to a different version of llama-cpp-python currently requires a server restart.") - - try: - return_lib = importlib.import_module(lib_name) - imported_module = lib_name - monkey_patch_llama_cpp_python(return_lib) - return return_lib - except ImportError: - not_available_modules.add(lib_name) - continue - - return None - - -def eval_with_progress(self, tokens: Sequence[int]): - """ - A copy of - - https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py - - with tqdm to show prompt processing progress. - """ - self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) - - if len(tokens) > self.n_batch: - progress_bar = tqdm(range(0, len(tokens), self.n_batch), desc="Prompt evaluation", leave=False) - else: - progress_bar = range(0, len(tokens), self.n_batch) - - for i in progress_bar: - batch = tokens[i : min(len(tokens), i + self.n_batch)] - n_past = self.n_tokens - n_tokens = len(batch) - self._batch.set_batch( - batch=batch, n_past=n_past, logits_all=self.context_params.logits_all - ) - self._ctx.decode(self._batch) - # Save tokens - self.input_ids[n_past : n_past + n_tokens] = batch - # Save logits - if self.context_params.logits_all: - rows = n_tokens - cols = self._n_vocab - logits = np.ctypeslib.as_array( - self._ctx.get_logits(), shape=(rows * cols,) - ) - self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits - self.last_updated_index = n_past + n_tokens - 1 - else: - rows = 1 - cols = self._n_vocab - logits = np.ctypeslib.as_array( - self._ctx.get_logits(), shape=(rows * cols,) - ) - last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1) - self.scores[last_token_index, :] = logits.reshape(-1) - self.last_updated_index = last_token_index - # Update n_tokens - self.n_tokens += n_tokens - - -def monkey_patch_llama_cpp_python(lib): - if getattr(lib.Llama, '_is_patched', False): - # If the patch is already applied, do nothing - return - - def my_generate(self, *args, **kwargs): - if shared.args.streaming_llm: - new_sequence = args[0] - past_sequence = self._input_ids - - # Do the cache trimming for StreamingLLM - process_llamacpp_cache(self, new_sequence, past_sequence) - - for output in self.original_generate(*args, **kwargs): - yield output - - lib.Llama.eval = eval_with_progress - lib.Llama.original_generate = lib.Llama.generate - lib.Llama.generate = my_generate - - # Also patch Jinja2ChatFormatter to handle loop controls - if hasattr(lib, 'llama_chat_format') and hasattr(lib.llama_chat_format, 'Jinja2ChatFormatter'): - Formatter = lib.llama_chat_format.Jinja2ChatFormatter - - if not getattr(Formatter, '_is_patched', False): - def patched_init(self, *args, **kwargs): - # Extract parameters from args or kwargs - if args: - self.template = args[0] - self.eos_token = args[1] if len(args) > 1 else kwargs.get('eos_token') - self.bos_token = args[2] if len(args) > 2 else kwargs.get('bos_token') - self.add_generation_prompt = args[3] if len(args) > 3 else kwargs.get('add_generation_prompt', True) - self.stop_token_ids = args[4] if len(args) > 4 else kwargs.get('stop_token_ids') - else: - self.template = kwargs.get('template') - self.eos_token = kwargs.get('eos_token') - self.bos_token = kwargs.get('bos_token') - self.add_generation_prompt = kwargs.get('add_generation_prompt', True) - self.stop_token_ids = kwargs.get('stop_token_ids') - - # Process stop tokens as in the original - self.stop_token_ids = ( - set(self.stop_token_ids) if self.stop_token_ids is not None else None - ) - - # Create environment with loopcontrols extension - import jinja2 - from jinja2.ext import loopcontrols - - self._environment = jinja2.sandbox.ImmutableSandboxedEnvironment( - loader=jinja2.BaseLoader(), - trim_blocks=True, - lstrip_blocks=True, - extensions=[loopcontrols] - ).from_string(self.template) - - # Replace the original __init__ with our patched version - Formatter.__init__ = patched_init - Formatter._is_patched = True - - # Set the flag to indicate that the patch has been applied - lib.Llama._is_patched = True diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py new file mode 100644 index 00000000..983b506f --- /dev/null +++ b/modules/llama_cpp_server.py @@ -0,0 +1,338 @@ +import json +import pprint +import socket +import subprocess +import sys +import threading +import time + +import llama_cpp_binaries +import requests + +from modules import shared +from modules.logging_colors import logger + +llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"} + + +class LlamaServer: + def __init__(self, model_path, server_path=None): + """ + Initialize and start a server for llama.cpp models. + """ + self.model_path = model_path + self.server_path = server_path + self.port = self._find_available_port() + self.process = None + self.max_context_length = None + self.bos_token = "" + + # Start the server + self._start_server() + + def encode(self, text, add_bos_token=False, **kwargs): + if self.bos_token and text.startswith(self.bos_token): + add_bos_token = False + + url = f"http://localhost:{self.port}/tokenize" + payload = { + "content": text, + "add_special": add_bos_token, + } + + response = requests.post(url, json=payload) + result = response.json() + return result.get("tokens", []) + + def decode(self, token_ids, **kwargs): + url = f"http://localhost:{self.port}/detokenize" + payload = { + "tokens": token_ids, + } + + response = requests.post(url, json=payload) + result = response.json() + return result.get("content", "") + + def prepare_payload(self, state): + # Prepare DRY + dry_sequence_breakers = state['dry_sequence_breakers'] + if not dry_sequence_breakers.startswith("["): + dry_sequence_breakers = "[" + dry_sequence_breakers + "]" + dry_sequence_breakers = json.loads(dry_sequence_breakers) + + # Prepare the sampler order + samplers = state["sampler_priority"] + samplers = samplers.split("\n") if isinstance(samplers, str) else samplers + penalty_found = False + filtered_samplers = [] + for s in samplers: + if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]: + filtered_samplers.append(s.strip()) + elif not penalty_found and s.strip() == "repetition_penalty": + filtered_samplers.append("penalties") + penalty_found = True + + samplers = filtered_samplers + + # Move temperature to the end if temperature_last is true and temperature exists in the list + if state["temperature_last"] and "temperature" in samplers: + samplers.remove("temperature") + samplers.append("temperature") + + payload = { + "temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2, + "dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2, + "dynatemp_exponent": state["dynatemp_exponent"], + "top_k": state["top_k"], + "top_p": state["top_p"], + "min_p": state["min_p"], + "tfs_z": state["tfs"], + "typical_p": state["typical_p"], + "repeat_penalty": state["repetition_penalty"], + "repeat_last_n": state["repetition_penalty_range"], + "presence_penalty": state["presence_penalty"], + "frequency_penalty": state["frequency_penalty"], + "dry_multiplier": state["dry_multiplier"], + "dry_base": state["dry_base"], + "dry_allowed_length": state["dry_allowed_length"], + "dry_penalty_last_n": state["repetition_penalty_range"], + "dry_sequence_breakers": dry_sequence_breakers, + "xtc_probability": state["xtc_probability"], + "xtc_threshold": state["xtc_threshold"], + "mirostat": state["mirostat_mode"], + "mirostat_tau": state["mirostat_tau"], + "mirostat_eta": state["mirostat_eta"], + "grammar": state["grammar_string"], + "seed": state["seed"], + "ignore_eos": state["ban_eos_token"], + "samplers": samplers, + } + + if state['custom_token_bans']: + to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')] + payload["logit_bias"] = to_ban + + return payload + + def generate_with_streaming( + self, + prompt, + state, + ): + url = f"http://localhost:{self.port}/completion" + payload = self.prepare_payload(state) + + token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"]) + if state['auto_max_new_tokens']: + max_new_tokens = state['truncation_length'] - len(token_ids) + else: + max_new_tokens = state['max_new_tokens'] + + payload.update({ + "prompt": token_ids, + "n_predict": max_new_tokens, + "stream": True, + }) + + if shared.args.verbose: + logger.info("GENERATE_PARAMS=") + printable_payload = {k: v for k, v in payload.items() if k != "prompt"} + pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) + print() + + # Make a direct request with streaming enabled + response = requests.post(url, json=payload, stream=True) + response.raise_for_status() # Raise an exception for HTTP errors + + full_text = "" + + # Process the streaming response + for line in response.iter_lines(): + if shared.stop_everything: + break + + if line: + try: + # Check if the line starts with "data: " and remove it + line_str = line.decode('utf-8') + if line_str.startswith('data: '): + line_str = line_str[6:] # Remove the "data: " prefix + + # Parse the JSON data + data = json.loads(line_str) + + # Extract the token content + if 'content' in data: + token_text = data['content'] + full_text += token_text + yield full_text + + # Check if generation is complete + if data.get('stop', False): + break + + except json.JSONDecodeError as e: + # Log the error and the problematic line + print(f"JSON decode error: {e}") + print(f"Problematic line: {line}") + continue + + def get_logits(self, prompt, state, n_probs=128, use_samplers=False): + """Get the logits/probabilities for the next token after a prompt""" + url = f"http://localhost:{self.port}/completion" + + payload = self.prepare_payload(state) + payload.update({ + "prompt": self.encode(prompt, add_bos_token=state["add_bos_token"]), + "n_predict": 0, + "logprobs": True, + "n_probs": n_probs, + "stream": False, + "post_sampling_probs": use_samplers, + }) + + if shared.args.verbose: + logger.info("GENERATE_PARAMS=") + printable_payload = {k: v for k, v in payload.items() if k != "prompt"} + pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) + print() + + response = requests.post(url, json=payload) + result = response.json() + + if "completion_probabilities" in result: + if use_samplers: + return result["completion_probabilities"][0]["top_probs"] + else: + return result["completion_probabilities"][0]["top_logprobs"] + else: + raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") + + def _get_max_context_length(self): + """Get and store the model's maximum context length.""" + url = f"http://localhost:{self.port}/v1/models" + response = requests.get(url).json() + + if "data" in response and len(response["data"]) > 0: + model_info = response["data"][0] + if "meta" in model_info and "n_vocab" in model_info["meta"]: + self.max_context_length = model_info["meta"]["n_vocab"] + + def _get_bos_token(self): + """Get and store the model's BOS token.""" + url = f"http://localhost:{self.port}/props" + response = requests.get(url).json() + if "bos_token" in response: + self.bos_token = response["bos_token"] + + def _find_available_port(self): + """Find an available port by letting the OS assign one.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) # Bind to port 0 to get an available port + return s.getsockname()[1] + + def _start_server(self): + """Start the llama.cpp server and wait until it's ready.""" + # Determine the server path + if self.server_path is None: + self.server_path = llama_cpp_binaries.get_binary_path() + + # Build the command + cmd = [ + self.server_path, + "--model", self.model_path, + "--ctx-size", str(shared.args.n_ctx), + "--n-gpu-layers", str(shared.args.n_gpu_layers), + "--batch-size", str(shared.args.batch_size), + "--port", str(self.port), + ] + + if shared.args.flash_attn: + cmd.append("--flash-attn") + if shared.args.threads > 0: + cmd += ["--threads", str(shared.args.threads)] + if shared.args.threads_batch > 0: + cmd += ["--threads-batch", str(shared.args.threads_batch)] + if shared.args.no_mmap: + cmd.append("--no-mmap") + if shared.args.mlock: + cmd.append("--mlock") + if shared.args.tensor_split: + cmd += ["--tensor-split", shared.args.tensor_split] + if shared.args.numa: + cmd += ["--numa", "distribute"] + if shared.args.no_kv_offload: + cmd.append("--no-kv-offload") + if shared.args.row_split: + cmd += ["--split-mode", "row"] + if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types: + cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type] + if shared.args.compress_pos_emb != 1: + cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)] + + # Start the server with pipes for output + self.process = subprocess.Popen( + cmd, + stderr=subprocess.PIPE, + text=True, + bufsize=1 + ) + + def filter_stderr(): + for line in iter(self.process.stderr.readline, ''): + if not line.startswith(('srv ', 'slot ')) and not 'log_server_r: request: GET /health' in line: + sys.stderr.write(line) + sys.stderr.flush() + + threading.Thread(target=filter_stderr, daemon=True).start() + + # Wait for server to be healthy + health_url = f"http://localhost:{self.port}/health" + start_time = time.time() + timeout = 3600 * 8 # 8 hours + while time.time() - start_time < timeout: + # Check if process is still alive + if self.process.poll() is not None: + # Process has terminated + exit_code = self.process.poll() + raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}") + + try: + response = requests.get(health_url) + if response.status_code == 200: + break + except: + pass + + time.sleep(1) + else: + raise TimeoutError(f"Server health check timed out after {timeout} seconds") + + # Server is now healthy, get model info + self._get_max_context_length() + self._get_bos_token() + return self.port + + def __enter__(self): + """Support for context manager.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Support for context manager.""" + self.stop() + + def __del__(self): + """Cleanup when the object is deleted.""" + self.stop() + + def stop(self): + """Stop the server process.""" + if self.process: + self.process.terminate() + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + + self.process = None diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py deleted file mode 100644 index b3761e0f..00000000 --- a/modules/llamacpp_hf.py +++ /dev/null @@ -1,220 +0,0 @@ -import os -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import torch -from torch.nn import CrossEntropyLoss -from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithPast - -from modules import shared -from modules.llama_cpp_python_hijack import llama_cpp_lib -from modules.llamacpp_model import get_llamacpp_cache_type_for_string -from modules.logging_colors import logger - - -class LlamacppHF(PreTrainedModel): - def __init__(self, model, path): - super().__init__(PretrainedConfig()) - self.model = model - self.generation_config = GenerationConfig() - - self.past_seq = None - self.llamacpp_cache = { - 'n_tokens': self.model.n_tokens, - 'input_ids': self.model.input_ids, - 'scores': self.model.scores, - 'ctx': self.model._ctx.ctx - } - - if shared.args.cfg_cache: - self.past_seq_negative = None - self.llamacpp_cache_negative = { - 'n_tokens': self.model.n_tokens, - 'input_ids': self.model.input_ids.copy(), - 'scores': self.model.scores.copy(), - 'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.context_params) - } - - def _validate_model_class(self): - pass - - def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): - pass - - def prepare_inputs_for_generation(self, input_ids, **kwargs): - return {'input_ids': input_ids, **kwargs} - - def save_cache(self): - self.llamacpp_cache.update({ - 'n_tokens': self.model.n_tokens, - 'input_ids': self.model.input_ids, - 'scores': self.model.scores, - 'ctx': self.model._ctx.ctx - }) - - def save_negative_cache(self): - self.llamacpp_cache_negative.update({ - 'n_tokens': self.model.n_tokens, - 'input_ids': self.model.input_ids, - 'scores': self.model.scores, - 'ctx': self.model._ctx.ctx - }) - - def load_cache(self): - self.model.n_tokens = self.llamacpp_cache['n_tokens'] - self.model.input_ids = self.llamacpp_cache['input_ids'] - self.model.scores = self.llamacpp_cache['scores'] - self.model._ctx.ctx = self.llamacpp_cache['ctx'] - - def load_negative_cache(self): - self.model.n_tokens = self.llamacpp_cache_negative['n_tokens'] - self.model.input_ids = self.llamacpp_cache_negative['input_ids'] - self.model.scores = self.llamacpp_cache_negative['scores'] - self.model._ctx.ctx = self.llamacpp_cache_negative['ctx'] - - @property - def device(self) -> torch.device: - return torch.device(0) - - def __call__(self, *args, **kwargs): - use_cache = kwargs.get('use_cache', True) - labels = kwargs.get('labels', None) - past_key_values = kwargs.get('past_key_values', None) - - if len(args) > 0: - if not shared.args.cfg_cache: - logger.error("Please enable the cfg-cache option to use CFG with llamacpp_HF.") - return - - input_ids = args[0] - is_negative = True - past_seq = self.past_seq_negative - self.load_negative_cache() - else: - input_ids = kwargs['input_ids'] - is_negative = False - past_seq = self.past_seq - self.load_cache() - - seq = input_ids[0].tolist() - if is_negative and past_key_values is not None: - seq = past_key_values + seq - - seq_tensor = torch.tensor(seq) - reset = True - - # Make the forward call. The prefix-match code has been adapted from - # https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee - if labels is None: - if past_seq is not None: - min_length = min(past_seq.shape[0], seq_tensor.shape[0]) - indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) - if len(indices) > 0: - longest_prefix = indices[0].item() - else: - longest_prefix = min_length - - if longest_prefix > 0: - reset = False - self.model.n_tokens = longest_prefix - if len(seq_tensor) - longest_prefix > 0: - self.model.eval(seq[longest_prefix:]) - else: - self.model.n_tokens -= 1 - self.model.eval([seq[-1]]) - - if reset: - self.model.reset() - self.model.eval(seq) - - logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device) - else: - self.model.reset() - self.model.eval(seq) - logits = torch.tensor(self.model.eval_logits) - logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device) - - if is_negative: - self.save_negative_cache() - self.past_seq_negative = seq_tensor - else: - self.save_cache() - self.past_seq = seq_tensor - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, logits.shape[-1]) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): - assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" - - if isinstance(pretrained_model_name_or_path, str): - pretrained_model_name_or_path = Path(pretrained_model_name_or_path) - - path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) - if path.is_file(): - model_file = path - else: - model_file = sorted(path.glob('*.gguf'))[0] - - logger.info(f"llama.cpp weights detected: {model_file}\n") - - if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '': - tensor_split_list = None - else: - tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")] - - params = { - 'model_path': str(model_file), - 'n_ctx': shared.args.n_ctx, - 'n_threads': shared.args.threads or None, - 'n_threads_batch': shared.args.threads_batch or None, - 'n_batch': shared.args.n_batch, - 'use_mmap': not shared.args.no_mmap, - 'use_mlock': shared.args.mlock, - 'mul_mat_q': not shared.args.no_mul_mat_q, - 'numa': shared.args.numa, - 'n_gpu_layers': shared.args.n_gpu_layers, - 'rope_freq_base': shared.args.rope_freq_base, - 'tensor_split': tensor_split_list, - 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, - 'logits_all': shared.args.logits_all, - 'offload_kqv': not shared.args.no_offload_kqv, - 'split_mode': 1 if not shared.args.row_split else 2, - 'flash_attn': shared.args.flash_attn - } - - if shared.args.cache_type != 'fp16': - params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) - params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) - - Llama = llama_cpp_lib().Llama - try: - model = Llama(**params) - except Exception as e: - error_message = ( - f"Failed loading the model. **This usually happens due to lack of memory**. Try these steps:\n" - f"1. Reduce the context length `n_ctx` (currently {shared.args.n_ctx})." - f"{' Try a lower value like 4096.' if shared.args.n_ctx > 4096 else '.'}" - "\n" - f"2. Lower the `n-gpu-layers` value (currently {shared.args.n_gpu_layers})." - ) - - raise type(e)(error_message) from e - - model.last_updated_index = -1 - - return LlamacppHF(model, model_file) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py deleted file mode 100644 index db25c66c..00000000 --- a/modules/llamacpp_model.py +++ /dev/null @@ -1,218 +0,0 @@ -import re -from functools import partial - -import numpy as np -import torch - -from modules import shared -from modules.callbacks import Iteratorize -from modules.llama_cpp_python_hijack import llama_cpp_lib -from modules.logging_colors import logger -from modules.text_generation import get_max_prompt_length - -llamacpp_quant_mapping = { - 'f32': 0, - 'fp16': 1, - 'q4_0': 2, - 'q4_1': 3, - 'q5_0': 6, - 'q5_1': 7, - 'q8_0': 8, - 'q8_1': 9, - 'q2_k': 10, - 'q3_k': 11, - 'q4_k': 12, - 'q5_k': 13, - 'q6_k': 14, - 'q8_k': 15, - 'iq4_nl': 20, - 'bf16': 30, -} - -llamacpp_valid_cache_types = {'fp16', 'q8_0', 'q4_0'} - - -def get_llamacpp_cache_type_for_string(quant_type: str): - quant_type = quant_type.lower() - if quant_type in llamacpp_valid_cache_types: - return llamacpp_quant_mapping[quant_type] - else: - raise ValueError(f"Invalid cache type for llama.cpp: {quant_type}. Valid options are: fp16, q8_0, q4_0.") - - -def ban_eos_logits_processor(eos_token, input_ids, logits): - logits[eos_token] = -float('inf') - return logits - - -def custom_token_ban_logits_processor(token_ids, input_ids, logits): - for token_id in token_ids: - logits[token_id] = -float('inf') - - return logits - - -class LlamaCppModel: - def __init__(self): - self.initialized = False - self.grammar_string = '' - self.grammar = None - - def __del__(self): - del self.model - - @classmethod - def from_pretrained(self, path): - - Llama = llama_cpp_lib().Llama - LlamaCache = llama_cpp_lib().LlamaCache - - result = self() - cache_capacity = 0 - if shared.args.cache_capacity is not None: - if 'GiB' in shared.args.cache_capacity: - cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000 - elif 'MiB' in shared.args.cache_capacity: - cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 - else: - cache_capacity = int(shared.args.cache_capacity) - - if cache_capacity > 0: - logger.info("Cache capacity is " + str(cache_capacity) + " bytes") - - if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '': - tensor_split_list = None - else: - tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")] - - params = { - 'model_path': str(path), - 'n_ctx': shared.args.n_ctx, - 'n_threads': shared.args.threads or None, - 'n_threads_batch': shared.args.threads_batch or None, - 'n_batch': shared.args.n_batch, - 'use_mmap': not shared.args.no_mmap, - 'use_mlock': shared.args.mlock, - 'mul_mat_q': not shared.args.no_mul_mat_q, - 'numa': shared.args.numa, - 'n_gpu_layers': shared.args.n_gpu_layers, - 'rope_freq_base': shared.args.rope_freq_base, - 'tensor_split': tensor_split_list, - 'rope_freq_scale': 1.0 / shared.args.compress_pos_emb, - 'offload_kqv': not shared.args.no_offload_kqv, - 'split_mode': 1 if not shared.args.row_split else 2, - 'flash_attn': shared.args.flash_attn - } - - if shared.args.cache_type != 'fp16': - params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) - params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) - - try: - result.model = Llama(**params) - except Exception as e: - error_message = ( - f"Failed loading the model. **This usually happens due to lack of memory**. Try these steps:\n" - f"1. Reduce the context length `n_ctx` (currently {shared.args.n_ctx})." - f"{' Try a lower value like 4096.' if shared.args.n_ctx > 4096 else '.'}" - "\n" - f"2. Lower the `n-gpu-layers` value (currently {shared.args.n_gpu_layers})." - ) - - raise type(e)(error_message) from e - - if cache_capacity > 0: - result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) - - # This is ugly, but the model and the tokenizer are the same object in this library. - return result, result - - def encode(self, string): - if type(string) is str: - string = string.encode() - - return self.model.tokenize(string) - - def decode(self, ids, **kwargs): - detokenized = self.model.detokenize(ids) - try: - # Attempt strict UTF-8 decoding first - return detokenized.decode('utf-8', 'strict') - except UnicodeDecodeError as e: - # Log the error and fall back to UTF-8 with replacement - logger.warning(f"Invalid UTF-8 in detokenized output. Using replacement characters.\n{e}") - return detokenized.decode('utf-8', 'replace') - - def get_logits(self, tokens): - self.model.reset() - self.model.eval(tokens) - logits = self.model._scores - logits = np.expand_dims(logits, 0) # batch dim is expected - return torch.tensor(logits, dtype=torch.float32) - - def load_grammar(self, string): - if string != self.grammar_string: - self.grammar_string = string - if string.strip() != '': - self.grammar = llama_cpp_lib().LlamaGrammar.from_string(string) - else: - self.grammar = None - - def generate(self, prompt, state, callback=None): - LogitsProcessorList = llama_cpp_lib().LogitsProcessorList - prompt = prompt if type(prompt) is str else prompt.decode() - - # Handle truncation - prompt = self.encode(prompt) - prompt = prompt[-get_max_prompt_length(state):] - prompt = self.decode(prompt) - - self.load_grammar(state['grammar_string']) - logit_processors = LogitsProcessorList() - if state['ban_eos_token']: - logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos())) - - if state['custom_token_bans']: - to_ban = [int(x) for x in state['custom_token_bans'].split(',')] - if len(to_ban) > 0: - logit_processors.append(partial(custom_token_ban_logits_processor, to_ban)) - - completion_chunks = self.model.create_completion( - prompt=prompt, - max_tokens=state['max_new_tokens'], - temperature=state['temperature'], - top_p=state['top_p'] if state['top_p'] < 1 else 0.999, - min_p=state['min_p'], - typical_p=state['typical_p'], - frequency_penalty=state['frequency_penalty'], - presence_penalty=state['presence_penalty'], - repeat_penalty=state['repetition_penalty'], - top_k=state['top_k'], - stream=True, - seed=int(state['seed']) if state['seed'] != -1 else None, - tfs_z=state['tfs'], - mirostat_mode=int(state['mirostat_mode']), - mirostat_tau=state['mirostat_tau'], - mirostat_eta=state['mirostat_eta'], - logits_processor=logit_processors, - grammar=self.grammar - ) - - output = "" - for completion_chunk in completion_chunks: - if shared.stop_everything: - break - - text = completion_chunk['choices'][0]['text'] - output += text - if callback: - callback(text) - - return output - - def generate_with_streaming(self, *args, **kwargs): - with Iteratorize(self.generate, args, kwargs, callback=None) as generator: - reply = '' - for token in generator: - reply += token - yield reply diff --git a/modules/loaders.py b/modules/loaders.py index 980a13e6..3060406d 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -30,51 +30,20 @@ loaders_and_params = OrderedDict({ 'n_gpu_layers', 'threads', 'threads_batch', - 'n_batch', + 'batch_size', 'n_ctx', 'cache_type', 'tensor_split', 'rope_freq_base', 'compress_pos_emb', - 'attention_sink_size', - 'tensorcores', 'flash_attn', - 'streaming_llm', - 'cpu', 'row_split', - 'no_offload_kqv', + 'no_kv_offload', 'no_mul_mat_q', 'no_mmap', 'mlock', 'numa', ], - 'llamacpp_HF': [ - 'n_gpu_layers', - 'threads', - 'threads_batch', - 'n_batch', - 'n_ctx', - 'cache_type', - 'tensor_split', - 'rope_freq_base', - 'compress_pos_emb', - 'attention_sink_size', - 'tensorcores', - 'flash_attn', - 'streaming_llm', - 'cpu', - 'row_split', - 'no_offload_kqv', - 'no_mul_mat_q', - 'no_mmap', - 'mlock', - 'numa', - 'cfg_cache', - 'logits_all', - 'trust_remote_code', - 'no_use_fast', - 'llamacpp_HF_info', - ], 'ExLlamav3_HF': [ 'max_seq_len', 'gpu_split', @@ -307,66 +276,34 @@ loaders_samplers = { 'dry_sequence_breakers', }, 'llama.cpp': { - 'temperature', - 'min_p', - 'top_p', - 'top_k', - 'typical_p', - 'tfs', - 'repetition_penalty', - 'frequency_penalty', - 'presence_penalty', - 'mirostat_mode', - 'mirostat_tau', - 'mirostat_eta', - 'ban_eos_token', - 'seed', - 'custom_token_bans', - 'grammar_string', - 'grammar_file_row', - }, - 'llamacpp_HF': { 'temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', - 'smoothing_factor', - 'smoothing_curve', 'min_p', 'top_p', 'top_k', 'typical_p', 'xtc_threshold', 'xtc_probability', - 'epsilon_cutoff', - 'eta_cutoff', 'tfs', - 'top_a', - 'top_n_sigma', 'dry_multiplier', 'dry_allowed_length', 'dry_base', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', - 'encoder_repetition_penalty', - 'no_repeat_ngram_size', 'repetition_penalty_range', - 'guidance_scale', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', - 'do_sample', 'dynamic_temperature', 'temperature_last', 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', - 'skip_special_tokens', 'seed', 'sampler_priority', - 'custom_token_bans', - 'negative_prompt', 'dry_sequence_breakers', 'grammar_string', 'grammar_file_row', diff --git a/modules/logits.py b/modules/logits.py index f8a1e80c..c5d9112d 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -1,6 +1,7 @@ import time import traceback +import numpy as np import torch from modules import models, sampler_hijack, shared @@ -38,70 +39,86 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur return 'Error: No model is loaded1 Select one in the Model tab.', previous is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model' - is_non_hf_llamacpp = shared.model.__class__.__name__ == 'LlamaCppModel' + is_llamacpp = shared.model.__class__.__name__ == 'LlamaServer' - if use_samplers: - if any([is_non_hf_exllamav2, is_non_hf_llamacpp]): - logger.error("Sampler hijacking is not supported non-Huggingface loaders.") - # sampling is all done in c for exllama, so it is really hard to hijack - # it should be possible to hijack llamacpp sampler by hijacking all their sampling methods, - # but it is not implemented yet - return 'Error: Sampler hijacking is not supported non-Huggingface loaders. Please disable the "Use samplers" option.', previous + if is_llamacpp: + logprobs = shared.model.get_logits(prompt, state, n_probs=top_logits, use_samplers=use_samplers) + if return_dict: + output = {} + for entry in logprobs: + token = repr(entry['token']) + prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) + output[token] = prob - state['max_new_tokens'] = 1 - state['auto_max_new_tokens'] = False - for _ in generate_reply(prompt, state): - pass - - scores = sampler_hijack.global_scores[-1] - else: - if is_non_hf_exllamav2: - device = get_device() - tokens = shared.tokenizer.encode(prompt) - if device: - tokens = tokens.to(device) - - scores = shared.model.get_logits(tokens)[-1][-1] - elif is_non_hf_llamacpp: - tokens = shared.tokenizer.encode(prompt) - scores = shared.model.get_logits(tokens)[-1][-1] + return output else: - device = get_device() - tokens = shared.tokenizer.encode(prompt, return_tensors='pt') - if device: - tokens = tokens.to(device) + output = '' + for entry in logprobs: + token = repr(entry['token']) + prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) + output += f"{prob:.5f} - {token}\n" - output = shared.model(input_ids=tokens) - scores = output['logits'][-1][-1] - - probs = torch.softmax(scores, dim=-1, dtype=torch.float) - topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True) - if is_non_hf_llamacpp: - topk_indices = [i.expand((1, 1)) for i in topk_indices] - - if hasattr(shared.tokenizer, 'convert_ids_to_tokens'): - tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices] + return output, previous else: - tokens = [shared.tokenizer.decode(i) for i in topk_indices] + if not use_samplers: + state = {'stream': True} - if return_dict: - topk_values = [float(i) for i in topk_values] - output = {} - for row in list(zip(topk_values, tokens)): - key = row[1] - if isinstance(key, bytes): - try: - key = key.decode() - except: - key = key.decode('latin') + if use_samplers: + if is_non_hf_exllamav2: + logger.error("Sampler hijacking is not supported non-Huggingface loaders.") + # sampling is all done in c for exllama, so it is really hard to hijack + # it should be possible to hijack llamacpp sampler by hijacking all their sampling methods, + # but it is not implemented yet + return 'Error: Sampler hijacking is not supported non-Huggingface loaders. Please disable the "Use samplers" option.', previous - output[key] = row[0] + state['max_new_tokens'] = 1 + state['auto_max_new_tokens'] = False + for _ in generate_reply(prompt, state): + pass - return output - else: - topk_values = [f"{float(i):.5f}" for i in topk_values] - output = '' - for row in list(zip(topk_values, tokens)): - output += f"{row[0]} - {repr(row[1])}\n" + scores = sampler_hijack.global_scores[-1] + else: + if is_non_hf_exllamav2: + device = get_device() + tokens = shared.tokenizer.encode(prompt) + if device: + tokens = tokens.to(device) - return output, previous + scores = shared.model.get_logits(tokens)[-1][-1] + else: + device = get_device() + tokens = shared.tokenizer.encode(prompt, return_tensors='pt') + if device: + tokens = tokens.to(device) + + output = shared.model(input_ids=tokens) + scores = output['logits'][-1][-1] + + probs = torch.softmax(scores, dim=-1, dtype=torch.float) + topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True) + if hasattr(shared.tokenizer, 'convert_ids_to_tokens'): + tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices] + else: + tokens = [shared.tokenizer.decode(i) for i in topk_indices] + + if return_dict: + topk_values = [float(i) for i in topk_values] + output = {} + for row in list(zip(topk_values, tokens)): + key = row[1] + if isinstance(key, bytes): + try: + key = key.decode() + except: + key = key.decode('latin') + + output[key] = row[0] + + return output + else: + topk_values = [f"{float(i):.5f}" for i in topk_values] + output = '' + for row in list(zip(topk_values, tokens)): + output += f"{row[0]} - {repr(row[1])}\n" + + return output, previous diff --git a/modules/models.py b/modules/models.py index 288bc1b6..48d92bd5 100644 --- a/modules/models.py +++ b/modules/models.py @@ -67,8 +67,7 @@ def load_model(model_name, loader=None): shared.model_name = model_name load_func_map = { 'Transformers': huggingface_loader, - 'llama.cpp': llamacpp_loader, - 'llamacpp_HF': llamacpp_HF_loader, + 'llama.cpp': llama_cpp_server_loader, 'ExLlamav3_HF': ExLlamav3_HF_loader, 'ExLlamav2_HF': ExLlamav2_HF_loader, 'ExLlamav2': ExLlamav2_loader, @@ -101,7 +100,7 @@ def load_model(model_name, loader=None): shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings}) if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'): shared.settings['truncation_length'] = shared.args.max_seq_len - elif loader in ['llama.cpp', 'llamacpp_HF']: + elif loader == 'llama.cpp': shared.settings['truncation_length'] = shared.args.n_ctx logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.") @@ -268,8 +267,8 @@ def huggingface_loader(model_name): return model -def llamacpp_loader(model_name): - from modules.llamacpp_model import LlamaCppModel +def llama_cpp_server_loader(model_name): + from modules.llama_cpp_server import LlamaServer path = Path(f'{shared.args.model_dir}/{model_name}') if path.is_file(): @@ -278,31 +277,11 @@ def llamacpp_loader(model_name): model_file = sorted(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0] logger.info(f"llama.cpp weights detected: \"{model_file}\"") - model, tokenizer = LlamaCppModel.from_pretrained(model_file) - return model, tokenizer - - -def llamacpp_HF_loader(model_name): - from modules.llamacpp_hf import LlamacppHF - - if shared.args.tokenizer_dir: - logger.info(f'Using tokenizer from: \"{shared.args.tokenizer_dir}\"') - else: - path = Path(f'{shared.args.model_dir}/{model_name}') - # Check if a HF tokenizer is available for the model - if all((path / file).exists() for file in ['tokenizer_config.json']): - logger.info(f'Using tokenizer from: \"{path}\"') - else: - logger.error("Could not load the model because a tokenizer in Transformers format was not found.") - return None, None - - model = LlamacppHF.from_pretrained(model_name) - - if shared.args.tokenizer_dir: - tokenizer = load_tokenizer(model_name, tokenizer_dir=shared.args.tokenizer_dir) - return model, tokenizer - else: - return model + try: + model = LlamaServer(model_file) + return model, model + except Exception as e: + logger.error(f"Error loading the model with llama.cpp: {str(e)}") def ExLlamav3_HF_loader(model_name): diff --git a/modules/models_settings.py b/modules/models_settings.py index 693c4dde..0af89d2c 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -48,7 +48,7 @@ def get_model_metadata(model): ) # GGUF metadata - if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF']: + if model_settings['loader'] == 'llama.cpp': path = Path(f'{shared.args.model_dir}/{model}') if path.is_file(): model_file = path @@ -163,8 +163,6 @@ def infer_loader(model_name, model_settings, hf_quant_method=None): path_to_model = Path(f'{shared.args.model_dir}/{model_name}') if not path_to_model.exists(): loader = None - elif len(list(path_to_model.glob('*.gguf'))) > 0 and path_to_model.is_dir() and (path_to_model / 'tokenizer_config.json').exists(): - loader = 'llamacpp_HF' elif len(list(path_to_model.glob('*.gguf'))) > 0: loader = 'llama.cpp' elif re.match(r'.*\.gguf', model_name.lower()): diff --git a/modules/shared.py b/modules/shared.py index 91a2d361..83761b75 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,7 +86,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft # Model loader group = parser.add_argument_group('Model loader') -group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.') +group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.') # Transformers/Accelerate group = parser.add_argument_group('Transformers/Accelerate') @@ -116,24 +116,17 @@ group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for # llama.cpp group = parser.add_argument_group('llama.cpp') group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.') -group.add_argument('--tensorcores', action='store_true', help='NVIDIA only: use llama-cpp-python compiled without GGML_CUDA_FORCE_MMQ. This may improve performance on newer cards.') group.add_argument('--n_ctx', type=int, default=8192, help='Size of the prompt context.') group.add_argument('--threads', type=int, default=0, help='Number of threads to use.') group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.') -group.add_argument('--no_mul_mat_q', action='store_true', help='Disable the mulmat kernels.') -group.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.') +group.add_argument('--batch-size', type=int, default=2048, help='Maximum number of prompt tokens to batch together when calling llama_eval.') group.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.') group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.') group.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.') -group.add_argument('--tensor_split', type=str, default=None, help='Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.') +group.add_argument('--tensor-split', type=str, default=None, help='Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.') group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.') -group.add_argument('--logits_all', action='store_true', help='Needs to be set for perplexity evaluation to work. Otherwise, ignore it, as it makes prompt processing slower.') -group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') -group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.') -group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') -group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') -group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.') -group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.') +group.add_argument('--no-kv-offload', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') +group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') # ExLlamaV2 group = parser.add_argument_group('ExLlamaV2') @@ -212,6 +205,13 @@ group.add_argument('--wbits', type=int, default=0, help='DEPRECATED') group.add_argument('--groupsize', type=int, default=-1, help='DEPRECATED') group.add_argument('--model-menu', action='store_true', help='DEPRECATED') group.add_argument('--multimodal-pipeline', type=str, default=None, help='DEPRECATED') +group.add_argument('--streaming-llm', action='store_true', help='DEPRECATED') +group.add_argument('--attention-sink-size', type=int, default=5, help='DEPRECATED') +group.add_argument('--tokenizer-dir', type=str, help='DEPRECATED') +group.add_argument('--logits_all', action='store_true', help='DEPRECATED') +group.add_argument('--no_mul_mat_q', action='store_true', help='DEPRECATED') +group.add_argument('--cache-capacity', type=str, help='DEPRECATED') +group.add_argument('--tensorcores', action='store_true', help='DEPRECATED') args = parser.parse_args() args_defaults = parser.parse_args([]) @@ -262,8 +262,6 @@ def fix_loader_name(name): name = name.lower() if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']: return 'llama.cpp' - if name in ['llamacpp_hf', 'llama.cpp_hf', 'llama-cpp-hf', 'llamacpp-hf', 'llama.cpp-hf']: - return 'llamacpp_HF' elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']: return 'Transformers' elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']: @@ -316,7 +314,7 @@ def transform_legacy_kv_cache_options(opts): set('cache_type', 'fp8') elif cache_4bit: set('cache_type', 'q4') - elif loader.lower() in ['llama.cpp', 'llamacpp_hf']: + elif loader.lower() == 'llama.cpp': # Llama.cpp loader-specific cache type if cache_4bit: set('cache_type', 'q4_0') diff --git a/modules/text_generation.py b/modules/text_generation.py index eff6495e..16aba3cb 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -17,7 +17,6 @@ from transformers import ( import modules.shared as shared from modules import models, sampler_hijack -from modules.cache_utils import process_llamacpp_cache from modules.callbacks import ( Iteratorize, Stream, @@ -56,7 +55,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield '' return - if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel']: + if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF @@ -133,8 +132,12 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if shared.tokenizer is None: raise ValueError('No tokenizer is loaded') - if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel']: - input_ids = shared.tokenizer.encode(str(prompt)) + if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']: + if shared.model.__class__.__name__ == 'LlamaServer': + input_ids = shared.tokenizer.encode(str(prompt), add_bos_token=add_bos_token) + else: + input_ids = shared.tokenizer.encode(str(prompt)) + if shared.model.__class__.__name__ not in ['Exllamav2Model']: input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: @@ -159,7 +162,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu: return input_ids else: device = get_device() @@ -186,7 +189,7 @@ def get_encoded_length(prompt): def get_token_ids(prompt): tokens = encode(prompt)[0] - decoded_tokens = [shared.tokenizer.decode([i]) for i in tokens] + decoded_tokens = [shared.tokenizer.decode([int(i)]) for i in tokens] output = '' for row in list(zip(tokens, decoded_tokens)): @@ -401,12 +404,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings logger.info("PROMPT=") print_prompt(decode(input_ids[0], skip_special_tokens=False)) - # Handle StreamingLLM for llamacpp_HF - if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm: - tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids.tolist()) - shared.model.past_seq = torch.tensor(tmp) - shared.model.save_cache() - t0 = time.time() try: if not is_chat and not shared.is_seq2seq: diff --git a/modules/ui.py b/modules/ui.py index adbb67b0..919ab8da 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -110,7 +110,7 @@ def list_model_elements(): 'n_gpu_layers', 'threads', 'threads_batch', - 'n_batch', + 'batch_size', 'hqq_backend', 'n_ctx', 'max_seq_len', @@ -122,20 +122,17 @@ def list_model_elements(): 'compress_pos_emb', 'compute_dtype', 'quant_type', - 'attention_sink_size', 'num_experts_per_token', - 'tensorcores', 'load_in_8bit', 'load_in_4bit', 'torch_compile', 'flash_attn', 'use_flash_attention_2', - 'streaming_llm', 'auto_devices', 'cpu', 'disk', 'row_split', - 'no_offload_kqv', + 'no_kv_offload', 'no_mul_mat_q', 'no_mmap', 'mlock', @@ -150,7 +147,6 @@ def list_model_elements(): 'no_sdpa', 'cfg_cache', 'cpp_runner', - 'logits_all', 'trust_remote_code', 'no_use_fast', ] diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 4fc1de08..fca883eb 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -87,7 +87,7 @@ def create_ui(): shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=256, value=shared.args.n_gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.') shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads) shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch) - shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, step=1, value=shared.args.n_batch) + shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size) shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend) shared.gradio['n_ctx'] = gr.Number(label="n_ctx", precision=0, step=256, value=shared.args.n_ctx, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.') shared.gradio['max_seq_len'] = gr.Number(label='max_seq_len', precision=0, step=256, value=shared.args.max_seq_len, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.') @@ -99,22 +99,19 @@ def create_ui(): shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.') shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.') shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.') - shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.') shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.') with gr.Column(): - shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled without GGML_CUDA_FORCE_MMQ. This may improve performance on newer cards.') shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit) shared.gradio['torch_compile'] = gr.Checkbox(label="torch-compile", value=shared.args.torch_compile, info='Compile the model with torch.compile for improved performance.') shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.') shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.') - shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices) shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.') shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk) shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.') - shared.gradio['no_offload_kqv'] = gr.Checkbox(label="no_offload_kqv", value=shared.args.no_offload_kqv, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') + shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') shared.gradio['no_mul_mat_q'] = gr.Checkbox(label="no_mul_mat_q", value=shared.args.no_mul_mat_q, info='Disable the mulmat kernels.') shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap) shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock) @@ -129,7 +126,6 @@ def create_ui(): shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa) shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.') shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.') - shared.gradio['logits_all'] = gr.Checkbox(label="logits_all", value=shared.args.logits_all, info='Needs to be set for perplexity evaluation to work with this loader. Otherwise, ignore it, as it makes prompt processing slower.') shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code) shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.') shared.gradio['llamacpp_HF_info'] = gr.Markdown("llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to place your GGUF in a subfolder of models/ with the necessary tokenizer files.\n\nYou can use the \"llamacpp_HF creator\" menu to do that automatically.") @@ -147,15 +143,6 @@ def create_ui(): shared.gradio['download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu) shared.gradio['get_file_list'] = gr.Button("Get file list", interactive=not mu) - with gr.Tab("llamacpp_HF creator"): - with gr.Row(): - shared.gradio['gguf_menu'] = gr.Dropdown(choices=utils.get_available_ggufs(), value=lambda: shared.model_name, label='Choose your GGUF', elem_classes='slim-dropdown', interactive=not mu) - ui.create_refresh_button(shared.gradio['gguf_menu'], lambda: None, lambda: {'choices': utils.get_available_ggufs()}, 'refresh-button', interactive=not mu) - - shared.gradio['unquantized_url'] = gr.Textbox(label="Enter the URL for the original (unquantized) model", info="Example: https://huggingface.co/lmsys/vicuna-13b-v1.5", max_lines=1) - shared.gradio['create_llamacpp_hf_button'] = gr.Button("Submit", variant="primary", interactive=not mu) - gr.Markdown("This will move your gguf file into a subfolder of `models` along with the necessary tokenizer files.") - with gr.Tab("Customize instruction template"): with gr.Row(): shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown') @@ -195,7 +182,6 @@ def create_event_handlers(): shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model')) - shared.gradio['create_llamacpp_hf_button'].click(create_llamacpp_hf, gradio('gguf_menu', 'unquantized_url'), gradio('model_status'), show_progress=True) shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True) @@ -286,34 +272,11 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur yield traceback.format_exc().replace('\n', '\n\n') -def create_llamacpp_hf(gguf_name, unquantized_url, progress=gr.Progress()): - try: - downloader = importlib.import_module("download-model").ModelDownloader() - - progress(0.0) - model, branch = downloader.sanitize_model_and_branch_names(unquantized_url, None) - - yield ("Getting the tokenizer files links from Hugging Face") - links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=True) - output_folder = Path(shared.args.model_dir) / (re.sub(r'(?i)\.gguf$', '', gguf_name) + "-HF") - - yield (f"Downloading tokenizer to `{output_folder}/`") - downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=4, is_llamacpp=False) - - # Move the GGUF - (Path(shared.args.model_dir) / gguf_name).rename(output_folder / gguf_name) - - yield (f"Model saved to `{output_folder}/`.\n\nYou can now load it using llamacpp_HF.") - except: - progress(1.0) - yield traceback.format_exc().replace('\n', '\n\n') - - def update_truncation_length(current_length, state): if 'loader' in state: if state['loader'].lower().startswith('exllama'): return state['max_seq_len'] - elif state['loader'] in ['llama.cpp', 'llamacpp_HF']: + elif state['loader'] == 'llama.cpp': return state['n_ctx'] return current_length diff --git a/requirements.txt b/requirements.txt index 54853528..84bdbd62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,19 +31,9 @@ flask_cloudflared==0.0.14 sse-starlette==1.6.5 tiktoken -# llama-cpp-python (CPU only, AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" - -# llama-cpp-python (CUDA, with GGML_CUDA_FORCE_MMQ) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" - -# llama-cpp-python (CUDA, without GGML_CUDA_FORCE_MMQ) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" - # CUDA wheels +https://github.com/oobabooga/llama-cpp-binaries/releases/download/textgen-webui/llama_cpp_binaries-0.1.0+cu124-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/textgen-webui/llama_cpp_binaries-0.1.0+cu124-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" diff --git a/requirements_amd.txt b/requirements_amd.txt index 3d24891f..c94522cf 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -30,11 +30,7 @@ flask_cloudflared==0.0.14 sse-starlette==1.6.5 tiktoken -# llama-cpp-python (CPU only, AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" - # AMD wheels -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.3.8+rocm6.1.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/rocm/llama_cpp_binaries-0.1.0+rocm6.1.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index 057b631d..a62d4c1e 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -30,10 +30,7 @@ flask_cloudflared==0.0.14 sse-starlette==1.6.5 tiktoken -# llama-cpp-python (CPU only, no AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" - # AMD wheels +https://github.com/oobabooga/llama-cpp-binaries/releases/download/rocm/llama_cpp_binaries-0.1.0+rocm6.1.2avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64" diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index a9e88a25..766a3dd6 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -31,7 +31,7 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2-py3-none-any.whl https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index 4c64c030..6a4cd80c 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -31,8 +31,8 @@ sse-starlette==1.6.5 tiktoken # Mac wheels -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11" https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2-py3-none-any.whl https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl diff --git a/requirements_cpu_only.txt b/requirements_cpu_only.txt index c7e2687c..67c8fe31 100644 --- a/requirements_cpu_only.txt +++ b/requirements_cpu_only.txt @@ -30,6 +30,6 @@ flask_cloudflared==0.0.14 sse-starlette==1.6.5 tiktoken -# llama-cpp-python (CPU only, AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +# llama.cpp (CPU only, AVX2) +https://github.com/oobabooga/llama-cpp-binaries/releases/download/cpu/llama_cpp_binaries-0.1.0+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/cpu/llama_cpp_binaries-0.1.0+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" diff --git a/requirements_cpu_only_noavx2.txt b/requirements_cpu_only_noavx2.txt index 2003c544..3e988c46 100644 --- a/requirements_cpu_only_noavx2.txt +++ b/requirements_cpu_only_noavx2.txt @@ -30,6 +30,6 @@ flask_cloudflared==0.0.14 sse-starlette==1.6.5 tiktoken -# llama-cpp-python (CPU only, no AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +# llama.cpp (CPU only, no AVX2) +https://github.com/oobabooga/llama-cpp-binaries/releases/download/cpu/llama_cpp_binaries-0.1.0+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/cpu/llama_cpp_binaries-0.1.0+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index 0096c9ab..27cd2150 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -31,19 +31,9 @@ flask_cloudflared==0.0.14 sse-starlette==1.6.5 tiktoken -# llama-cpp-python (CPU only, no AVX2) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" - -# llama-cpp-python (CUDA, with GGML_CUDA_FORCE_MMQ) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" - -# llama-cpp-python (CUDA, without GGML_CUDA_FORCE_MMQ) -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" - # CUDA wheels +https://github.com/oobabooga/llama-cpp-binaries/releases/download/textgen-webui/llama_cpp_binaries-0.1.0+cu124avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/oobabooga/llama-cpp-binaries/releases/download/textgen-webui/llama_cpp_binaries-0.1.0+cu124avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"