diff --git a/modules/mlx_loader.py b/modules/mlx_loader.py index 46bac3c9..80aa987a 100644 --- a/modules/mlx_loader.py +++ b/modules/mlx_loader.py @@ -39,7 +39,8 @@ class MLXModel: model_path = cls._resolve_model_path(model_name) logger.info(f"Loading MLX model: {model_path}") - model, tokenizer = load(model_path) + tokenizer_config = {"trust_remote_code": True} + model, tokenizer = load(model_path, tokenizer_config=tokenizer_config) instance.model = model instance.tokenizer = tokenizer @@ -48,8 +49,13 @@ class MLXModel: return instance, instance # Return model, tokenizer tuple for compatibility except Exception as e: - logger.error(f"Failed to load MLX model {model_name}: {str(e)}") - traceback.print_exc() + error_msg = str(e) + if "not supported" in error_msg.lower(): + logger.error(f"MLX model {model_name} uses an unsupported model type: {error_msg}") + logger.error("Consider using a different loader or updating mlx-lm to a newer version") + else: + logger.error(f"Failed to load MLX model {model_name}: {error_msg}") + traceback.print_exc() return None @staticmethod @@ -60,11 +66,18 @@ class MLXModel: if model_path.exists(): # Local model path return str(model_path) - elif model_name.startswith('mlx-community/'): - # Already has mlx-community prefix + elif '/' in model_name: + # Already has repo/model format + return model_name + elif '_' in model_name and not model_name.startswith('_'): + # Handle repo_name format - convert first underscore to slash + # e.g., "mlx-community_model-name" -> "mlx-community/model-name" + parts = model_name.split('_', 1) + if len(parts) == 2: + return f"{parts[0]}/{parts[1]}" return model_name else: - # Try to find in mlx-community + # Default to mlx-community for standalone model names return f"mlx-community/{model_name}" def _create_mlx_sampler(self, state): @@ -177,7 +190,7 @@ class MLXModel: if self.tokenizer.chat_template is not None: messages = [{"role": "user", "content": prompt}] formatted_prompt = self.tokenizer.apply_chat_template( - messages, add_generation_prompt=True + messages, add_generation_prompt=True, tokenize=False ) return formatted_prompt return prompt diff --git a/modules/models.py b/modules/models.py index 05ef7a7b..7b6a6ce1 100644 --- a/modules/models.py +++ b/modules/models.py @@ -120,7 +120,7 @@ def MLX_loader(model_name): result = MLXModel.from_pretrained(model_name) if result is None: - raise RuntimeError(f"Failed to load MLX model: {model_name}") + raise RuntimeError(f"Failed to load MLX model: {model_name}. Check the logs above for specific error details.") return result