diff --git a/modules/loaders.py b/modules/loaders.py index 0bf3781b..f81cfaeb 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -105,6 +105,9 @@ loaders_and_params = OrderedDict({ 'ctx_size', 'cpp_runner', 'tensorrt_llm_info', + ], + 'MLX': [ + 'ctx_size', ] }) @@ -359,6 +362,26 @@ loaders_samplers = { 'presence_penalty', 'auto_max_new_tokens', 'ban_eos_token', + }, + 'MLX': { + 'temperature', + 'dynatemp_low', + 'dynatemp_high', + 'dynatemp_exponent', + 'top_p', + 'top_k', + 'min_p', + 'xtc_threshold', + 'xtc_probability', + 'repetition_penalty', + 'repetition_penalty_range', + 'auto_max_new_tokens', + 'ban_eos_token', + 'add_bos_token', + 'skip_special_tokens', + 'dynamic_temperature', + 'seed', + 'sampler_priority', } } diff --git a/modules/mlx_loader.py b/modules/mlx_loader.py new file mode 100644 index 00000000..5cfc3059 --- /dev/null +++ b/modules/mlx_loader.py @@ -0,0 +1,371 @@ +import platform +import traceback +from pathlib import Path + +import modules.shared as shared +from modules.logging_colors import logger + +# Constants for MLX configuration +MLX_TOP_P_DISABLED = 0.0 # MLX expects 0.0 to disable top_p +DEFAULT_MAX_TOKENS = 512 # Default maximum tokens for generation + + +def is_apple_silicon(): + """Check if running on Apple Silicon""" + return platform.system() == "Darwin" and platform.machine() == "arm64" + + +class MLXModel: + def __init__(self): + self.model = None + self.tokenizer = None + self.model_name = None + + @classmethod + def from_pretrained(cls, model_name): + """Load MLX model from path or HuggingFace repository""" + + if not is_apple_silicon(): + logger.warning("MLX backend is only supported on Apple Silicon. Falling back to Transformers.") + return None + + try: + from mlx_lm import load + except ImportError: + logger.error("mlx-lm not found. Please install with: pip install mlx-lm") + return None + + instance = cls() + instance.model_name = model_name + + try: + # Determine the model path/name + model_path = cls._resolve_model_path(model_name) + + logger.info(f"Loading MLX model: {model_path}") + tokenizer_config = {"trust_remote_code": True} + model, tokenizer = load(model_path, tokenizer_config=tokenizer_config) + + instance.model = model + instance.tokenizer = tokenizer + + logger.info(f"Successfully loaded MLX model: {model_name}") + return instance # Return instance for compatibility + + except Exception as e: + error_msg = str(e) + if "not supported" in error_msg.lower(): + logger.error(f"MLX model {model_name} uses an unsupported model type: {error_msg}") + logger.error("Consider using a different loader or updating mlx-lm to a newer version") + else: + logger.error(f"Failed to load MLX model {model_name}: {error_msg}") + traceback.print_exc() + return None + + @staticmethod + def _resolve_model_path(model_name): + """Resolve model path - either local path or HuggingFace repo""" + model_path = Path(f'{shared.args.model_dir}/{model_name}') + + if model_path.exists(): + # Local model path + return str(model_path) + elif '/' in model_name: + # Already has repo/model format + return model_name + elif '_' in model_name and not model_name.startswith('_'): + # Handle repo_name format - convert first underscore to slash + # e.g., "mlx-community_model-name" -> "mlx-community/model-name" + parts = model_name.split('_', 1) + if len(parts) == 2: + return f"{parts[0]}/{parts[1]}" + return model_name + else: + # Default to mlx-community for standalone model names + return f"mlx-community/{model_name}" + + def _create_mlx_sampler(self, state): + """Create MLX sampler with webui parameters""" + try: + from mlx_lm.sample_utils import make_sampler + + # Extract sampling parameters from state + temperature = state.get('temperature', 1.0) + top_p = state.get('top_p', 1.0) + top_k = state.get('top_k', 0) # 0 means no top_k limit + min_p = state.get('min_p', 0.0) + + # Handle dynamic temperature + if state.get('dynamic_temperature', False): + temp_low = state.get('dynatemp_low', 1.0) + temp_high = state.get('dynatemp_high', 1.0) + temperature = (temp_low + temp_high) / 2 # Simple average for now + + # XTC sampling parameters + xtc_probability = state.get('xtc_probability', 0.0) + xtc_threshold = state.get('xtc_threshold', 0.1) + + # Ensure temperature is not zero (causes issues with MLX) + if temperature <= 0.0: + temperature = 0.01 + + # Log sampling parameters for debugging + if shared.args.verbose: + logger.info(f"MLX Sampler - temp: {temperature}, top_p: {top_p}, top_k: {top_k}, min_p: {min_p}") + + # Create the sampler + sampler = make_sampler( + temp=temperature, + top_p=top_p if top_p < 1.0 else MLX_TOP_P_DISABLED, # MLX expects 0.0 to disable + top_k=int(top_k) if top_k > 0 else 0, + min_p=min_p, + min_tokens_to_keep=1, # Always keep at least one token + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + xtc_special_tokens=[] # Could be customized later + ) + + return sampler + + except ImportError: + logger.warning("MLX sampling utilities not available, using default sampler") + return None + except Exception as e: + logger.error(f"Failed to create MLX sampler: {e}") + return None + + def _create_logits_processors(self, state): + """Create logits processors for repetition penalty, etc.""" + processors = [] + + try: + from mlx_lm.sample_utils import make_repetition_penalty + + # Repetition penalty + repetition_penalty = state.get('repetition_penalty', 1.0) + if repetition_penalty != 1.0: + context_size = state.get('repetition_penalty_range', 1024) + rep_processor = make_repetition_penalty( + penalty=repetition_penalty, + context_size=context_size + ) + processors.append(rep_processor) + + except ImportError: + logger.warning("MLX repetition penalty not available") + except Exception as e: + logger.error(f"Failed to create repetition penalty processor: {e}") + + return processors if processors else None + + def _map_parameters(self, state): + """Map text-generation-webui parameters to MLX parameters""" + mlx_params = {} + + # Max tokens + if 'max_new_tokens' in state and state['max_new_tokens'] > 0: + mlx_params['max_tokens'] = state['max_new_tokens'] + else: + mlx_params['max_tokens'] = DEFAULT_MAX_TOKENS # Default + + # Create custom sampler with advanced parameters + sampler = self._create_mlx_sampler(state) + if sampler: + mlx_params['sampler'] = sampler + + # Create logits processors + logits_processors = self._create_logits_processors(state) + if logits_processors: + mlx_params['logits_processors'] = logits_processors + + # Seed handling + seed = state.get('seed', -1) + if seed != -1: + try: + import mlx.core as mx + mx.random.seed(seed) + except Exception as e: + logger.warning(f"Failed to set MLX random seed: {e}") + + return mlx_params + + def _prepare_prompt(self, prompt): + """Prepare prompt with chat template if available""" + if self.tokenizer.chat_template is not None: + messages = [{"role": "user", "content": prompt}] + formatted_prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + return formatted_prompt + return prompt + + def generate(self, prompt, state): + """Non-streaming generation with advanced sampling""" + try: + from mlx_lm.generate import generate_step + import mlx.core as mx + except ImportError: + logger.error("mlx-lm not found. Please install with: pip install mlx-lm") + return "" + + if self.model is None or self.tokenizer is None: + logger.error("MLX model not loaded") + return "" + + try: + # Prepare the prompt + formatted_prompt = self._prepare_prompt(prompt) + + # Tokenize the prompt + prompt_tokens = self.tokenizer.encode(formatted_prompt) + prompt_array = mx.array(prompt_tokens) + + # Map parameters for MLX + mlx_params = self._map_parameters(state) + + # Remove max_tokens from params for generate_step + max_tokens = mlx_params.pop('max_tokens', 512) + + # Generate all tokens at once + generated_tokens = [] + + for token, logprobs in generate_step( + prompt_array, + self.model, + max_tokens=max_tokens, + **mlx_params + ): + # Handle both MLX arrays and direct integers + if hasattr(token, 'item'): + token_id = int(token.item()) + else: + token_id = int(token) + generated_tokens.append(token_id) + + # Check for stop conditions + if shared.stop_everything: + break + + # Decode all generated tokens + if generated_tokens: + response = self.tokenizer.decode(generated_tokens) + return response + else: + return "" + + except Exception as e: + logger.error(f"MLX generation failed: {str(e)}") + traceback.print_exc() + return "" + + def generate_with_streaming(self, prompt, state): + """True streaming generation using MLX generate_step""" + try: + from mlx_lm.generate import generate_step + import mlx.core as mx + except ImportError: + logger.error("mlx-lm not found. Please install with: pip install mlx-lm") + yield "" + return + + if self.model is None or self.tokenizer is None: + logger.error("MLX model not loaded") + yield "" + return + + try: + # Prepare the prompt + formatted_prompt = self._prepare_prompt(prompt) + + # Tokenize the prompt + prompt_tokens = self.tokenizer.encode(formatted_prompt) + prompt_array = mx.array(prompt_tokens) + + # Map parameters for MLX + mlx_params = self._map_parameters(state) + + # Remove max_tokens from params for generate_step (use different name) + max_tokens = mlx_params.pop('max_tokens', 512) + + # Initialize streaming generation + generated_tokens = [] + generated_text = "" + + # Use generate_step for true streaming + for token, logprobs in generate_step( + prompt_array, + self.model, + max_tokens=max_tokens, + **mlx_params + ): + # Handle both MLX arrays and direct integers + if hasattr(token, 'item'): + token_id = int(token.item()) + else: + token_id = int(token) + generated_tokens.append(token_id) + + # Decode the new token + try: + # Decode just the new token + new_text = self.tokenizer.decode([token_id]) + generated_text += new_text + + # Yield the accumulated text so far + yield generated_text + + except Exception as decode_error: + logger.warning(f"Failed to decode token {token_id}: {decode_error}") + continue + + # Check for stop conditions + if shared.stop_everything: + break + + # Final yield with complete text + if generated_text: + yield generated_text + + except Exception as e: + logger.error(f"MLX streaming generation failed: {str(e)}") + traceback.print_exc() + yield "" + + def encode(self, text, add_bos_token=False, **kwargs): + """Encode text to tokens""" + if self.tokenizer is None: + import torch + return torch.tensor([[]], dtype=torch.long) + + try: + # MLX tokenizer encode method + tokens = self.tokenizer.encode(text) + + # Convert to tensor format expected by webui + import torch + tokens_tensor = torch.tensor([tokens], dtype=torch.long) + return tokens_tensor + except Exception as e: + logger.error(f"MLX tokenization failed: {str(e)}") + # Return empty tensor on failure + import torch + return torch.tensor([[]], dtype=torch.long) + + def decode(self, token_ids, **kwargs): + """Decode tokens to text""" + if self.tokenizer is None: + return "" + + try: + # MLX tokenizer decode method + text = self.tokenizer.decode(token_ids) + return text + except Exception as e: + logger.error(f"MLX detokenization failed: {str(e)}") + return "" + + def unload(self): + """Unload the model to free memory""" + self.model = None + self.tokenizer = None + logger.info("MLX model unloaded") \ No newline at end of file diff --git a/modules/models.py b/modules/models.py index 8c0f1c37..6b8d6de3 100644 --- a/modules/models.py +++ b/modules/models.py @@ -23,6 +23,7 @@ def load_model(model_name, loader=None): 'ExLlamav2_HF': ExLlamav2_HF_loader, 'ExLlamav2': ExLlamav2_loader, 'TensorRT-LLM': TensorRT_LLM_loader, + 'MLX': MLX_loader, } metadata = get_model_metadata(model_name) @@ -53,7 +54,7 @@ def load_model(model_name, loader=None): return None, None shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings}) - if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp': + if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp' or loader == 'MLX': shared.settings['truncation_length'] = shared.args.ctx_size shared.is_multimodal = False @@ -131,6 +132,19 @@ def TensorRT_LLM_loader(model_name): return model +def MLX_loader(model_name): + try: + from modules.mlx_loader import MLXModel + except ModuleNotFoundError: + raise ModuleNotFoundError("Failed to import MLX loader. Please install mlx-lm: pip install mlx-lm") + + result = MLXModel.from_pretrained(model_name) + if result is None: + raise RuntimeError(f"Failed to load MLX model: {model_name}. Check the logs above for specific error details.") + + return result + + def unload_model(keep_model_name=False): if shared.model is None: return @@ -142,6 +156,8 @@ def unload_model(keep_model_name=False): shared.model.unload() elif model_class_name in ['Exllamav2Model', 'Exllamav2HF'] and hasattr(shared.model, 'unload'): shared.model.unload() + elif shared.model.__class__.__name__ == 'MLXModel': + shared.model.unload() shared.model = shared.tokenizer = None shared.lora_names = [] diff --git a/modules/models_settings.py b/modules/models_settings.py index 6dc000b4..0741b3f5 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -208,6 +208,12 @@ def infer_loader(model_name, model_settings, hf_quant_method=None): loader = 'llama.cpp' elif re.match(r'.*\.gguf', model_name.lower()): loader = 'llama.cpp' + elif hf_quant_method == 'mlx': + loader = 'MLX' + elif re.match(r'.*\.mlx', model_name.lower()): + loader = 'MLX' + elif model_name.lower().startswith('mlx-community'): + loader = 'MLX' elif hf_quant_method == 'exl3': loader = 'ExLlamav3' elif hf_quant_method in ['exl2', 'gptq']: diff --git a/modules/text_generation.py b/modules/text_generation.py index 27c5de7d..abedbe67 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -40,7 +40,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield '' return - if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel']: + if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel', 'MLXModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF @@ -148,7 +148,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['Exllamav2Model', 'Exllamav3Model', 'TensorRTLLMModel', 'MLXModel'] or shared.args.cpu: return input_ids else: device = get_device() diff --git a/requirements/full/requirements_apple_silicon.txt b/requirements/full/requirements_apple_silicon.txt index 6a7ce8a6..b56fe527 100644 --- a/requirements/full/requirements_apple_silicon.txt +++ b/requirements/full/requirements_apple_silicon.txt @@ -8,6 +8,7 @@ html2text==2025.4.15 huggingface-hub==0.36.0 jinja2==3.1.6 markdown +mlx-lm>=0.26.3 numpy==2.2.* pandas peft==0.18.* diff --git a/requirements/portable/requirements_apple_silicon.txt b/requirements/portable/requirements_apple_silicon.txt index e480db8f..2706af7e 100644 --- a/requirements/portable/requirements_apple_silicon.txt +++ b/requirements/portable/requirements_apple_silicon.txt @@ -4,6 +4,7 @@ html2text==2025.4.15 huggingface-hub==0.36.0 jinja2==3.1.6 markdown +mlx-lm>=0.26.3 numpy==2.2.* pydantic==2.11.0 PyPDF2==3.0.1