mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
New llama.cpp loader (#6846)
This commit is contained in:
parent
5c2f8d828e
commit
ae54d8faaa
|
|
@ -5,7 +5,5 @@ from modules.logits import get_next_logits
|
||||||
def _get_next_logits(body):
|
def _get_next_logits(body):
|
||||||
# Pre-process the input payload to simulate a real generation
|
# Pre-process the input payload to simulate a real generation
|
||||||
use_samplers = body['use_samplers']
|
use_samplers = body['use_samplers']
|
||||||
state = process_parameters(body) if use_samplers else {}
|
state = process_parameters(body)
|
||||||
state['stream'] = True
|
|
||||||
|
|
||||||
return get_next_logits(body['prompt'], state, use_samplers, "", top_logits=body['top_logits'], return_dict=True)
|
return get_next_logits(body['prompt'], state, use_samplers, "", top_logits=body['top_logits'], return_dict=True)
|
||||||
|
|
|
||||||
|
|
@ -1,115 +0,0 @@
|
||||||
import torch
|
|
||||||
from numba import njit
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
|
|
||||||
def process_llamacpp_cache(model, new_sequence, past_sequence):
|
|
||||||
if len(past_sequence) == 0 or len(new_sequence) == 0:
|
|
||||||
return past_sequence
|
|
||||||
|
|
||||||
i1, i2, j1, j2 = find_longest_common_substring_indices(past_sequence, new_sequence)
|
|
||||||
overlap_length = i2 - i1 + 1
|
|
||||||
|
|
||||||
# Do StreamingLLM if i1 > 0 (ie the longest common subsequence is not a prefix)
|
|
||||||
# and the overlap length is sufficiently long.
|
|
||||||
if i1 > 0 and overlap_length > 0.2 * len(new_sequence):
|
|
||||||
|
|
||||||
new_sequence = torch.tensor(new_sequence)
|
|
||||||
past_sequence = torch.tensor(past_sequence)
|
|
||||||
|
|
||||||
prefix_length = find_prefix_length(past_sequence[:i1], new_sequence[:j1])
|
|
||||||
sink_length = max(prefix_length, shared.args.attention_sink_size)
|
|
||||||
removed_length = i1 - sink_length
|
|
||||||
|
|
||||||
if removed_length <= 0:
|
|
||||||
return past_sequence.tolist()
|
|
||||||
|
|
||||||
matching_prefix = past_sequence[:prefix_length]
|
|
||||||
removed_chunk = past_sequence[sink_length:i1]
|
|
||||||
overlapping_sequence = new_sequence[j1:j2 + 1]
|
|
||||||
added_chunk = new_sequence[j2 + 1:]
|
|
||||||
|
|
||||||
# print(past_sequence.tolist())
|
|
||||||
# print(new_sequence.tolist())
|
|
||||||
|
|
||||||
print()
|
|
||||||
print('MATCHING PREFIX=', repr(shared.tokenizer.decode(matching_prefix)))
|
|
||||||
print('ADDED CHUNK=', repr(shared.tokenizer.decode(added_chunk)))
|
|
||||||
print('REMOVED CHUNK=', repr(shared.tokenizer.decode(removed_chunk)))
|
|
||||||
print('REMOVED LENGTH=', removed_length)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Remove interval [sink_length, sink_length + removed_length) from the context
|
|
||||||
# Update model.n_tokens
|
|
||||||
model._ctx.kv_cache_seq_rm(0, sink_length, sink_length + removed_length)
|
|
||||||
model._ctx.kv_cache_seq_shift(0, sink_length + removed_length, -1, -removed_length)
|
|
||||||
|
|
||||||
new_sequence = new_sequence.tolist()
|
|
||||||
model.input_ids[:j2 + 1] = new_sequence[:j2 + 1]
|
|
||||||
model.n_tokens = j2 + 1
|
|
||||||
|
|
||||||
return new_sequence[:j2 + 1]
|
|
||||||
else:
|
|
||||||
return past_sequence
|
|
||||||
|
|
||||||
|
|
||||||
def find_prefix_length(past_seq, seq_tensor):
|
|
||||||
'''
|
|
||||||
Given two torch tensors, finds the length of the longest
|
|
||||||
common prefix between the two.
|
|
||||||
'''
|
|
||||||
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
|
||||||
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
|
||||||
if len(indices) > 0:
|
|
||||||
prefix_length = indices[0].item()
|
|
||||||
else:
|
|
||||||
prefix_length = min_length
|
|
||||||
|
|
||||||
return prefix_length
|
|
||||||
|
|
||||||
|
|
||||||
@njit
|
|
||||||
def find_longest_common_substring_indices(list1, list2):
|
|
||||||
'''
|
|
||||||
Given two lists, solves the Longest Common Substring problem.
|
|
||||||
|
|
||||||
It returns the indices where the substring starts and ends in
|
|
||||||
s1 and s2.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
ir, jr, ir2, jr2 = find_longest_common_substring_indices(s1, s2)
|
|
||||||
print(s1[ir:jr + 1])
|
|
||||||
print(s2[ir2:jr2 + 1])
|
|
||||||
|
|
||||||
Adapted from
|
|
||||||
https://rosettacode.org/wiki/Longest_common_substring#Python
|
|
||||||
'''
|
|
||||||
|
|
||||||
len_list1, len_list2 = len(list1), len(list2)
|
|
||||||
start_index_list1, end_index_list1 = 0, -1
|
|
||||||
start_index_list2, end_index_list2 = 0, -1
|
|
||||||
|
|
||||||
# for index1 in tqdm(range(0, len_list1), desc="StreamingLLM prompt comparison", leave=False):
|
|
||||||
for index1 in range(0, len_list1):
|
|
||||||
try:
|
|
||||||
index2 = list2.index(list1[index1])
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
|
|
||||||
while index2 >= 0:
|
|
||||||
temp_index1, temp_index2 = index1, index2
|
|
||||||
while temp_index1 < len_list1 and temp_index2 < len_list2 and list2[temp_index2] == list1[temp_index1]:
|
|
||||||
if temp_index1 - index1 >= end_index_list1 - start_index_list1:
|
|
||||||
start_index_list1, end_index_list1 = index1, temp_index1
|
|
||||||
start_index_list2, end_index_list2 = index2, temp_index2
|
|
||||||
|
|
||||||
temp_index1 += 1
|
|
||||||
temp_index2 += 1
|
|
||||||
try:
|
|
||||||
index2 = list2.index(list1[index1], index2 + 1)
|
|
||||||
except:
|
|
||||||
break
|
|
||||||
|
|
||||||
return start_index_list1, end_index_list1, start_index_list2, end_index_list2
|
|
||||||
|
|
@ -40,17 +40,13 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if shared.args.loader == "llama.cpp":
|
if shared.args.loader == "llama.cpp":
|
||||||
logger.error("llamacpp_HF is required for perplexity evaluation with GGUF models. Please reload the model with llamacpp_HF instead of llama.cpp.")
|
logger.error("Perplexity evaluation is not implemented for the llama.cpp loader.")
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
if shared.args.loader == "ExLlamav2":
|
if shared.args.loader == "ExLlamav2":
|
||||||
logger.error("ExLlamav2_HF is required for perplexity evaluation with EXL2 models. Please reload the model with ExLlamav2_HF instead of ExLlamav2.")
|
logger.error("ExLlamav2_HF is required for perplexity evaluation with EXL2 models. Please reload the model with ExLlamav2_HF instead of ExLlamav2.")
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
if shared.args.loader == "llamacpp_HF" and not shared.args.logits_all:
|
|
||||||
logger.error("--logits_all is required for perplexity evaluation with GGUF models. Please reload the model with that option set/checked.")
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
if not shared.args.no_use_fast:
|
if not shared.args.no_use_fast:
|
||||||
logger.warning("--no_use_fast is not set. If tokenizing the input dataset takes a long time, try reloading the model with that option set/checked.")
|
logger.warning("--no_use_fast is not set. If tokenizing the input dataset takes a long time, try reloading the model with that option set/checked.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,165 +0,0 @@
|
||||||
import importlib
|
|
||||||
import platform
|
|
||||||
from typing import Sequence
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.cache_utils import process_llamacpp_cache
|
|
||||||
|
|
||||||
imported_module = None
|
|
||||||
not_available_modules = set()
|
|
||||||
|
|
||||||
|
|
||||||
def llama_cpp_lib():
|
|
||||||
global imported_module, not_available_modules
|
|
||||||
|
|
||||||
# Determine the platform
|
|
||||||
is_macos = platform.system() == 'Darwin'
|
|
||||||
|
|
||||||
# Define the library names based on the platform
|
|
||||||
if is_macos:
|
|
||||||
lib_names = [
|
|
||||||
(None, 'llama_cpp')
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
lib_names = [
|
|
||||||
('cpu', 'llama_cpp'),
|
|
||||||
('tensorcores', 'llama_cpp_cuda_tensorcores'),
|
|
||||||
(None, 'llama_cpp_cuda'),
|
|
||||||
(None, 'llama_cpp')
|
|
||||||
]
|
|
||||||
|
|
||||||
for arg, lib_name in lib_names:
|
|
||||||
if lib_name in not_available_modules:
|
|
||||||
continue
|
|
||||||
|
|
||||||
should_import = (arg is None or getattr(shared.args, arg))
|
|
||||||
|
|
||||||
if should_import:
|
|
||||||
if imported_module and imported_module != lib_name:
|
|
||||||
# Conflict detected, raise an exception
|
|
||||||
raise Exception(f"Cannot import `{lib_name}` because `{imported_module}` is already imported. Switching to a different version of llama-cpp-python currently requires a server restart.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
return_lib = importlib.import_module(lib_name)
|
|
||||||
imported_module = lib_name
|
|
||||||
monkey_patch_llama_cpp_python(return_lib)
|
|
||||||
return return_lib
|
|
||||||
except ImportError:
|
|
||||||
not_available_modules.add(lib_name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def eval_with_progress(self, tokens: Sequence[int]):
|
|
||||||
"""
|
|
||||||
A copy of
|
|
||||||
|
|
||||||
https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py
|
|
||||||
|
|
||||||
with tqdm to show prompt processing progress.
|
|
||||||
"""
|
|
||||||
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
|
||||||
|
|
||||||
if len(tokens) > self.n_batch:
|
|
||||||
progress_bar = tqdm(range(0, len(tokens), self.n_batch), desc="Prompt evaluation", leave=False)
|
|
||||||
else:
|
|
||||||
progress_bar = range(0, len(tokens), self.n_batch)
|
|
||||||
|
|
||||||
for i in progress_bar:
|
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
|
||||||
n_past = self.n_tokens
|
|
||||||
n_tokens = len(batch)
|
|
||||||
self._batch.set_batch(
|
|
||||||
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
|
|
||||||
)
|
|
||||||
self._ctx.decode(self._batch)
|
|
||||||
# Save tokens
|
|
||||||
self.input_ids[n_past : n_past + n_tokens] = batch
|
|
||||||
# Save logits
|
|
||||||
if self.context_params.logits_all:
|
|
||||||
rows = n_tokens
|
|
||||||
cols = self._n_vocab
|
|
||||||
logits = np.ctypeslib.as_array(
|
|
||||||
self._ctx.get_logits(), shape=(rows * cols,)
|
|
||||||
)
|
|
||||||
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
|
|
||||||
self.last_updated_index = n_past + n_tokens - 1
|
|
||||||
else:
|
|
||||||
rows = 1
|
|
||||||
cols = self._n_vocab
|
|
||||||
logits = np.ctypeslib.as_array(
|
|
||||||
self._ctx.get_logits(), shape=(rows * cols,)
|
|
||||||
)
|
|
||||||
last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1)
|
|
||||||
self.scores[last_token_index, :] = logits.reshape(-1)
|
|
||||||
self.last_updated_index = last_token_index
|
|
||||||
# Update n_tokens
|
|
||||||
self.n_tokens += n_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_llama_cpp_python(lib):
|
|
||||||
if getattr(lib.Llama, '_is_patched', False):
|
|
||||||
# If the patch is already applied, do nothing
|
|
||||||
return
|
|
||||||
|
|
||||||
def my_generate(self, *args, **kwargs):
|
|
||||||
if shared.args.streaming_llm:
|
|
||||||
new_sequence = args[0]
|
|
||||||
past_sequence = self._input_ids
|
|
||||||
|
|
||||||
# Do the cache trimming for StreamingLLM
|
|
||||||
process_llamacpp_cache(self, new_sequence, past_sequence)
|
|
||||||
|
|
||||||
for output in self.original_generate(*args, **kwargs):
|
|
||||||
yield output
|
|
||||||
|
|
||||||
lib.Llama.eval = eval_with_progress
|
|
||||||
lib.Llama.original_generate = lib.Llama.generate
|
|
||||||
lib.Llama.generate = my_generate
|
|
||||||
|
|
||||||
# Also patch Jinja2ChatFormatter to handle loop controls
|
|
||||||
if hasattr(lib, 'llama_chat_format') and hasattr(lib.llama_chat_format, 'Jinja2ChatFormatter'):
|
|
||||||
Formatter = lib.llama_chat_format.Jinja2ChatFormatter
|
|
||||||
|
|
||||||
if not getattr(Formatter, '_is_patched', False):
|
|
||||||
def patched_init(self, *args, **kwargs):
|
|
||||||
# Extract parameters from args or kwargs
|
|
||||||
if args:
|
|
||||||
self.template = args[0]
|
|
||||||
self.eos_token = args[1] if len(args) > 1 else kwargs.get('eos_token')
|
|
||||||
self.bos_token = args[2] if len(args) > 2 else kwargs.get('bos_token')
|
|
||||||
self.add_generation_prompt = args[3] if len(args) > 3 else kwargs.get('add_generation_prompt', True)
|
|
||||||
self.stop_token_ids = args[4] if len(args) > 4 else kwargs.get('stop_token_ids')
|
|
||||||
else:
|
|
||||||
self.template = kwargs.get('template')
|
|
||||||
self.eos_token = kwargs.get('eos_token')
|
|
||||||
self.bos_token = kwargs.get('bos_token')
|
|
||||||
self.add_generation_prompt = kwargs.get('add_generation_prompt', True)
|
|
||||||
self.stop_token_ids = kwargs.get('stop_token_ids')
|
|
||||||
|
|
||||||
# Process stop tokens as in the original
|
|
||||||
self.stop_token_ids = (
|
|
||||||
set(self.stop_token_ids) if self.stop_token_ids is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create environment with loopcontrols extension
|
|
||||||
import jinja2
|
|
||||||
from jinja2.ext import loopcontrols
|
|
||||||
|
|
||||||
self._environment = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
|
||||||
loader=jinja2.BaseLoader(),
|
|
||||||
trim_blocks=True,
|
|
||||||
lstrip_blocks=True,
|
|
||||||
extensions=[loopcontrols]
|
|
||||||
).from_string(self.template)
|
|
||||||
|
|
||||||
# Replace the original __init__ with our patched version
|
|
||||||
Formatter.__init__ = patched_init
|
|
||||||
Formatter._is_patched = True
|
|
||||||
|
|
||||||
# Set the flag to indicate that the patch has been applied
|
|
||||||
lib.Llama._is_patched = True
|
|
||||||
338
modules/llama_cpp_server.py
Normal file
338
modules/llama_cpp_server.py
Normal file
|
|
@ -0,0 +1,338 @@
|
||||||
|
import json
|
||||||
|
import pprint
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import llama_cpp_binaries
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"}
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaServer:
|
||||||
|
def __init__(self, model_path, server_path=None):
|
||||||
|
"""
|
||||||
|
Initialize and start a server for llama.cpp models.
|
||||||
|
"""
|
||||||
|
self.model_path = model_path
|
||||||
|
self.server_path = server_path
|
||||||
|
self.port = self._find_available_port()
|
||||||
|
self.process = None
|
||||||
|
self.max_context_length = None
|
||||||
|
self.bos_token = "<s>"
|
||||||
|
|
||||||
|
# Start the server
|
||||||
|
self._start_server()
|
||||||
|
|
||||||
|
def encode(self, text, add_bos_token=False, **kwargs):
|
||||||
|
if self.bos_token and text.startswith(self.bos_token):
|
||||||
|
add_bos_token = False
|
||||||
|
|
||||||
|
url = f"http://localhost:{self.port}/tokenize"
|
||||||
|
payload = {
|
||||||
|
"content": text,
|
||||||
|
"add_special": add_bos_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, json=payload)
|
||||||
|
result = response.json()
|
||||||
|
return result.get("tokens", [])
|
||||||
|
|
||||||
|
def decode(self, token_ids, **kwargs):
|
||||||
|
url = f"http://localhost:{self.port}/detokenize"
|
||||||
|
payload = {
|
||||||
|
"tokens": token_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, json=payload)
|
||||||
|
result = response.json()
|
||||||
|
return result.get("content", "")
|
||||||
|
|
||||||
|
def prepare_payload(self, state):
|
||||||
|
# Prepare DRY
|
||||||
|
dry_sequence_breakers = state['dry_sequence_breakers']
|
||||||
|
if not dry_sequence_breakers.startswith("["):
|
||||||
|
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||||
|
dry_sequence_breakers = json.loads(dry_sequence_breakers)
|
||||||
|
|
||||||
|
# Prepare the sampler order
|
||||||
|
samplers = state["sampler_priority"]
|
||||||
|
samplers = samplers.split("\n") if isinstance(samplers, str) else samplers
|
||||||
|
penalty_found = False
|
||||||
|
filtered_samplers = []
|
||||||
|
for s in samplers:
|
||||||
|
if s.strip() in ["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]:
|
||||||
|
filtered_samplers.append(s.strip())
|
||||||
|
elif not penalty_found and s.strip() == "repetition_penalty":
|
||||||
|
filtered_samplers.append("penalties")
|
||||||
|
penalty_found = True
|
||||||
|
|
||||||
|
samplers = filtered_samplers
|
||||||
|
|
||||||
|
# Move temperature to the end if temperature_last is true and temperature exists in the list
|
||||||
|
if state["temperature_last"] and "temperature" in samplers:
|
||||||
|
samplers.remove("temperature")
|
||||||
|
samplers.append("temperature")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2,
|
||||||
|
"dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2,
|
||||||
|
"dynatemp_exponent": state["dynatemp_exponent"],
|
||||||
|
"top_k": state["top_k"],
|
||||||
|
"top_p": state["top_p"],
|
||||||
|
"min_p": state["min_p"],
|
||||||
|
"tfs_z": state["tfs"],
|
||||||
|
"typical_p": state["typical_p"],
|
||||||
|
"repeat_penalty": state["repetition_penalty"],
|
||||||
|
"repeat_last_n": state["repetition_penalty_range"],
|
||||||
|
"presence_penalty": state["presence_penalty"],
|
||||||
|
"frequency_penalty": state["frequency_penalty"],
|
||||||
|
"dry_multiplier": state["dry_multiplier"],
|
||||||
|
"dry_base": state["dry_base"],
|
||||||
|
"dry_allowed_length": state["dry_allowed_length"],
|
||||||
|
"dry_penalty_last_n": state["repetition_penalty_range"],
|
||||||
|
"dry_sequence_breakers": dry_sequence_breakers,
|
||||||
|
"xtc_probability": state["xtc_probability"],
|
||||||
|
"xtc_threshold": state["xtc_threshold"],
|
||||||
|
"mirostat": state["mirostat_mode"],
|
||||||
|
"mirostat_tau": state["mirostat_tau"],
|
||||||
|
"mirostat_eta": state["mirostat_eta"],
|
||||||
|
"grammar": state["grammar_string"],
|
||||||
|
"seed": state["seed"],
|
||||||
|
"ignore_eos": state["ban_eos_token"],
|
||||||
|
"samplers": samplers,
|
||||||
|
}
|
||||||
|
|
||||||
|
if state['custom_token_bans']:
|
||||||
|
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
||||||
|
payload["logit_bias"] = to_ban
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def generate_with_streaming(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
state,
|
||||||
|
):
|
||||||
|
url = f"http://localhost:{self.port}/completion"
|
||||||
|
payload = self.prepare_payload(state)
|
||||||
|
|
||||||
|
token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"])
|
||||||
|
if state['auto_max_new_tokens']:
|
||||||
|
max_new_tokens = state['truncation_length'] - len(token_ids)
|
||||||
|
else:
|
||||||
|
max_new_tokens = state['max_new_tokens']
|
||||||
|
|
||||||
|
payload.update({
|
||||||
|
"prompt": token_ids,
|
||||||
|
"n_predict": max_new_tokens,
|
||||||
|
"stream": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
if shared.args.verbose:
|
||||||
|
logger.info("GENERATE_PARAMS=")
|
||||||
|
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
|
||||||
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Make a direct request with streaming enabled
|
||||||
|
response = requests.post(url, json=payload, stream=True)
|
||||||
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
|
||||||
|
# Process the streaming response
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if shared.stop_everything:
|
||||||
|
break
|
||||||
|
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
# Check if the line starts with "data: " and remove it
|
||||||
|
line_str = line.decode('utf-8')
|
||||||
|
if line_str.startswith('data: '):
|
||||||
|
line_str = line_str[6:] # Remove the "data: " prefix
|
||||||
|
|
||||||
|
# Parse the JSON data
|
||||||
|
data = json.loads(line_str)
|
||||||
|
|
||||||
|
# Extract the token content
|
||||||
|
if 'content' in data:
|
||||||
|
token_text = data['content']
|
||||||
|
full_text += token_text
|
||||||
|
yield full_text
|
||||||
|
|
||||||
|
# Check if generation is complete
|
||||||
|
if data.get('stop', False):
|
||||||
|
break
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# Log the error and the problematic line
|
||||||
|
print(f"JSON decode error: {e}")
|
||||||
|
print(f"Problematic line: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
||||||
|
"""Get the logits/probabilities for the next token after a prompt"""
|
||||||
|
url = f"http://localhost:{self.port}/completion"
|
||||||
|
|
||||||
|
payload = self.prepare_payload(state)
|
||||||
|
payload.update({
|
||||||
|
"prompt": self.encode(prompt, add_bos_token=state["add_bos_token"]),
|
||||||
|
"n_predict": 0,
|
||||||
|
"logprobs": True,
|
||||||
|
"n_probs": n_probs,
|
||||||
|
"stream": False,
|
||||||
|
"post_sampling_probs": use_samplers,
|
||||||
|
})
|
||||||
|
|
||||||
|
if shared.args.verbose:
|
||||||
|
logger.info("GENERATE_PARAMS=")
|
||||||
|
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
|
||||||
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
||||||
|
print()
|
||||||
|
|
||||||
|
response = requests.post(url, json=payload)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if "completion_probabilities" in result:
|
||||||
|
if use_samplers:
|
||||||
|
return result["completion_probabilities"][0]["top_probs"]
|
||||||
|
else:
|
||||||
|
return result["completion_probabilities"][0]["top_logprobs"]
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
||||||
|
|
||||||
|
def _get_max_context_length(self):
|
||||||
|
"""Get and store the model's maximum context length."""
|
||||||
|
url = f"http://localhost:{self.port}/v1/models"
|
||||||
|
response = requests.get(url).json()
|
||||||
|
|
||||||
|
if "data" in response and len(response["data"]) > 0:
|
||||||
|
model_info = response["data"][0]
|
||||||
|
if "meta" in model_info and "n_vocab" in model_info["meta"]:
|
||||||
|
self.max_context_length = model_info["meta"]["n_vocab"]
|
||||||
|
|
||||||
|
def _get_bos_token(self):
|
||||||
|
"""Get and store the model's BOS token."""
|
||||||
|
url = f"http://localhost:{self.port}/props"
|
||||||
|
response = requests.get(url).json()
|
||||||
|
if "bos_token" in response:
|
||||||
|
self.bos_token = response["bos_token"]
|
||||||
|
|
||||||
|
def _find_available_port(self):
|
||||||
|
"""Find an available port by letting the OS assign one."""
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(('', 0)) # Bind to port 0 to get an available port
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
def _start_server(self):
|
||||||
|
"""Start the llama.cpp server and wait until it's ready."""
|
||||||
|
# Determine the server path
|
||||||
|
if self.server_path is None:
|
||||||
|
self.server_path = llama_cpp_binaries.get_binary_path()
|
||||||
|
|
||||||
|
# Build the command
|
||||||
|
cmd = [
|
||||||
|
self.server_path,
|
||||||
|
"--model", self.model_path,
|
||||||
|
"--ctx-size", str(shared.args.n_ctx),
|
||||||
|
"--n-gpu-layers", str(shared.args.n_gpu_layers),
|
||||||
|
"--batch-size", str(shared.args.batch_size),
|
||||||
|
"--port", str(self.port),
|
||||||
|
]
|
||||||
|
|
||||||
|
if shared.args.flash_attn:
|
||||||
|
cmd.append("--flash-attn")
|
||||||
|
if shared.args.threads > 0:
|
||||||
|
cmd += ["--threads", str(shared.args.threads)]
|
||||||
|
if shared.args.threads_batch > 0:
|
||||||
|
cmd += ["--threads-batch", str(shared.args.threads_batch)]
|
||||||
|
if shared.args.no_mmap:
|
||||||
|
cmd.append("--no-mmap")
|
||||||
|
if shared.args.mlock:
|
||||||
|
cmd.append("--mlock")
|
||||||
|
if shared.args.tensor_split:
|
||||||
|
cmd += ["--tensor-split", shared.args.tensor_split]
|
||||||
|
if shared.args.numa:
|
||||||
|
cmd += ["--numa", "distribute"]
|
||||||
|
if shared.args.no_kv_offload:
|
||||||
|
cmd.append("--no-kv-offload")
|
||||||
|
if shared.args.row_split:
|
||||||
|
cmd += ["--split-mode", "row"]
|
||||||
|
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
|
||||||
|
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
|
||||||
|
if shared.args.compress_pos_emb != 1:
|
||||||
|
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
||||||
|
|
||||||
|
# Start the server with pipes for output
|
||||||
|
self.process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
bufsize=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter_stderr():
|
||||||
|
for line in iter(self.process.stderr.readline, ''):
|
||||||
|
if not line.startswith(('srv ', 'slot ')) and not 'log_server_r: request: GET /health' in line:
|
||||||
|
sys.stderr.write(line)
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
threading.Thread(target=filter_stderr, daemon=True).start()
|
||||||
|
|
||||||
|
# Wait for server to be healthy
|
||||||
|
health_url = f"http://localhost:{self.port}/health"
|
||||||
|
start_time = time.time()
|
||||||
|
timeout = 3600 * 8 # 8 hours
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
# Check if process is still alive
|
||||||
|
if self.process.poll() is not None:
|
||||||
|
# Process has terminated
|
||||||
|
exit_code = self.process.poll()
|
||||||
|
raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(health_url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
raise TimeoutError(f"Server health check timed out after {timeout} seconds")
|
||||||
|
|
||||||
|
# Server is now healthy, get model info
|
||||||
|
self._get_max_context_length()
|
||||||
|
self._get_bos_token()
|
||||||
|
return self.port
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Support for context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Support for context manager."""
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup when the object is deleted."""
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the server process."""
|
||||||
|
if self.process:
|
||||||
|
self.process.terminate()
|
||||||
|
try:
|
||||||
|
self.process.wait(timeout=5)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
self.process.kill()
|
||||||
|
|
||||||
|
self.process = None
|
||||||
|
|
@ -1,220 +0,0 @@
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.llama_cpp_python_hijack import llama_cpp_lib
|
|
||||||
from modules.llamacpp_model import get_llamacpp_cache_type_for_string
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
|
|
||||||
class LlamacppHF(PreTrainedModel):
|
|
||||||
def __init__(self, model, path):
|
|
||||||
super().__init__(PretrainedConfig())
|
|
||||||
self.model = model
|
|
||||||
self.generation_config = GenerationConfig()
|
|
||||||
|
|
||||||
self.past_seq = None
|
|
||||||
self.llamacpp_cache = {
|
|
||||||
'n_tokens': self.model.n_tokens,
|
|
||||||
'input_ids': self.model.input_ids,
|
|
||||||
'scores': self.model.scores,
|
|
||||||
'ctx': self.model._ctx.ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
if shared.args.cfg_cache:
|
|
||||||
self.past_seq_negative = None
|
|
||||||
self.llamacpp_cache_negative = {
|
|
||||||
'n_tokens': self.model.n_tokens,
|
|
||||||
'input_ids': self.model.input_ids.copy(),
|
|
||||||
'scores': self.model.scores.copy(),
|
|
||||||
'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.context_params)
|
|
||||||
}
|
|
||||||
|
|
||||||
def _validate_model_class(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
|
||||||
return {'input_ids': input_ids, **kwargs}
|
|
||||||
|
|
||||||
def save_cache(self):
|
|
||||||
self.llamacpp_cache.update({
|
|
||||||
'n_tokens': self.model.n_tokens,
|
|
||||||
'input_ids': self.model.input_ids,
|
|
||||||
'scores': self.model.scores,
|
|
||||||
'ctx': self.model._ctx.ctx
|
|
||||||
})
|
|
||||||
|
|
||||||
def save_negative_cache(self):
|
|
||||||
self.llamacpp_cache_negative.update({
|
|
||||||
'n_tokens': self.model.n_tokens,
|
|
||||||
'input_ids': self.model.input_ids,
|
|
||||||
'scores': self.model.scores,
|
|
||||||
'ctx': self.model._ctx.ctx
|
|
||||||
})
|
|
||||||
|
|
||||||
def load_cache(self):
|
|
||||||
self.model.n_tokens = self.llamacpp_cache['n_tokens']
|
|
||||||
self.model.input_ids = self.llamacpp_cache['input_ids']
|
|
||||||
self.model.scores = self.llamacpp_cache['scores']
|
|
||||||
self.model._ctx.ctx = self.llamacpp_cache['ctx']
|
|
||||||
|
|
||||||
def load_negative_cache(self):
|
|
||||||
self.model.n_tokens = self.llamacpp_cache_negative['n_tokens']
|
|
||||||
self.model.input_ids = self.llamacpp_cache_negative['input_ids']
|
|
||||||
self.model.scores = self.llamacpp_cache_negative['scores']
|
|
||||||
self.model._ctx.ctx = self.llamacpp_cache_negative['ctx']
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
return torch.device(0)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
use_cache = kwargs.get('use_cache', True)
|
|
||||||
labels = kwargs.get('labels', None)
|
|
||||||
past_key_values = kwargs.get('past_key_values', None)
|
|
||||||
|
|
||||||
if len(args) > 0:
|
|
||||||
if not shared.args.cfg_cache:
|
|
||||||
logger.error("Please enable the cfg-cache option to use CFG with llamacpp_HF.")
|
|
||||||
return
|
|
||||||
|
|
||||||
input_ids = args[0]
|
|
||||||
is_negative = True
|
|
||||||
past_seq = self.past_seq_negative
|
|
||||||
self.load_negative_cache()
|
|
||||||
else:
|
|
||||||
input_ids = kwargs['input_ids']
|
|
||||||
is_negative = False
|
|
||||||
past_seq = self.past_seq
|
|
||||||
self.load_cache()
|
|
||||||
|
|
||||||
seq = input_ids[0].tolist()
|
|
||||||
if is_negative and past_key_values is not None:
|
|
||||||
seq = past_key_values + seq
|
|
||||||
|
|
||||||
seq_tensor = torch.tensor(seq)
|
|
||||||
reset = True
|
|
||||||
|
|
||||||
# Make the forward call. The prefix-match code has been adapted from
|
|
||||||
# https://github.com/abetlen/llama-cpp-python/commit/f4090a0bb2a2a25acfe28d31c82cc1aa273bedee
|
|
||||||
if labels is None:
|
|
||||||
if past_seq is not None:
|
|
||||||
min_length = min(past_seq.shape[0], seq_tensor.shape[0])
|
|
||||||
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
|
|
||||||
if len(indices) > 0:
|
|
||||||
longest_prefix = indices[0].item()
|
|
||||||
else:
|
|
||||||
longest_prefix = min_length
|
|
||||||
|
|
||||||
if longest_prefix > 0:
|
|
||||||
reset = False
|
|
||||||
self.model.n_tokens = longest_prefix
|
|
||||||
if len(seq_tensor) - longest_prefix > 0:
|
|
||||||
self.model.eval(seq[longest_prefix:])
|
|
||||||
else:
|
|
||||||
self.model.n_tokens -= 1
|
|
||||||
self.model.eval([seq[-1]])
|
|
||||||
|
|
||||||
if reset:
|
|
||||||
self.model.reset()
|
|
||||||
self.model.eval(seq)
|
|
||||||
|
|
||||||
logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device)
|
|
||||||
else:
|
|
||||||
self.model.reset()
|
|
||||||
self.model.eval(seq)
|
|
||||||
logits = torch.tensor(self.model.eval_logits)
|
|
||||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device)
|
|
||||||
|
|
||||||
if is_negative:
|
|
||||||
self.save_negative_cache()
|
|
||||||
self.past_seq_negative = seq_tensor
|
|
||||||
else:
|
|
||||||
self.save_cache()
|
|
||||||
self.past_seq = seq_tensor
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, logits.shape[-1])
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
|
||||||
assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
|
|
||||||
|
|
||||||
if isinstance(pretrained_model_name_or_path, str):
|
|
||||||
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
|
||||||
|
|
||||||
path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)
|
|
||||||
if path.is_file():
|
|
||||||
model_file = path
|
|
||||||
else:
|
|
||||||
model_file = sorted(path.glob('*.gguf'))[0]
|
|
||||||
|
|
||||||
logger.info(f"llama.cpp weights detected: {model_file}\n")
|
|
||||||
|
|
||||||
if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '':
|
|
||||||
tensor_split_list = None
|
|
||||||
else:
|
|
||||||
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
|
||||||
|
|
||||||
params = {
|
|
||||||
'model_path': str(model_file),
|
|
||||||
'n_ctx': shared.args.n_ctx,
|
|
||||||
'n_threads': shared.args.threads or None,
|
|
||||||
'n_threads_batch': shared.args.threads_batch or None,
|
|
||||||
'n_batch': shared.args.n_batch,
|
|
||||||
'use_mmap': not shared.args.no_mmap,
|
|
||||||
'use_mlock': shared.args.mlock,
|
|
||||||
'mul_mat_q': not shared.args.no_mul_mat_q,
|
|
||||||
'numa': shared.args.numa,
|
|
||||||
'n_gpu_layers': shared.args.n_gpu_layers,
|
|
||||||
'rope_freq_base': shared.args.rope_freq_base,
|
|
||||||
'tensor_split': tensor_split_list,
|
|
||||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
|
||||||
'logits_all': shared.args.logits_all,
|
|
||||||
'offload_kqv': not shared.args.no_offload_kqv,
|
|
||||||
'split_mode': 1 if not shared.args.row_split else 2,
|
|
||||||
'flash_attn': shared.args.flash_attn
|
|
||||||
}
|
|
||||||
|
|
||||||
if shared.args.cache_type != 'fp16':
|
|
||||||
params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
|
||||||
params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
|
||||||
|
|
||||||
Llama = llama_cpp_lib().Llama
|
|
||||||
try:
|
|
||||||
model = Llama(**params)
|
|
||||||
except Exception as e:
|
|
||||||
error_message = (
|
|
||||||
f"Failed loading the model. **This usually happens due to lack of memory**. Try these steps:\n"
|
|
||||||
f"1. Reduce the context length `n_ctx` (currently {shared.args.n_ctx})."
|
|
||||||
f"{' Try a lower value like 4096.' if shared.args.n_ctx > 4096 else '.'}"
|
|
||||||
"\n"
|
|
||||||
f"2. Lower the `n-gpu-layers` value (currently {shared.args.n_gpu_layers})."
|
|
||||||
)
|
|
||||||
|
|
||||||
raise type(e)(error_message) from e
|
|
||||||
|
|
||||||
model.last_updated_index = -1
|
|
||||||
|
|
||||||
return LlamacppHF(model, model_file)
|
|
||||||
|
|
@ -1,218 +0,0 @@
|
||||||
import re
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.callbacks import Iteratorize
|
|
||||||
from modules.llama_cpp_python_hijack import llama_cpp_lib
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
from modules.text_generation import get_max_prompt_length
|
|
||||||
|
|
||||||
llamacpp_quant_mapping = {
|
|
||||||
'f32': 0,
|
|
||||||
'fp16': 1,
|
|
||||||
'q4_0': 2,
|
|
||||||
'q4_1': 3,
|
|
||||||
'q5_0': 6,
|
|
||||||
'q5_1': 7,
|
|
||||||
'q8_0': 8,
|
|
||||||
'q8_1': 9,
|
|
||||||
'q2_k': 10,
|
|
||||||
'q3_k': 11,
|
|
||||||
'q4_k': 12,
|
|
||||||
'q5_k': 13,
|
|
||||||
'q6_k': 14,
|
|
||||||
'q8_k': 15,
|
|
||||||
'iq4_nl': 20,
|
|
||||||
'bf16': 30,
|
|
||||||
}
|
|
||||||
|
|
||||||
llamacpp_valid_cache_types = {'fp16', 'q8_0', 'q4_0'}
|
|
||||||
|
|
||||||
|
|
||||||
def get_llamacpp_cache_type_for_string(quant_type: str):
|
|
||||||
quant_type = quant_type.lower()
|
|
||||||
if quant_type in llamacpp_valid_cache_types:
|
|
||||||
return llamacpp_quant_mapping[quant_type]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid cache type for llama.cpp: {quant_type}. Valid options are: fp16, q8_0, q4_0.")
|
|
||||||
|
|
||||||
|
|
||||||
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
|
||||||
logits[eos_token] = -float('inf')
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def custom_token_ban_logits_processor(token_ids, input_ids, logits):
|
|
||||||
for token_id in token_ids:
|
|
||||||
logits[token_id] = -float('inf')
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaCppModel:
|
|
||||||
def __init__(self):
|
|
||||||
self.initialized = False
|
|
||||||
self.grammar_string = ''
|
|
||||||
self.grammar = None
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
del self.model
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(self, path):
|
|
||||||
|
|
||||||
Llama = llama_cpp_lib().Llama
|
|
||||||
LlamaCache = llama_cpp_lib().LlamaCache
|
|
||||||
|
|
||||||
result = self()
|
|
||||||
cache_capacity = 0
|
|
||||||
if shared.args.cache_capacity is not None:
|
|
||||||
if 'GiB' in shared.args.cache_capacity:
|
|
||||||
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000
|
|
||||||
elif 'MiB' in shared.args.cache_capacity:
|
|
||||||
cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000
|
|
||||||
else:
|
|
||||||
cache_capacity = int(shared.args.cache_capacity)
|
|
||||||
|
|
||||||
if cache_capacity > 0:
|
|
||||||
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
|
||||||
|
|
||||||
if shared.args.tensor_split is None or shared.args.tensor_split.strip() == '':
|
|
||||||
tensor_split_list = None
|
|
||||||
else:
|
|
||||||
tensor_split_list = [float(x) for x in shared.args.tensor_split.strip().split(",")]
|
|
||||||
|
|
||||||
params = {
|
|
||||||
'model_path': str(path),
|
|
||||||
'n_ctx': shared.args.n_ctx,
|
|
||||||
'n_threads': shared.args.threads or None,
|
|
||||||
'n_threads_batch': shared.args.threads_batch or None,
|
|
||||||
'n_batch': shared.args.n_batch,
|
|
||||||
'use_mmap': not shared.args.no_mmap,
|
|
||||||
'use_mlock': shared.args.mlock,
|
|
||||||
'mul_mat_q': not shared.args.no_mul_mat_q,
|
|
||||||
'numa': shared.args.numa,
|
|
||||||
'n_gpu_layers': shared.args.n_gpu_layers,
|
|
||||||
'rope_freq_base': shared.args.rope_freq_base,
|
|
||||||
'tensor_split': tensor_split_list,
|
|
||||||
'rope_freq_scale': 1.0 / shared.args.compress_pos_emb,
|
|
||||||
'offload_kqv': not shared.args.no_offload_kqv,
|
|
||||||
'split_mode': 1 if not shared.args.row_split else 2,
|
|
||||||
'flash_attn': shared.args.flash_attn
|
|
||||||
}
|
|
||||||
|
|
||||||
if shared.args.cache_type != 'fp16':
|
|
||||||
params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
|
||||||
params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result.model = Llama(**params)
|
|
||||||
except Exception as e:
|
|
||||||
error_message = (
|
|
||||||
f"Failed loading the model. **This usually happens due to lack of memory**. Try these steps:\n"
|
|
||||||
f"1. Reduce the context length `n_ctx` (currently {shared.args.n_ctx})."
|
|
||||||
f"{' Try a lower value like 4096.' if shared.args.n_ctx > 4096 else '.'}"
|
|
||||||
"\n"
|
|
||||||
f"2. Lower the `n-gpu-layers` value (currently {shared.args.n_gpu_layers})."
|
|
||||||
)
|
|
||||||
|
|
||||||
raise type(e)(error_message) from e
|
|
||||||
|
|
||||||
if cache_capacity > 0:
|
|
||||||
result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
|
|
||||||
|
|
||||||
# This is ugly, but the model and the tokenizer are the same object in this library.
|
|
||||||
return result, result
|
|
||||||
|
|
||||||
def encode(self, string):
|
|
||||||
if type(string) is str:
|
|
||||||
string = string.encode()
|
|
||||||
|
|
||||||
return self.model.tokenize(string)
|
|
||||||
|
|
||||||
def decode(self, ids, **kwargs):
|
|
||||||
detokenized = self.model.detokenize(ids)
|
|
||||||
try:
|
|
||||||
# Attempt strict UTF-8 decoding first
|
|
||||||
return detokenized.decode('utf-8', 'strict')
|
|
||||||
except UnicodeDecodeError as e:
|
|
||||||
# Log the error and fall back to UTF-8 with replacement
|
|
||||||
logger.warning(f"Invalid UTF-8 in detokenized output. Using replacement characters.\n{e}")
|
|
||||||
return detokenized.decode('utf-8', 'replace')
|
|
||||||
|
|
||||||
def get_logits(self, tokens):
|
|
||||||
self.model.reset()
|
|
||||||
self.model.eval(tokens)
|
|
||||||
logits = self.model._scores
|
|
||||||
logits = np.expand_dims(logits, 0) # batch dim is expected
|
|
||||||
return torch.tensor(logits, dtype=torch.float32)
|
|
||||||
|
|
||||||
def load_grammar(self, string):
|
|
||||||
if string != self.grammar_string:
|
|
||||||
self.grammar_string = string
|
|
||||||
if string.strip() != '':
|
|
||||||
self.grammar = llama_cpp_lib().LlamaGrammar.from_string(string)
|
|
||||||
else:
|
|
||||||
self.grammar = None
|
|
||||||
|
|
||||||
def generate(self, prompt, state, callback=None):
|
|
||||||
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
|
|
||||||
prompt = prompt if type(prompt) is str else prompt.decode()
|
|
||||||
|
|
||||||
# Handle truncation
|
|
||||||
prompt = self.encode(prompt)
|
|
||||||
prompt = prompt[-get_max_prompt_length(state):]
|
|
||||||
prompt = self.decode(prompt)
|
|
||||||
|
|
||||||
self.load_grammar(state['grammar_string'])
|
|
||||||
logit_processors = LogitsProcessorList()
|
|
||||||
if state['ban_eos_token']:
|
|
||||||
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos()))
|
|
||||||
|
|
||||||
if state['custom_token_bans']:
|
|
||||||
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
|
||||||
if len(to_ban) > 0:
|
|
||||||
logit_processors.append(partial(custom_token_ban_logits_processor, to_ban))
|
|
||||||
|
|
||||||
completion_chunks = self.model.create_completion(
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=state['max_new_tokens'],
|
|
||||||
temperature=state['temperature'],
|
|
||||||
top_p=state['top_p'] if state['top_p'] < 1 else 0.999,
|
|
||||||
min_p=state['min_p'],
|
|
||||||
typical_p=state['typical_p'],
|
|
||||||
frequency_penalty=state['frequency_penalty'],
|
|
||||||
presence_penalty=state['presence_penalty'],
|
|
||||||
repeat_penalty=state['repetition_penalty'],
|
|
||||||
top_k=state['top_k'],
|
|
||||||
stream=True,
|
|
||||||
seed=int(state['seed']) if state['seed'] != -1 else None,
|
|
||||||
tfs_z=state['tfs'],
|
|
||||||
mirostat_mode=int(state['mirostat_mode']),
|
|
||||||
mirostat_tau=state['mirostat_tau'],
|
|
||||||
mirostat_eta=state['mirostat_eta'],
|
|
||||||
logits_processor=logit_processors,
|
|
||||||
grammar=self.grammar
|
|
||||||
)
|
|
||||||
|
|
||||||
output = ""
|
|
||||||
for completion_chunk in completion_chunks:
|
|
||||||
if shared.stop_everything:
|
|
||||||
break
|
|
||||||
|
|
||||||
text = completion_chunk['choices'][0]['text']
|
|
||||||
output += text
|
|
||||||
if callback:
|
|
||||||
callback(text)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def generate_with_streaming(self, *args, **kwargs):
|
|
||||||
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
|
||||||
reply = ''
|
|
||||||
for token in generator:
|
|
||||||
reply += token
|
|
||||||
yield reply
|
|
||||||
|
|
@ -30,51 +30,20 @@ loaders_and_params = OrderedDict({
|
||||||
'n_gpu_layers',
|
'n_gpu_layers',
|
||||||
'threads',
|
'threads',
|
||||||
'threads_batch',
|
'threads_batch',
|
||||||
'n_batch',
|
'batch_size',
|
||||||
'n_ctx',
|
'n_ctx',
|
||||||
'cache_type',
|
'cache_type',
|
||||||
'tensor_split',
|
'tensor_split',
|
||||||
'rope_freq_base',
|
'rope_freq_base',
|
||||||
'compress_pos_emb',
|
'compress_pos_emb',
|
||||||
'attention_sink_size',
|
|
||||||
'tensorcores',
|
|
||||||
'flash_attn',
|
'flash_attn',
|
||||||
'streaming_llm',
|
|
||||||
'cpu',
|
|
||||||
'row_split',
|
'row_split',
|
||||||
'no_offload_kqv',
|
'no_kv_offload',
|
||||||
'no_mul_mat_q',
|
'no_mul_mat_q',
|
||||||
'no_mmap',
|
'no_mmap',
|
||||||
'mlock',
|
'mlock',
|
||||||
'numa',
|
'numa',
|
||||||
],
|
],
|
||||||
'llamacpp_HF': [
|
|
||||||
'n_gpu_layers',
|
|
||||||
'threads',
|
|
||||||
'threads_batch',
|
|
||||||
'n_batch',
|
|
||||||
'n_ctx',
|
|
||||||
'cache_type',
|
|
||||||
'tensor_split',
|
|
||||||
'rope_freq_base',
|
|
||||||
'compress_pos_emb',
|
|
||||||
'attention_sink_size',
|
|
||||||
'tensorcores',
|
|
||||||
'flash_attn',
|
|
||||||
'streaming_llm',
|
|
||||||
'cpu',
|
|
||||||
'row_split',
|
|
||||||
'no_offload_kqv',
|
|
||||||
'no_mul_mat_q',
|
|
||||||
'no_mmap',
|
|
||||||
'mlock',
|
|
||||||
'numa',
|
|
||||||
'cfg_cache',
|
|
||||||
'logits_all',
|
|
||||||
'trust_remote_code',
|
|
||||||
'no_use_fast',
|
|
||||||
'llamacpp_HF_info',
|
|
||||||
],
|
|
||||||
'ExLlamav3_HF': [
|
'ExLlamav3_HF': [
|
||||||
'max_seq_len',
|
'max_seq_len',
|
||||||
'gpu_split',
|
'gpu_split',
|
||||||
|
|
@ -307,66 +276,34 @@ loaders_samplers = {
|
||||||
'dry_sequence_breakers',
|
'dry_sequence_breakers',
|
||||||
},
|
},
|
||||||
'llama.cpp': {
|
'llama.cpp': {
|
||||||
'temperature',
|
|
||||||
'min_p',
|
|
||||||
'top_p',
|
|
||||||
'top_k',
|
|
||||||
'typical_p',
|
|
||||||
'tfs',
|
|
||||||
'repetition_penalty',
|
|
||||||
'frequency_penalty',
|
|
||||||
'presence_penalty',
|
|
||||||
'mirostat_mode',
|
|
||||||
'mirostat_tau',
|
|
||||||
'mirostat_eta',
|
|
||||||
'ban_eos_token',
|
|
||||||
'seed',
|
|
||||||
'custom_token_bans',
|
|
||||||
'grammar_string',
|
|
||||||
'grammar_file_row',
|
|
||||||
},
|
|
||||||
'llamacpp_HF': {
|
|
||||||
'temperature',
|
'temperature',
|
||||||
'dynatemp_low',
|
'dynatemp_low',
|
||||||
'dynatemp_high',
|
'dynatemp_high',
|
||||||
'dynatemp_exponent',
|
'dynatemp_exponent',
|
||||||
'smoothing_factor',
|
|
||||||
'smoothing_curve',
|
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_p',
|
'top_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
'typical_p',
|
'typical_p',
|
||||||
'xtc_threshold',
|
'xtc_threshold',
|
||||||
'xtc_probability',
|
'xtc_probability',
|
||||||
'epsilon_cutoff',
|
|
||||||
'eta_cutoff',
|
|
||||||
'tfs',
|
'tfs',
|
||||||
'top_a',
|
|
||||||
'top_n_sigma',
|
|
||||||
'dry_multiplier',
|
'dry_multiplier',
|
||||||
'dry_allowed_length',
|
'dry_allowed_length',
|
||||||
'dry_base',
|
'dry_base',
|
||||||
'repetition_penalty',
|
'repetition_penalty',
|
||||||
'frequency_penalty',
|
'frequency_penalty',
|
||||||
'presence_penalty',
|
'presence_penalty',
|
||||||
'encoder_repetition_penalty',
|
|
||||||
'no_repeat_ngram_size',
|
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'guidance_scale',
|
|
||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
'do_sample',
|
|
||||||
'dynamic_temperature',
|
'dynamic_temperature',
|
||||||
'temperature_last',
|
'temperature_last',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'skip_special_tokens',
|
|
||||||
'seed',
|
'seed',
|
||||||
'sampler_priority',
|
'sampler_priority',
|
||||||
'custom_token_bans',
|
|
||||||
'negative_prompt',
|
|
||||||
'dry_sequence_breakers',
|
'dry_sequence_breakers',
|
||||||
'grammar_string',
|
'grammar_string',
|
||||||
'grammar_file_row',
|
'grammar_file_row',
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import models, sampler_hijack, shared
|
from modules import models, sampler_hijack, shared
|
||||||
|
|
@ -38,10 +39,32 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||||
return 'Error: No model is loaded1 Select one in the Model tab.', previous
|
return 'Error: No model is loaded1 Select one in the Model tab.', previous
|
||||||
|
|
||||||
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
|
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
|
||||||
is_non_hf_llamacpp = shared.model.__class__.__name__ == 'LlamaCppModel'
|
is_llamacpp = shared.model.__class__.__name__ == 'LlamaServer'
|
||||||
|
|
||||||
|
if is_llamacpp:
|
||||||
|
logprobs = shared.model.get_logits(prompt, state, n_probs=top_logits, use_samplers=use_samplers)
|
||||||
|
if return_dict:
|
||||||
|
output = {}
|
||||||
|
for entry in logprobs:
|
||||||
|
token = repr(entry['token'])
|
||||||
|
prob = entry['prob'] if use_samplers else np.exp(entry['logprob'])
|
||||||
|
output[token] = prob
|
||||||
|
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
output = ''
|
||||||
|
for entry in logprobs:
|
||||||
|
token = repr(entry['token'])
|
||||||
|
prob = entry['prob'] if use_samplers else np.exp(entry['logprob'])
|
||||||
|
output += f"{prob:.5f} - {token}\n"
|
||||||
|
|
||||||
|
return output, previous
|
||||||
|
else:
|
||||||
|
if not use_samplers:
|
||||||
|
state = {'stream': True}
|
||||||
|
|
||||||
if use_samplers:
|
if use_samplers:
|
||||||
if any([is_non_hf_exllamav2, is_non_hf_llamacpp]):
|
if is_non_hf_exllamav2:
|
||||||
logger.error("Sampler hijacking is not supported non-Huggingface loaders.")
|
logger.error("Sampler hijacking is not supported non-Huggingface loaders.")
|
||||||
# sampling is all done in c for exllama, so it is really hard to hijack
|
# sampling is all done in c for exllama, so it is really hard to hijack
|
||||||
# it should be possible to hijack llamacpp sampler by hijacking all their sampling methods,
|
# it should be possible to hijack llamacpp sampler by hijacking all their sampling methods,
|
||||||
|
|
@ -61,9 +84,6 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||||
if device:
|
if device:
|
||||||
tokens = tokens.to(device)
|
tokens = tokens.to(device)
|
||||||
|
|
||||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
|
||||||
elif is_non_hf_llamacpp:
|
|
||||||
tokens = shared.tokenizer.encode(prompt)
|
|
||||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||||
else:
|
else:
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
@ -76,9 +96,6 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
||||||
|
|
||||||
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
||||||
topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True)
|
topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True)
|
||||||
if is_non_hf_llamacpp:
|
|
||||||
topk_indices = [i.expand((1, 1)) for i in topk_indices]
|
|
||||||
|
|
||||||
if hasattr(shared.tokenizer, 'convert_ids_to_tokens'):
|
if hasattr(shared.tokenizer, 'convert_ids_to_tokens'):
|
||||||
tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]
|
tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -67,8 +67,7 @@ def load_model(model_name, loader=None):
|
||||||
shared.model_name = model_name
|
shared.model_name = model_name
|
||||||
load_func_map = {
|
load_func_map = {
|
||||||
'Transformers': huggingface_loader,
|
'Transformers': huggingface_loader,
|
||||||
'llama.cpp': llamacpp_loader,
|
'llama.cpp': llama_cpp_server_loader,
|
||||||
'llamacpp_HF': llamacpp_HF_loader,
|
|
||||||
'ExLlamav3_HF': ExLlamav3_HF_loader,
|
'ExLlamav3_HF': ExLlamav3_HF_loader,
|
||||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||||
'ExLlamav2': ExLlamav2_loader,
|
'ExLlamav2': ExLlamav2_loader,
|
||||||
|
|
@ -101,7 +100,7 @@ def load_model(model_name, loader=None):
|
||||||
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
||||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'):
|
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'):
|
||||||
shared.settings['truncation_length'] = shared.args.max_seq_len
|
shared.settings['truncation_length'] = shared.args.max_seq_len
|
||||||
elif loader in ['llama.cpp', 'llamacpp_HF']:
|
elif loader == 'llama.cpp':
|
||||||
shared.settings['truncation_length'] = shared.args.n_ctx
|
shared.settings['truncation_length'] = shared.args.n_ctx
|
||||||
|
|
||||||
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
||||||
|
|
@ -268,8 +267,8 @@ def huggingface_loader(model_name):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def llamacpp_loader(model_name):
|
def llama_cpp_server_loader(model_name):
|
||||||
from modules.llamacpp_model import LlamaCppModel
|
from modules.llama_cpp_server import LlamaServer
|
||||||
|
|
||||||
path = Path(f'{shared.args.model_dir}/{model_name}')
|
path = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
|
|
@ -278,31 +277,11 @@ def llamacpp_loader(model_name):
|
||||||
model_file = sorted(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0]
|
model_file = sorted(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0]
|
||||||
|
|
||||||
logger.info(f"llama.cpp weights detected: \"{model_file}\"")
|
logger.info(f"llama.cpp weights detected: \"{model_file}\"")
|
||||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
try:
|
||||||
return model, tokenizer
|
model = LlamaServer(model_file)
|
||||||
|
return model, model
|
||||||
|
except Exception as e:
|
||||||
def llamacpp_HF_loader(model_name):
|
logger.error(f"Error loading the model with llama.cpp: {str(e)}")
|
||||||
from modules.llamacpp_hf import LlamacppHF
|
|
||||||
|
|
||||||
if shared.args.tokenizer_dir:
|
|
||||||
logger.info(f'Using tokenizer from: \"{shared.args.tokenizer_dir}\"')
|
|
||||||
else:
|
|
||||||
path = Path(f'{shared.args.model_dir}/{model_name}')
|
|
||||||
# Check if a HF tokenizer is available for the model
|
|
||||||
if all((path / file).exists() for file in ['tokenizer_config.json']):
|
|
||||||
logger.info(f'Using tokenizer from: \"{path}\"')
|
|
||||||
else:
|
|
||||||
logger.error("Could not load the model because a tokenizer in Transformers format was not found.")
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
model = LlamacppHF.from_pretrained(model_name)
|
|
||||||
|
|
||||||
if shared.args.tokenizer_dir:
|
|
||||||
tokenizer = load_tokenizer(model_name, tokenizer_dir=shared.args.tokenizer_dir)
|
|
||||||
return model, tokenizer
|
|
||||||
else:
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def ExLlamav3_HF_loader(model_name):
|
def ExLlamav3_HF_loader(model_name):
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ def get_model_metadata(model):
|
||||||
)
|
)
|
||||||
|
|
||||||
# GGUF metadata
|
# GGUF metadata
|
||||||
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF']:
|
if model_settings['loader'] == 'llama.cpp':
|
||||||
path = Path(f'{shared.args.model_dir}/{model}')
|
path = Path(f'{shared.args.model_dir}/{model}')
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
model_file = path
|
model_file = path
|
||||||
|
|
@ -163,8 +163,6 @@ def infer_loader(model_name, model_settings, hf_quant_method=None):
|
||||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
if not path_to_model.exists():
|
if not path_to_model.exists():
|
||||||
loader = None
|
loader = None
|
||||||
elif len(list(path_to_model.glob('*.gguf'))) > 0 and path_to_model.is_dir() and (path_to_model / 'tokenizer_config.json').exists():
|
|
||||||
loader = 'llamacpp_HF'
|
|
||||||
elif len(list(path_to_model.glob('*.gguf'))) > 0:
|
elif len(list(path_to_model.glob('*.gguf'))) > 0:
|
||||||
loader = 'llama.cpp'
|
loader = 'llama.cpp'
|
||||||
elif re.match(r'.*\.gguf', model_name.lower()):
|
elif re.match(r'.*\.gguf', model_name.lower()):
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft
|
||||||
|
|
||||||
# Model loader
|
# Model loader
|
||||||
group = parser.add_argument_group('Model loader')
|
group = parser.add_argument_group('Model loader')
|
||||||
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, llamacpp_HF, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.')
|
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, HQQ, TensorRT-LLM.')
|
||||||
|
|
||||||
# Transformers/Accelerate
|
# Transformers/Accelerate
|
||||||
group = parser.add_argument_group('Transformers/Accelerate')
|
group = parser.add_argument_group('Transformers/Accelerate')
|
||||||
|
|
@ -116,24 +116,17 @@ group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for
|
||||||
# llama.cpp
|
# llama.cpp
|
||||||
group = parser.add_argument_group('llama.cpp')
|
group = parser.add_argument_group('llama.cpp')
|
||||||
group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.')
|
group.add_argument('--flash-attn', action='store_true', help='Use flash-attention.')
|
||||||
group.add_argument('--tensorcores', action='store_true', help='NVIDIA only: use llama-cpp-python compiled without GGML_CUDA_FORCE_MMQ. This may improve performance on newer cards.')
|
|
||||||
group.add_argument('--n_ctx', type=int, default=8192, help='Size of the prompt context.')
|
group.add_argument('--n_ctx', type=int, default=8192, help='Size of the prompt context.')
|
||||||
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
group.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
||||||
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.')
|
||||||
group.add_argument('--no_mul_mat_q', action='store_true', help='Disable the mulmat kernels.')
|
group.add_argument('--batch-size', type=int, default=2048, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
|
||||||
group.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
|
|
||||||
group.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
|
group.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
|
||||||
group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
|
group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
|
||||||
group.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.')
|
group.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.')
|
||||||
group.add_argument('--tensor_split', type=str, default=None, help='Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.')
|
group.add_argument('--tensor-split', type=str, default=None, help='Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.')
|
||||||
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
|
||||||
group.add_argument('--logits_all', action='store_true', help='Needs to be set for perplexity evaluation to work. Otherwise, ignore it, as it makes prompt processing slower.')
|
group.add_argument('--no-kv-offload', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
||||||
group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||||
group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (llama-cpp-python). Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
|
|
||||||
group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
|
||||||
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
|
||||||
group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.')
|
|
||||||
group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.')
|
|
||||||
|
|
||||||
# ExLlamaV2
|
# ExLlamaV2
|
||||||
group = parser.add_argument_group('ExLlamaV2')
|
group = parser.add_argument_group('ExLlamaV2')
|
||||||
|
|
@ -212,6 +205,13 @@ group.add_argument('--wbits', type=int, default=0, help='DEPRECATED')
|
||||||
group.add_argument('--groupsize', type=int, default=-1, help='DEPRECATED')
|
group.add_argument('--groupsize', type=int, default=-1, help='DEPRECATED')
|
||||||
group.add_argument('--model-menu', action='store_true', help='DEPRECATED')
|
group.add_argument('--model-menu', action='store_true', help='DEPRECATED')
|
||||||
group.add_argument('--multimodal-pipeline', type=str, default=None, help='DEPRECATED')
|
group.add_argument('--multimodal-pipeline', type=str, default=None, help='DEPRECATED')
|
||||||
|
group.add_argument('--streaming-llm', action='store_true', help='DEPRECATED')
|
||||||
|
group.add_argument('--attention-sink-size', type=int, default=5, help='DEPRECATED')
|
||||||
|
group.add_argument('--tokenizer-dir', type=str, help='DEPRECATED')
|
||||||
|
group.add_argument('--logits_all', action='store_true', help='DEPRECATED')
|
||||||
|
group.add_argument('--no_mul_mat_q', action='store_true', help='DEPRECATED')
|
||||||
|
group.add_argument('--cache-capacity', type=str, help='DEPRECATED')
|
||||||
|
group.add_argument('--tensorcores', action='store_true', help='DEPRECATED')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args_defaults = parser.parse_args([])
|
args_defaults = parser.parse_args([])
|
||||||
|
|
@ -262,8 +262,6 @@ def fix_loader_name(name):
|
||||||
name = name.lower()
|
name = name.lower()
|
||||||
if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']:
|
if name in ['llamacpp', 'llama.cpp', 'llama-cpp', 'llama cpp']:
|
||||||
return 'llama.cpp'
|
return 'llama.cpp'
|
||||||
if name in ['llamacpp_hf', 'llama.cpp_hf', 'llama-cpp-hf', 'llamacpp-hf', 'llama.cpp-hf']:
|
|
||||||
return 'llamacpp_HF'
|
|
||||||
elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
|
elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']:
|
||||||
return 'Transformers'
|
return 'Transformers'
|
||||||
elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']:
|
elif name in ['exllamav2', 'exllama-v2', 'ex_llama-v2', 'exlamav2', 'exlama-v2', 'exllama2', 'exllama-2']:
|
||||||
|
|
@ -316,7 +314,7 @@ def transform_legacy_kv_cache_options(opts):
|
||||||
set('cache_type', 'fp8')
|
set('cache_type', 'fp8')
|
||||||
elif cache_4bit:
|
elif cache_4bit:
|
||||||
set('cache_type', 'q4')
|
set('cache_type', 'q4')
|
||||||
elif loader.lower() in ['llama.cpp', 'llamacpp_hf']:
|
elif loader.lower() == 'llama.cpp':
|
||||||
# Llama.cpp loader-specific cache type
|
# Llama.cpp loader-specific cache type
|
||||||
if cache_4bit:
|
if cache_4bit:
|
||||||
set('cache_type', 'q4_0')
|
set('cache_type', 'q4_0')
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from transformers import (
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import models, sampler_hijack
|
from modules import models, sampler_hijack
|
||||||
from modules.cache_utils import process_llamacpp_cache
|
|
||||||
from modules.callbacks import (
|
from modules.callbacks import (
|
||||||
Iteratorize,
|
Iteratorize,
|
||||||
Stream,
|
Stream,
|
||||||
|
|
@ -56,7 +55,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||||
yield ''
|
yield ''
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel']:
|
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']:
|
||||||
generate_func = generate_reply_custom
|
generate_func = generate_reply_custom
|
||||||
else:
|
else:
|
||||||
generate_func = generate_reply_HF
|
generate_func = generate_reply_HF
|
||||||
|
|
@ -133,8 +132,12 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||||
if shared.tokenizer is None:
|
if shared.tokenizer is None:
|
||||||
raise ValueError('No tokenizer is loaded')
|
raise ValueError('No tokenizer is loaded')
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel']:
|
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']:
|
||||||
|
if shared.model.__class__.__name__ == 'LlamaServer':
|
||||||
|
input_ids = shared.tokenizer.encode(str(prompt), add_bos_token=add_bos_token)
|
||||||
|
else:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt))
|
input_ids = shared.tokenizer.encode(str(prompt))
|
||||||
|
|
||||||
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
|
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
|
||||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||||
else:
|
else:
|
||||||
|
|
@ -159,7 +162,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||||
if truncation_length is not None:
|
if truncation_length is not None:
|
||||||
input_ids = input_ids[:, -truncation_length:]
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||||
return input_ids
|
return input_ids
|
||||||
else:
|
else:
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
@ -186,7 +189,7 @@ def get_encoded_length(prompt):
|
||||||
|
|
||||||
def get_token_ids(prompt):
|
def get_token_ids(prompt):
|
||||||
tokens = encode(prompt)[0]
|
tokens = encode(prompt)[0]
|
||||||
decoded_tokens = [shared.tokenizer.decode([i]) for i in tokens]
|
decoded_tokens = [shared.tokenizer.decode([int(i)]) for i in tokens]
|
||||||
|
|
||||||
output = ''
|
output = ''
|
||||||
for row in list(zip(tokens, decoded_tokens)):
|
for row in list(zip(tokens, decoded_tokens)):
|
||||||
|
|
@ -401,12 +404,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||||
logger.info("PROMPT=")
|
logger.info("PROMPT=")
|
||||||
print_prompt(decode(input_ids[0], skip_special_tokens=False))
|
print_prompt(decode(input_ids[0], skip_special_tokens=False))
|
||||||
|
|
||||||
# Handle StreamingLLM for llamacpp_HF
|
|
||||||
if shared.model.__class__.__name__ == 'LlamacppHF' and shared.args.streaming_llm:
|
|
||||||
tmp = process_llamacpp_cache(shared.model.model, input_ids[-1].tolist(), shared.model.model._input_ids.tolist())
|
|
||||||
shared.model.past_seq = torch.tensor(tmp)
|
|
||||||
shared.model.save_cache()
|
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat and not shared.is_seq2seq:
|
if not is_chat and not shared.is_seq2seq:
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ def list_model_elements():
|
||||||
'n_gpu_layers',
|
'n_gpu_layers',
|
||||||
'threads',
|
'threads',
|
||||||
'threads_batch',
|
'threads_batch',
|
||||||
'n_batch',
|
'batch_size',
|
||||||
'hqq_backend',
|
'hqq_backend',
|
||||||
'n_ctx',
|
'n_ctx',
|
||||||
'max_seq_len',
|
'max_seq_len',
|
||||||
|
|
@ -122,20 +122,17 @@ def list_model_elements():
|
||||||
'compress_pos_emb',
|
'compress_pos_emb',
|
||||||
'compute_dtype',
|
'compute_dtype',
|
||||||
'quant_type',
|
'quant_type',
|
||||||
'attention_sink_size',
|
|
||||||
'num_experts_per_token',
|
'num_experts_per_token',
|
||||||
'tensorcores',
|
|
||||||
'load_in_8bit',
|
'load_in_8bit',
|
||||||
'load_in_4bit',
|
'load_in_4bit',
|
||||||
'torch_compile',
|
'torch_compile',
|
||||||
'flash_attn',
|
'flash_attn',
|
||||||
'use_flash_attention_2',
|
'use_flash_attention_2',
|
||||||
'streaming_llm',
|
|
||||||
'auto_devices',
|
'auto_devices',
|
||||||
'cpu',
|
'cpu',
|
||||||
'disk',
|
'disk',
|
||||||
'row_split',
|
'row_split',
|
||||||
'no_offload_kqv',
|
'no_kv_offload',
|
||||||
'no_mul_mat_q',
|
'no_mul_mat_q',
|
||||||
'no_mmap',
|
'no_mmap',
|
||||||
'mlock',
|
'mlock',
|
||||||
|
|
@ -150,7 +147,6 @@ def list_model_elements():
|
||||||
'no_sdpa',
|
'no_sdpa',
|
||||||
'cfg_cache',
|
'cfg_cache',
|
||||||
'cpp_runner',
|
'cpp_runner',
|
||||||
'logits_all',
|
|
||||||
'trust_remote_code',
|
'trust_remote_code',
|
||||||
'no_use_fast',
|
'no_use_fast',
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ def create_ui():
|
||||||
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=256, value=shared.args.n_gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
|
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=256, value=shared.args.n_gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
|
||||||
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
|
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
|
||||||
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
|
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
|
||||||
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, step=1, value=shared.args.n_batch)
|
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
||||||
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
|
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
|
||||||
shared.gradio['n_ctx'] = gr.Number(label="n_ctx", precision=0, step=256, value=shared.args.n_ctx, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.')
|
shared.gradio['n_ctx'] = gr.Number(label="n_ctx", precision=0, step=256, value=shared.args.n_ctx, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.')
|
||||||
shared.gradio['max_seq_len'] = gr.Number(label='max_seq_len', precision=0, step=256, value=shared.args.max_seq_len, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.')
|
shared.gradio['max_seq_len'] = gr.Number(label='max_seq_len', precision=0, step=256, value=shared.args.max_seq_len, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.')
|
||||||
|
|
@ -99,22 +99,19 @@ def create_ui():
|
||||||
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
|
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
|
||||||
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
|
||||||
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
|
||||||
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
|
|
||||||
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled without GGML_CUDA_FORCE_MMQ. This may improve performance on newer cards.')
|
|
||||||
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
||||||
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
||||||
shared.gradio['torch_compile'] = gr.Checkbox(label="torch-compile", value=shared.args.torch_compile, info='Compile the model with torch.compile for improved performance.')
|
shared.gradio['torch_compile'] = gr.Checkbox(label="torch-compile", value=shared.args.torch_compile, info='Compile the model with torch.compile for improved performance.')
|
||||||
shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.')
|
shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.')
|
||||||
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
|
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
|
||||||
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
|
||||||
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
||||||
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')
|
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')
|
||||||
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
|
||||||
shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
|
||||||
shared.gradio['no_offload_kqv'] = gr.Checkbox(label="no_offload_kqv", value=shared.args.no_offload_kqv, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
|
||||||
shared.gradio['no_mul_mat_q'] = gr.Checkbox(label="no_mul_mat_q", value=shared.args.no_mul_mat_q, info='Disable the mulmat kernels.')
|
shared.gradio['no_mul_mat_q'] = gr.Checkbox(label="no_mul_mat_q", value=shared.args.no_mul_mat_q, info='Disable the mulmat kernels.')
|
||||||
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
||||||
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
||||||
|
|
@ -129,7 +126,6 @@ def create_ui():
|
||||||
shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa)
|
shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa)
|
||||||
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
|
shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
|
||||||
shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
|
shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
|
||||||
shared.gradio['logits_all'] = gr.Checkbox(label="logits_all", value=shared.args.logits_all, info='Needs to be set for perplexity evaluation to work with this loader. Otherwise, ignore it, as it makes prompt processing slower.')
|
|
||||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
||||||
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
|
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
|
||||||
shared.gradio['llamacpp_HF_info'] = gr.Markdown("llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to place your GGUF in a subfolder of models/ with the necessary tokenizer files.\n\nYou can use the \"llamacpp_HF creator\" menu to do that automatically.")
|
shared.gradio['llamacpp_HF_info'] = gr.Markdown("llamacpp_HF loads llama.cpp as a Transformers model. To use it, you need to place your GGUF in a subfolder of models/ with the necessary tokenizer files.\n\nYou can use the \"llamacpp_HF creator\" menu to do that automatically.")
|
||||||
|
|
@ -147,15 +143,6 @@ def create_ui():
|
||||||
shared.gradio['download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
|
shared.gradio['download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
|
||||||
shared.gradio['get_file_list'] = gr.Button("Get file list", interactive=not mu)
|
shared.gradio['get_file_list'] = gr.Button("Get file list", interactive=not mu)
|
||||||
|
|
||||||
with gr.Tab("llamacpp_HF creator"):
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['gguf_menu'] = gr.Dropdown(choices=utils.get_available_ggufs(), value=lambda: shared.model_name, label='Choose your GGUF', elem_classes='slim-dropdown', interactive=not mu)
|
|
||||||
ui.create_refresh_button(shared.gradio['gguf_menu'], lambda: None, lambda: {'choices': utils.get_available_ggufs()}, 'refresh-button', interactive=not mu)
|
|
||||||
|
|
||||||
shared.gradio['unquantized_url'] = gr.Textbox(label="Enter the URL for the original (unquantized) model", info="Example: https://huggingface.co/lmsys/vicuna-13b-v1.5", max_lines=1)
|
|
||||||
shared.gradio['create_llamacpp_hf_button'] = gr.Button("Submit", variant="primary", interactive=not mu)
|
|
||||||
gr.Markdown("This will move your gguf file into a subfolder of `models` along with the necessary tokenizer files.")
|
|
||||||
|
|
||||||
with gr.Tab("Customize instruction template"):
|
with gr.Tab("Customize instruction template"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
|
shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
|
||||||
|
|
@ -195,7 +182,6 @@ def create_event_handlers():
|
||||||
shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||||
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||||
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model'))
|
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model'))
|
||||||
shared.gradio['create_llamacpp_hf_button'].click(create_llamacpp_hf, gradio('gguf_menu', 'unquantized_url'), gradio('model_status'), show_progress=True)
|
|
||||||
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
|
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -286,34 +272,11 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
||||||
yield traceback.format_exc().replace('\n', '\n\n')
|
yield traceback.format_exc().replace('\n', '\n\n')
|
||||||
|
|
||||||
|
|
||||||
def create_llamacpp_hf(gguf_name, unquantized_url, progress=gr.Progress()):
|
|
||||||
try:
|
|
||||||
downloader = importlib.import_module("download-model").ModelDownloader()
|
|
||||||
|
|
||||||
progress(0.0)
|
|
||||||
model, branch = downloader.sanitize_model_and_branch_names(unquantized_url, None)
|
|
||||||
|
|
||||||
yield ("Getting the tokenizer files links from Hugging Face")
|
|
||||||
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=True)
|
|
||||||
output_folder = Path(shared.args.model_dir) / (re.sub(r'(?i)\.gguf$', '', gguf_name) + "-HF")
|
|
||||||
|
|
||||||
yield (f"Downloading tokenizer to `{output_folder}/`")
|
|
||||||
downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=4, is_llamacpp=False)
|
|
||||||
|
|
||||||
# Move the GGUF
|
|
||||||
(Path(shared.args.model_dir) / gguf_name).rename(output_folder / gguf_name)
|
|
||||||
|
|
||||||
yield (f"Model saved to `{output_folder}/`.\n\nYou can now load it using llamacpp_HF.")
|
|
||||||
except:
|
|
||||||
progress(1.0)
|
|
||||||
yield traceback.format_exc().replace('\n', '\n\n')
|
|
||||||
|
|
||||||
|
|
||||||
def update_truncation_length(current_length, state):
|
def update_truncation_length(current_length, state):
|
||||||
if 'loader' in state:
|
if 'loader' in state:
|
||||||
if state['loader'].lower().startswith('exllama'):
|
if state['loader'].lower().startswith('exllama'):
|
||||||
return state['max_seq_len']
|
return state['max_seq_len']
|
||||||
elif state['loader'] in ['llama.cpp', 'llamacpp_HF']:
|
elif state['loader'] == 'llama.cpp':
|
||||||
return state['n_ctx']
|
return state['n_ctx']
|
||||||
|
|
||||||
return current_length
|
return current_length
|
||||||
|
|
|
||||||
|
|
@ -31,19 +31,9 @@ flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# llama-cpp-python (CPU only, AVX2)
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
|
||||||
|
|
||||||
# llama-cpp-python (CUDA, with GGML_CUDA_FORCE_MMQ)
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
|
||||||
|
|
||||||
# llama-cpp-python (CUDA, without GGML_CUDA_FORCE_MMQ)
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
|
||||||
|
|
||||||
# CUDA wheels
|
# CUDA wheels
|
||||||
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/textgen-webui/llama_cpp_binaries-0.1.0+cu124-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
||||||
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/textgen-webui/llama_cpp_binaries-0.1.0+cu124-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
||||||
|
|
|
||||||
|
|
@ -30,11 +30,7 @@ flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# llama-cpp-python (CPU only, AVX2)
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
|
||||||
|
|
||||||
# AMD wheels
|
# AMD wheels
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/rocm/llama_cpp_python_cuda-0.3.8+rocm6.1.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/rocm/llama_cpp_binaries-0.1.0+rocm6.1.2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64"
|
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64"
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,7 @@ flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# llama-cpp-python (CPU only, no AVX2)
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
|
||||||
|
|
||||||
# AMD wheels
|
# AMD wheels
|
||||||
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/rocm/llama_cpp_binaries-0.1.0+rocm6.1.2avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+rocm6.1.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64"
|
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl; platform_system != "Darwin" and platform_machine != "x86_64"
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# Mac wheels
|
# Mac wheels
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_15_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_14_0_x86_64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2-py3-none-any.whl
|
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2-py3-none-any.whl
|
||||||
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl
|
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# Mac wheels
|
# Mac wheels
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_15_0_arm64.whl; platform_system == "Darwin" and platform_release >= "24.0.0" and platform_release < "25.0.0" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_14_0_arm64.whl; platform_system == "Darwin" and platform_release >= "23.0.0" and platform_release < "24.0.0" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/metal/llama_cpp_python-0.3.8-cp311-cp311-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/metal/llama_cpp_binaries-0.1.0-cp311-cp311-macosx_13_0_arm64.whl; platform_system == "Darwin" and platform_release >= "22.0.0" and platform_release < "23.0.0" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2-py3-none-any.whl
|
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2-py3-none-any.whl
|
||||||
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl
|
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8-py3-none-any.whl
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,6 @@ flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# llama-cpp-python (CPU only, AVX2)
|
# llama.cpp (CPU only, AVX2)
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/cpu/llama_cpp_binaries-0.1.0+cpuavx2-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/cpu/llama_cpp_binaries-0.1.0+cpuavx2-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,6 @@ flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# llama-cpp-python (CPU only, no AVX2)
|
# llama.cpp (CPU only, no AVX2)
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/cpu/llama_cpp_binaries-0.1.0+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/cpu/llama_cpp_binaries-0.1.0+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
||||||
|
|
|
||||||
|
|
@ -31,19 +31,9 @@ flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
||||||
# llama-cpp-python (CPU only, no AVX2)
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/cpu/llama_cpp_python-0.3.8+cpuavx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
|
||||||
|
|
||||||
# llama-cpp-python (CUDA, with GGML_CUDA_FORCE_MMQ)
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda-0.3.8+cu124avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
|
||||||
|
|
||||||
# llama-cpp-python (CUDA, without GGML_CUDA_FORCE_MMQ)
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
|
||||||
https://github.com/oobabooga/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui/llama_cpp_python_cuda_tensorcores-0.3.8+cu124avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
|
||||||
|
|
||||||
# CUDA wheels
|
# CUDA wheels
|
||||||
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/textgen-webui/llama_cpp_binaries-0.1.0+cu124avx-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
||||||
|
https://github.com/oobabooga/llama-cpp-binaries/releases/download/textgen-webui/llama_cpp_binaries-0.1.0+cu124avx-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
https://github.com/oobabooga/exllamav3/releases/download/v0.0.1a2/exllamav3-0.0.1a2+cu124.torch2.6.0-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11"
|
||||||
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
https://github.com/oobabooga/exllamav2/releases/download/v0.2.8/exllamav2-0.2.8+cu124.torch2.6.0-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue