mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add MLX support
This commit is contained in:
parent
6338dc0051
commit
365a997a7f
|
|
@ -94,6 +94,9 @@ loaders_and_params = OrderedDict({
|
||||||
'ctx_size',
|
'ctx_size',
|
||||||
'cpp_runner',
|
'cpp_runner',
|
||||||
'tensorrt_llm_info',
|
'tensorrt_llm_info',
|
||||||
|
],
|
||||||
|
'MLX': [
|
||||||
|
'ctx_size',
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -325,6 +328,26 @@ loaders_samplers = {
|
||||||
'presence_penalty',
|
'presence_penalty',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
},
|
||||||
|
'MLX': {
|
||||||
|
'temperature',
|
||||||
|
'dynatemp_low',
|
||||||
|
'dynatemp_high',
|
||||||
|
'dynatemp_exponent',
|
||||||
|
'top_p',
|
||||||
|
'top_k',
|
||||||
|
'min_p',
|
||||||
|
'xtc_threshold',
|
||||||
|
'xtc_probability',
|
||||||
|
'repetition_penalty',
|
||||||
|
'repetition_penalty_range',
|
||||||
|
'auto_max_new_tokens',
|
||||||
|
'ban_eos_token',
|
||||||
|
'add_bos_token',
|
||||||
|
'skip_special_tokens',
|
||||||
|
'dynamic_temperature',
|
||||||
|
'seed',
|
||||||
|
'sampler_priority',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
354
modules/mlx_loader.py
Normal file
354
modules/mlx_loader.py
Normal file
|
|
@ -0,0 +1,354 @@
|
||||||
|
import platform
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
|
def is_apple_silicon():
|
||||||
|
"""Check if running on Apple Silicon"""
|
||||||
|
return platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||||
|
|
||||||
|
|
||||||
|
class MLXModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.model_name = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name):
|
||||||
|
"""Load MLX model from path or HuggingFace repository"""
|
||||||
|
|
||||||
|
if not is_apple_silicon():
|
||||||
|
logger.warning("MLX backend is only supported on Apple Silicon. Falling back to Transformers.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mlx_lm import load
|
||||||
|
except ImportError:
|
||||||
|
logger.error("mlx-lm not found. Please install with: pip install mlx-lm")
|
||||||
|
return None
|
||||||
|
|
||||||
|
instance = cls()
|
||||||
|
instance.model_name = model_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Determine the model path/name
|
||||||
|
model_path = cls._resolve_model_path(model_name)
|
||||||
|
|
||||||
|
logger.info(f"Loading MLX model: {model_path}")
|
||||||
|
model, tokenizer = load(model_path)
|
||||||
|
|
||||||
|
instance.model = model
|
||||||
|
instance.tokenizer = tokenizer
|
||||||
|
|
||||||
|
logger.info(f"Successfully loaded MLX model: {model_name}")
|
||||||
|
return instance, instance # Return model, tokenizer tuple for compatibility
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load MLX model {model_name}: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_model_path(model_name):
|
||||||
|
"""Resolve model path - either local path or HuggingFace repo"""
|
||||||
|
model_path = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
|
|
||||||
|
if model_path.exists():
|
||||||
|
# Local model path
|
||||||
|
return str(model_path)
|
||||||
|
elif model_name.startswith('mlx-community/'):
|
||||||
|
# Already has mlx-community prefix
|
||||||
|
return model_name
|
||||||
|
else:
|
||||||
|
# Try to find in mlx-community
|
||||||
|
return f"mlx-community/{model_name}"
|
||||||
|
|
||||||
|
def _create_mlx_sampler(self, state):
|
||||||
|
"""Create MLX sampler with webui parameters"""
|
||||||
|
try:
|
||||||
|
from mlx_lm.sample_utils import make_sampler
|
||||||
|
|
||||||
|
# Extract sampling parameters from state
|
||||||
|
temperature = state.get('temperature', 1.0)
|
||||||
|
top_p = state.get('top_p', 1.0)
|
||||||
|
top_k = state.get('top_k', 0) # 0 means no top_k limit
|
||||||
|
min_p = state.get('min_p', 0.0)
|
||||||
|
|
||||||
|
# Handle dynamic temperature
|
||||||
|
if state.get('dynamic_temperature', False):
|
||||||
|
temp_low = state.get('dynatemp_low', 1.0)
|
||||||
|
temp_high = state.get('dynatemp_high', 1.0)
|
||||||
|
temperature = (temp_low + temp_high) / 2 # Simple average for now
|
||||||
|
|
||||||
|
# XTC sampling parameters
|
||||||
|
xtc_probability = state.get('xtc_probability', 0.0)
|
||||||
|
xtc_threshold = state.get('xtc_threshold', 0.1)
|
||||||
|
|
||||||
|
# Ensure temperature is not zero (causes issues with MLX)
|
||||||
|
if temperature <= 0.0:
|
||||||
|
temperature = 0.01
|
||||||
|
|
||||||
|
# Log sampling parameters for debugging
|
||||||
|
if shared.args.verbose:
|
||||||
|
logger.info(f"MLX Sampler - temp: {temperature}, top_p: {top_p}, top_k: {top_k}, min_p: {min_p}")
|
||||||
|
|
||||||
|
# Create the sampler
|
||||||
|
sampler = make_sampler(
|
||||||
|
temp=temperature,
|
||||||
|
top_p=top_p if top_p < 1.0 else 0.0, # MLX expects 0.0 to disable
|
||||||
|
top_k=int(top_k) if top_k > 0 else 0,
|
||||||
|
min_p=min_p,
|
||||||
|
min_tokens_to_keep=1, # Always keep at least one token
|
||||||
|
xtc_probability=xtc_probability,
|
||||||
|
xtc_threshold=xtc_threshold,
|
||||||
|
xtc_special_tokens=[] # Could be customized later
|
||||||
|
)
|
||||||
|
|
||||||
|
return sampler
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("MLX sampling utilities not available, using default sampler")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create MLX sampler: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_logits_processors(self, state):
|
||||||
|
"""Create logits processors for repetition penalty, etc."""
|
||||||
|
processors = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mlx_lm.sample_utils import make_repetition_penalty
|
||||||
|
|
||||||
|
# Repetition penalty
|
||||||
|
repetition_penalty = state.get('repetition_penalty', 1.0)
|
||||||
|
if repetition_penalty != 1.0:
|
||||||
|
context_size = state.get('repetition_penalty_range', 1024)
|
||||||
|
rep_processor = make_repetition_penalty(
|
||||||
|
penalty=repetition_penalty,
|
||||||
|
context_size=context_size
|
||||||
|
)
|
||||||
|
processors.append(rep_processor)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("MLX repetition penalty not available")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create repetition penalty processor: {e}")
|
||||||
|
|
||||||
|
return processors if processors else None
|
||||||
|
|
||||||
|
def _map_parameters(self, state):
|
||||||
|
"""Map text-generation-webui parameters to MLX parameters"""
|
||||||
|
mlx_params = {}
|
||||||
|
|
||||||
|
# Max tokens
|
||||||
|
if 'max_new_tokens' in state and state['max_new_tokens'] > 0:
|
||||||
|
mlx_params['max_tokens'] = state['max_new_tokens']
|
||||||
|
else:
|
||||||
|
mlx_params['max_tokens'] = 512 # Default
|
||||||
|
|
||||||
|
# Create custom sampler with advanced parameters
|
||||||
|
sampler = self._create_mlx_sampler(state)
|
||||||
|
if sampler:
|
||||||
|
mlx_params['sampler'] = sampler
|
||||||
|
|
||||||
|
# Create logits processors
|
||||||
|
logits_processors = self._create_logits_processors(state)
|
||||||
|
if logits_processors:
|
||||||
|
mlx_params['logits_processors'] = logits_processors
|
||||||
|
|
||||||
|
# Seed handling
|
||||||
|
seed = state.get('seed', -1)
|
||||||
|
if seed != -1:
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
mx.random.seed(seed)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to set MLX random seed: {e}")
|
||||||
|
|
||||||
|
return mlx_params
|
||||||
|
|
||||||
|
def _prepare_prompt(self, prompt):
|
||||||
|
"""Prepare prompt with chat template if available"""
|
||||||
|
if self.tokenizer.chat_template is not None:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return formatted_prompt
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def generate(self, prompt, state):
|
||||||
|
"""Non-streaming generation with advanced sampling"""
|
||||||
|
try:
|
||||||
|
from mlx_lm.generate import generate_step
|
||||||
|
import mlx.core as mx
|
||||||
|
except ImportError:
|
||||||
|
logger.error("mlx-lm not found. Please install with: pip install mlx-lm")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if self.model is None or self.tokenizer is None:
|
||||||
|
logger.error("MLX model not loaded")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare the prompt
|
||||||
|
formatted_prompt = self._prepare_prompt(prompt)
|
||||||
|
|
||||||
|
# Tokenize the prompt
|
||||||
|
prompt_tokens = self.tokenizer.encode(formatted_prompt)
|
||||||
|
prompt_array = mx.array(prompt_tokens)
|
||||||
|
|
||||||
|
# Map parameters for MLX
|
||||||
|
mlx_params = self._map_parameters(state)
|
||||||
|
|
||||||
|
# Remove max_tokens from params for generate_step
|
||||||
|
max_tokens = mlx_params.pop('max_tokens', 512)
|
||||||
|
|
||||||
|
# Generate all tokens at once
|
||||||
|
generated_tokens = []
|
||||||
|
|
||||||
|
for token, logprobs in generate_step(
|
||||||
|
prompt_array,
|
||||||
|
self.model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**mlx_params
|
||||||
|
):
|
||||||
|
# Handle both MLX arrays and direct integers
|
||||||
|
if hasattr(token, 'item'):
|
||||||
|
token_id = int(token.item())
|
||||||
|
else:
|
||||||
|
token_id = int(token)
|
||||||
|
generated_tokens.append(token_id)
|
||||||
|
|
||||||
|
# Check for stop conditions
|
||||||
|
if shared.stop_everything:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Decode all generated tokens
|
||||||
|
if generated_tokens:
|
||||||
|
response = self.tokenizer.decode(generated_tokens)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MLX generation failed: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def generate_with_streaming(self, prompt, state):
|
||||||
|
"""True streaming generation using MLX generate_step"""
|
||||||
|
try:
|
||||||
|
from mlx_lm.generate import generate_step
|
||||||
|
import mlx.core as mx
|
||||||
|
except ImportError:
|
||||||
|
logger.error("mlx-lm not found. Please install with: pip install mlx-lm")
|
||||||
|
yield ""
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.model is None or self.tokenizer is None:
|
||||||
|
logger.error("MLX model not loaded")
|
||||||
|
yield ""
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare the prompt
|
||||||
|
formatted_prompt = self._prepare_prompt(prompt)
|
||||||
|
|
||||||
|
# Tokenize the prompt
|
||||||
|
prompt_tokens = self.tokenizer.encode(formatted_prompt)
|
||||||
|
prompt_array = mx.array(prompt_tokens)
|
||||||
|
|
||||||
|
# Map parameters for MLX
|
||||||
|
mlx_params = self._map_parameters(state)
|
||||||
|
|
||||||
|
# Remove max_tokens from params for generate_step (use different name)
|
||||||
|
max_tokens = mlx_params.pop('max_tokens', 512)
|
||||||
|
|
||||||
|
# Initialize streaming generation
|
||||||
|
generated_tokens = []
|
||||||
|
generated_text = ""
|
||||||
|
|
||||||
|
# Use generate_step for true streaming
|
||||||
|
for token, logprobs in generate_step(
|
||||||
|
prompt_array,
|
||||||
|
self.model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**mlx_params
|
||||||
|
):
|
||||||
|
# Handle both MLX arrays and direct integers
|
||||||
|
if hasattr(token, 'item'):
|
||||||
|
token_id = int(token.item())
|
||||||
|
else:
|
||||||
|
token_id = int(token)
|
||||||
|
generated_tokens.append(token_id)
|
||||||
|
|
||||||
|
# Decode the new token
|
||||||
|
try:
|
||||||
|
# Decode just the new token
|
||||||
|
new_text = self.tokenizer.decode([token_id])
|
||||||
|
generated_text += new_text
|
||||||
|
|
||||||
|
# Yield the accumulated text so far
|
||||||
|
yield generated_text
|
||||||
|
|
||||||
|
except Exception as decode_error:
|
||||||
|
logger.warning(f"Failed to decode token {token_id}: {decode_error}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for stop conditions
|
||||||
|
if shared.stop_everything:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Final yield with complete text
|
||||||
|
if generated_text:
|
||||||
|
yield generated_text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MLX streaming generation failed: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
yield ""
|
||||||
|
|
||||||
|
def encode(self, text, add_bos_token=False, **kwargs):
|
||||||
|
"""Encode text to tokens"""
|
||||||
|
if self.tokenizer is None:
|
||||||
|
import torch
|
||||||
|
return torch.tensor([[]], dtype=torch.long)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# MLX tokenizer encode method
|
||||||
|
tokens = self.tokenizer.encode(text)
|
||||||
|
|
||||||
|
# Convert to tensor format expected by webui
|
||||||
|
import torch
|
||||||
|
tokens_tensor = torch.tensor([tokens], dtype=torch.long)
|
||||||
|
return tokens_tensor
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MLX tokenization failed: {str(e)}")
|
||||||
|
# Return empty tensor on failure
|
||||||
|
import torch
|
||||||
|
return torch.tensor([[]], dtype=torch.long)
|
||||||
|
|
||||||
|
def decode(self, token_ids, **kwargs):
|
||||||
|
"""Decode tokens to text"""
|
||||||
|
if self.tokenizer is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# MLX tokenizer decode method
|
||||||
|
text = self.tokenizer.decode(token_ids)
|
||||||
|
return text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MLX detokenization failed: {str(e)}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def unload(self):
|
||||||
|
"""Unload the model to free memory"""
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
logger.info("MLX model unloaded")
|
||||||
|
|
@ -22,6 +22,7 @@ def load_model(model_name, loader=None):
|
||||||
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
'ExLlamav2_HF': ExLlamav2_HF_loader,
|
||||||
'ExLlamav2': ExLlamav2_loader,
|
'ExLlamav2': ExLlamav2_loader,
|
||||||
'TensorRT-LLM': TensorRT_LLM_loader,
|
'TensorRT-LLM': TensorRT_LLM_loader,
|
||||||
|
'MLX': MLX_loader,
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata = get_model_metadata(model_name)
|
metadata = get_model_metadata(model_name)
|
||||||
|
|
@ -51,7 +52,7 @@ def load_model(model_name, loader=None):
|
||||||
tokenizer = load_tokenizer(model_name)
|
tokenizer = load_tokenizer(model_name)
|
||||||
|
|
||||||
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
||||||
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
|
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp' or loader == 'MLX':
|
||||||
shared.settings['truncation_length'] = shared.args.ctx_size
|
shared.settings['truncation_length'] = shared.args.ctx_size
|
||||||
|
|
||||||
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
|
||||||
|
|
@ -111,6 +112,19 @@ def TensorRT_LLM_loader(model_name):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def MLX_loader(model_name):
|
||||||
|
try:
|
||||||
|
from modules.mlx_loader import MLXModel
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError("Failed to import MLX loader. Please install mlx-lm: pip install mlx-lm")
|
||||||
|
|
||||||
|
result = MLXModel.from_pretrained(model_name)
|
||||||
|
if result is None:
|
||||||
|
raise RuntimeError(f"Failed to load MLX model: {model_name}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def unload_model(keep_model_name=False):
|
def unload_model(keep_model_name=False):
|
||||||
if shared.model is None:
|
if shared.model is None:
|
||||||
return
|
return
|
||||||
|
|
@ -118,6 +132,8 @@ def unload_model(keep_model_name=False):
|
||||||
is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer')
|
is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer')
|
||||||
if shared.model.__class__.__name__ == 'Exllamav3HF':
|
if shared.model.__class__.__name__ == 'Exllamav3HF':
|
||||||
shared.model.unload()
|
shared.model.unload()
|
||||||
|
elif shared.model.__class__.__name__ == 'MLXModel':
|
||||||
|
shared.model.unload()
|
||||||
|
|
||||||
shared.model = shared.tokenizer = None
|
shared.model = shared.tokenizer = None
|
||||||
shared.lora_names = []
|
shared.lora_names = []
|
||||||
|
|
|
||||||
|
|
@ -174,25 +174,34 @@ def get_model_metadata(model):
|
||||||
|
|
||||||
|
|
||||||
def infer_loader(model_name, model_settings, hf_quant_method=None):
|
def infer_loader(model_name, model_settings, hf_quant_method=None):
|
||||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
import platform
|
||||||
if not path_to_model.exists():
|
|
||||||
loader = None
|
# Check for MLX models first (before path checks)
|
||||||
elif shared.args.portable:
|
if (model_name.startswith('mlx-community/') or model_name.startswith('mlx-community_')) and platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||||
loader = 'llama.cpp'
|
loader = 'MLX'
|
||||||
elif len(list(path_to_model.glob('*.gguf'))) > 0:
|
elif re.match(r'.*\.mlx', model_name.lower()) and platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||||
loader = 'llama.cpp'
|
loader = 'MLX'
|
||||||
elif re.match(r'.*\.gguf', model_name.lower()):
|
|
||||||
loader = 'llama.cpp'
|
|
||||||
elif hf_quant_method == 'exl3':
|
|
||||||
loader = 'ExLlamav3_HF'
|
|
||||||
elif hf_quant_method in ['exl2', 'gptq']:
|
|
||||||
loader = 'ExLlamav2_HF'
|
|
||||||
elif re.match(r'.*exl3', model_name.lower()):
|
|
||||||
loader = 'ExLlamav3_HF'
|
|
||||||
elif re.match(r'.*exl2', model_name.lower()):
|
|
||||||
loader = 'ExLlamav2_HF'
|
|
||||||
else:
|
else:
|
||||||
loader = 'Transformers'
|
# Original logic for other loaders
|
||||||
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
|
if not path_to_model.exists():
|
||||||
|
loader = None
|
||||||
|
elif shared.args.portable:
|
||||||
|
loader = 'llama.cpp'
|
||||||
|
elif len(list(path_to_model.glob('*.gguf'))) > 0:
|
||||||
|
loader = 'llama.cpp'
|
||||||
|
elif re.match(r'.*\.gguf', model_name.lower()):
|
||||||
|
loader = 'llama.cpp'
|
||||||
|
elif hf_quant_method == 'exl3':
|
||||||
|
loader = 'ExLlamav3_HF'
|
||||||
|
elif hf_quant_method in ['exl2', 'gptq']:
|
||||||
|
loader = 'ExLlamav2_HF'
|
||||||
|
elif re.match(r'.*exl3', model_name.lower()):
|
||||||
|
loader = 'ExLlamav3_HF'
|
||||||
|
elif re.match(r'.*exl2', model_name.lower()):
|
||||||
|
loader = 'ExLlamav2_HF'
|
||||||
|
else:
|
||||||
|
loader = 'Transformers'
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||||
yield ''
|
yield ''
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']:
|
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel', 'MLXModel']:
|
||||||
generate_func = generate_reply_custom
|
generate_func = generate_reply_custom
|
||||||
else:
|
else:
|
||||||
generate_func = generate_reply_HF
|
generate_func = generate_reply_HF
|
||||||
|
|
@ -153,7 +153,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
||||||
if truncation_length is not None:
|
if truncation_length is not None:
|
||||||
input_ids = input_ids[:, -truncation_length:]
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
|
|
||||||
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel', 'MLXModel'] or shared.args.cpu:
|
||||||
return input_ids
|
return input_ids
|
||||||
else:
|
else:
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ def create_ui():
|
||||||
shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
|
shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
|
||||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code)
|
||||||
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ gradio==4.37.*
|
||||||
html2text==2025.4.15
|
html2text==2025.4.15
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
|
mlx-lm>=0.26.3
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pandas
|
pandas
|
||||||
peft==0.15.*
|
peft==0.15.*
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ gradio==4.37.*
|
||||||
html2text==2025.4.15
|
html2text==2025.4.15
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
|
mlx-lm>=0.26.3
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
pydantic==2.8.2
|
pydantic==2.8.2
|
||||||
PyPDF2==3.0.1
|
PyPDF2==3.0.1
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue