diff --git a/modules/loaders.py b/modules/loaders.py index 6fbd2198..2bfd06da 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -94,6 +94,9 @@ loaders_and_params = OrderedDict({ 'ctx_size', 'cpp_runner', 'tensorrt_llm_info', + ], + 'MLX': [ + 'ctx_size', ] }) @@ -325,6 +328,26 @@ loaders_samplers = { 'presence_penalty', 'auto_max_new_tokens', 'ban_eos_token', + }, + 'MLX': { + 'temperature', + 'dynatemp_low', + 'dynatemp_high', + 'dynatemp_exponent', + 'top_p', + 'top_k', + 'min_p', + 'xtc_threshold', + 'xtc_probability', + 'repetition_penalty', + 'repetition_penalty_range', + 'auto_max_new_tokens', + 'ban_eos_token', + 'add_bos_token', + 'skip_special_tokens', + 'dynamic_temperature', + 'seed', + 'sampler_priority', } } diff --git a/modules/mlx_loader.py b/modules/mlx_loader.py new file mode 100644 index 00000000..46bac3c9 --- /dev/null +++ b/modules/mlx_loader.py @@ -0,0 +1,354 @@ +import platform +import traceback +from pathlib import Path + +import modules.shared as shared +from modules.logging_colors import logger + + +def is_apple_silicon(): + """Check if running on Apple Silicon""" + return platform.system() == "Darwin" and platform.machine() == "arm64" + + +class MLXModel: + def __init__(self): + self.model = None + self.tokenizer = None + self.model_name = None + + @classmethod + def from_pretrained(cls, model_name): + """Load MLX model from path or HuggingFace repository""" + + if not is_apple_silicon(): + logger.warning("MLX backend is only supported on Apple Silicon. Falling back to Transformers.") + return None + + try: + from mlx_lm import load + except ImportError: + logger.error("mlx-lm not found. Please install with: pip install mlx-lm") + return None + + instance = cls() + instance.model_name = model_name + + try: + # Determine the model path/name + model_path = cls._resolve_model_path(model_name) + + logger.info(f"Loading MLX model: {model_path}") + model, tokenizer = load(model_path) + + instance.model = model + instance.tokenizer = tokenizer + + logger.info(f"Successfully loaded MLX model: {model_name}") + return instance, instance # Return model, tokenizer tuple for compatibility + + except Exception as e: + logger.error(f"Failed to load MLX model {model_name}: {str(e)}") + traceback.print_exc() + return None + + @staticmethod + def _resolve_model_path(model_name): + """Resolve model path - either local path or HuggingFace repo""" + model_path = Path(f'{shared.args.model_dir}/{model_name}') + + if model_path.exists(): + # Local model path + return str(model_path) + elif model_name.startswith('mlx-community/'): + # Already has mlx-community prefix + return model_name + else: + # Try to find in mlx-community + return f"mlx-community/{model_name}" + + def _create_mlx_sampler(self, state): + """Create MLX sampler with webui parameters""" + try: + from mlx_lm.sample_utils import make_sampler + + # Extract sampling parameters from state + temperature = state.get('temperature', 1.0) + top_p = state.get('top_p', 1.0) + top_k = state.get('top_k', 0) # 0 means no top_k limit + min_p = state.get('min_p', 0.0) + + # Handle dynamic temperature + if state.get('dynamic_temperature', False): + temp_low = state.get('dynatemp_low', 1.0) + temp_high = state.get('dynatemp_high', 1.0) + temperature = (temp_low + temp_high) / 2 # Simple average for now + + # XTC sampling parameters + xtc_probability = state.get('xtc_probability', 0.0) + xtc_threshold = state.get('xtc_threshold', 0.1) + + # Ensure temperature is not zero (causes issues with MLX) + if temperature <= 0.0: + temperature = 0.01 + + # Log sampling parameters for debugging + if shared.args.verbose: + logger.info(f"MLX Sampler - temp: {temperature}, top_p: {top_p}, top_k: {top_k}, min_p: {min_p}") + + # Create the sampler + sampler = make_sampler( + temp=temperature, + top_p=top_p if top_p < 1.0 else 0.0, # MLX expects 0.0 to disable + top_k=int(top_k) if top_k > 0 else 0, + min_p=min_p, + min_tokens_to_keep=1, # Always keep at least one token + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + xtc_special_tokens=[] # Could be customized later + ) + + return sampler + + except ImportError: + logger.warning("MLX sampling utilities not available, using default sampler") + return None + except Exception as e: + logger.error(f"Failed to create MLX sampler: {e}") + return None + + def _create_logits_processors(self, state): + """Create logits processors for repetition penalty, etc.""" + processors = [] + + try: + from mlx_lm.sample_utils import make_repetition_penalty + + # Repetition penalty + repetition_penalty = state.get('repetition_penalty', 1.0) + if repetition_penalty != 1.0: + context_size = state.get('repetition_penalty_range', 1024) + rep_processor = make_repetition_penalty( + penalty=repetition_penalty, + context_size=context_size + ) + processors.append(rep_processor) + + except ImportError: + logger.warning("MLX repetition penalty not available") + except Exception as e: + logger.error(f"Failed to create repetition penalty processor: {e}") + + return processors if processors else None + + def _map_parameters(self, state): + """Map text-generation-webui parameters to MLX parameters""" + mlx_params = {} + + # Max tokens + if 'max_new_tokens' in state and state['max_new_tokens'] > 0: + mlx_params['max_tokens'] = state['max_new_tokens'] + else: + mlx_params['max_tokens'] = 512 # Default + + # Create custom sampler with advanced parameters + sampler = self._create_mlx_sampler(state) + if sampler: + mlx_params['sampler'] = sampler + + # Create logits processors + logits_processors = self._create_logits_processors(state) + if logits_processors: + mlx_params['logits_processors'] = logits_processors + + # Seed handling + seed = state.get('seed', -1) + if seed != -1: + try: + import mlx.core as mx + mx.random.seed(seed) + except Exception as e: + logger.warning(f"Failed to set MLX random seed: {e}") + + return mlx_params + + def _prepare_prompt(self, prompt): + """Prepare prompt with chat template if available""" + if self.tokenizer.chat_template is not None: + messages = [{"role": "user", "content": prompt}] + formatted_prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True + ) + return formatted_prompt + return prompt + + def generate(self, prompt, state): + """Non-streaming generation with advanced sampling""" + try: + from mlx_lm.generate import generate_step + import mlx.core as mx + except ImportError: + logger.error("mlx-lm not found. Please install with: pip install mlx-lm") + return "" + + if self.model is None or self.tokenizer is None: + logger.error("MLX model not loaded") + return "" + + try: + # Prepare the prompt + formatted_prompt = self._prepare_prompt(prompt) + + # Tokenize the prompt + prompt_tokens = self.tokenizer.encode(formatted_prompt) + prompt_array = mx.array(prompt_tokens) + + # Map parameters for MLX + mlx_params = self._map_parameters(state) + + # Remove max_tokens from params for generate_step + max_tokens = mlx_params.pop('max_tokens', 512) + + # Generate all tokens at once + generated_tokens = [] + + for token, logprobs in generate_step( + prompt_array, + self.model, + max_tokens=max_tokens, + **mlx_params + ): + # Handle both MLX arrays and direct integers + if hasattr(token, 'item'): + token_id = int(token.item()) + else: + token_id = int(token) + generated_tokens.append(token_id) + + # Check for stop conditions + if shared.stop_everything: + break + + # Decode all generated tokens + if generated_tokens: + response = self.tokenizer.decode(generated_tokens) + return response + else: + return "" + + except Exception as e: + logger.error(f"MLX generation failed: {str(e)}") + traceback.print_exc() + return "" + + def generate_with_streaming(self, prompt, state): + """True streaming generation using MLX generate_step""" + try: + from mlx_lm.generate import generate_step + import mlx.core as mx + except ImportError: + logger.error("mlx-lm not found. Please install with: pip install mlx-lm") + yield "" + return + + if self.model is None or self.tokenizer is None: + logger.error("MLX model not loaded") + yield "" + return + + try: + # Prepare the prompt + formatted_prompt = self._prepare_prompt(prompt) + + # Tokenize the prompt + prompt_tokens = self.tokenizer.encode(formatted_prompt) + prompt_array = mx.array(prompt_tokens) + + # Map parameters for MLX + mlx_params = self._map_parameters(state) + + # Remove max_tokens from params for generate_step (use different name) + max_tokens = mlx_params.pop('max_tokens', 512) + + # Initialize streaming generation + generated_tokens = [] + generated_text = "" + + # Use generate_step for true streaming + for token, logprobs in generate_step( + prompt_array, + self.model, + max_tokens=max_tokens, + **mlx_params + ): + # Handle both MLX arrays and direct integers + if hasattr(token, 'item'): + token_id = int(token.item()) + else: + token_id = int(token) + generated_tokens.append(token_id) + + # Decode the new token + try: + # Decode just the new token + new_text = self.tokenizer.decode([token_id]) + generated_text += new_text + + # Yield the accumulated text so far + yield generated_text + + except Exception as decode_error: + logger.warning(f"Failed to decode token {token_id}: {decode_error}") + continue + + # Check for stop conditions + if shared.stop_everything: + break + + # Final yield with complete text + if generated_text: + yield generated_text + + except Exception as e: + logger.error(f"MLX streaming generation failed: {str(e)}") + traceback.print_exc() + yield "" + + def encode(self, text, add_bos_token=False, **kwargs): + """Encode text to tokens""" + if self.tokenizer is None: + import torch + return torch.tensor([[]], dtype=torch.long) + + try: + # MLX tokenizer encode method + tokens = self.tokenizer.encode(text) + + # Convert to tensor format expected by webui + import torch + tokens_tensor = torch.tensor([tokens], dtype=torch.long) + return tokens_tensor + except Exception as e: + logger.error(f"MLX tokenization failed: {str(e)}") + # Return empty tensor on failure + import torch + return torch.tensor([[]], dtype=torch.long) + + def decode(self, token_ids, **kwargs): + """Decode tokens to text""" + if self.tokenizer is None: + return "" + + try: + # MLX tokenizer decode method + text = self.tokenizer.decode(token_ids) + return text + except Exception as e: + logger.error(f"MLX detokenization failed: {str(e)}") + return "" + + def unload(self): + """Unload the model to free memory""" + self.model = None + self.tokenizer = None + logger.info("MLX model unloaded") \ No newline at end of file diff --git a/modules/models.py b/modules/models.py index c1e7fb56..05ef7a7b 100644 --- a/modules/models.py +++ b/modules/models.py @@ -22,6 +22,7 @@ def load_model(model_name, loader=None): 'ExLlamav2_HF': ExLlamav2_HF_loader, 'ExLlamav2': ExLlamav2_loader, 'TensorRT-LLM': TensorRT_LLM_loader, + 'MLX': MLX_loader, } metadata = get_model_metadata(model_name) @@ -51,7 +52,7 @@ def load_model(model_name, loader=None): tokenizer = load_tokenizer(model_name) shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings}) - if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp': + if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp' or loader == 'MLX': shared.settings['truncation_length'] = shared.args.ctx_size logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.") @@ -111,6 +112,19 @@ def TensorRT_LLM_loader(model_name): return model +def MLX_loader(model_name): + try: + from modules.mlx_loader import MLXModel + except ModuleNotFoundError: + raise ModuleNotFoundError("Failed to import MLX loader. Please install mlx-lm: pip install mlx-lm") + + result = MLXModel.from_pretrained(model_name) + if result is None: + raise RuntimeError(f"Failed to load MLX model: {model_name}") + + return result + + def unload_model(keep_model_name=False): if shared.model is None: return @@ -118,6 +132,8 @@ def unload_model(keep_model_name=False): is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer') if shared.model.__class__.__name__ == 'Exllamav3HF': shared.model.unload() + elif shared.model.__class__.__name__ == 'MLXModel': + shared.model.unload() shared.model = shared.tokenizer = None shared.lora_names = [] diff --git a/modules/models_settings.py b/modules/models_settings.py index bea5b4d6..dfd9df3e 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -174,25 +174,34 @@ def get_model_metadata(model): def infer_loader(model_name, model_settings, hf_quant_method=None): - path_to_model = Path(f'{shared.args.model_dir}/{model_name}') - if not path_to_model.exists(): - loader = None - elif shared.args.portable: - loader = 'llama.cpp' - elif len(list(path_to_model.glob('*.gguf'))) > 0: - loader = 'llama.cpp' - elif re.match(r'.*\.gguf', model_name.lower()): - loader = 'llama.cpp' - elif hf_quant_method == 'exl3': - loader = 'ExLlamav3_HF' - elif hf_quant_method in ['exl2', 'gptq']: - loader = 'ExLlamav2_HF' - elif re.match(r'.*exl3', model_name.lower()): - loader = 'ExLlamav3_HF' - elif re.match(r'.*exl2', model_name.lower()): - loader = 'ExLlamav2_HF' + import platform + + # Check for MLX models first (before path checks) + if (model_name.startswith('mlx-community/') or model_name.startswith('mlx-community_')) and platform.system() == "Darwin" and platform.machine() == "arm64": + loader = 'MLX' + elif re.match(r'.*\.mlx', model_name.lower()) and platform.system() == "Darwin" and platform.machine() == "arm64": + loader = 'MLX' else: - loader = 'Transformers' + # Original logic for other loaders + path_to_model = Path(f'{shared.args.model_dir}/{model_name}') + if not path_to_model.exists(): + loader = None + elif shared.args.portable: + loader = 'llama.cpp' + elif len(list(path_to_model.glob('*.gguf'))) > 0: + loader = 'llama.cpp' + elif re.match(r'.*\.gguf', model_name.lower()): + loader = 'llama.cpp' + elif hf_quant_method == 'exl3': + loader = 'ExLlamav3_HF' + elif hf_quant_method in ['exl2', 'gptq']: + loader = 'ExLlamav2_HF' + elif re.match(r'.*exl3', model_name.lower()): + loader = 'ExLlamav3_HF' + elif re.match(r'.*exl2', model_name.lower()): + loader = 'ExLlamav2_HF' + else: + loader = 'Transformers' return loader diff --git a/modules/text_generation.py b/modules/text_generation.py index a75141f1..d4dfb123 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -40,7 +40,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap yield '' return - if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']: + if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel', 'MLXModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF @@ -153,7 +153,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] - if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu: + if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel', 'MLXModel'] or shared.args.cpu: return input_ids else: device = get_device() diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index e09e292e..2eeeeb3b 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -59,7 +59,7 @@ def create_ui(): shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.') shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Set trust_remote_code=True while loading the tokenizer/model. To enable this option, start the web UI with the --trust-remote-code flag.', interactive=shared.args.trust_remote_code) shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.') - + # Speculative decoding with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']: with gr.Row(): diff --git a/requirements/full/requirements_apple_silicon.txt b/requirements/full/requirements_apple_silicon.txt index e48d4dee..beb8d619 100644 --- a/requirements/full/requirements_apple_silicon.txt +++ b/requirements/full/requirements_apple_silicon.txt @@ -7,6 +7,7 @@ gradio==4.37.* html2text==2025.4.15 jinja2==3.1.6 markdown +mlx-lm>=0.26.3 numpy==2.2.* pandas peft==0.15.* diff --git a/requirements/portable/requirements_apple_silicon.txt b/requirements/portable/requirements_apple_silicon.txt index 9f403a0b..2e1017ad 100644 --- a/requirements/portable/requirements_apple_silicon.txt +++ b/requirements/portable/requirements_apple_silicon.txt @@ -3,6 +3,7 @@ gradio==4.37.* html2text==2025.4.15 jinja2==3.1.6 markdown +mlx-lm>=0.26.3 numpy==2.2.* pydantic==2.8.2 PyPDF2==3.0.1