Fix the return value in MLX loader and add named constants for magic numbers

This commit is contained in:
SB Yoon 2025-08-08 23:20:23 -06:00
parent fe0bef40d2
commit 297fd7a67a

View file

@ -5,6 +5,10 @@ from pathlib import Path
import modules.shared as shared
from modules.logging_colors import logger
# Constants for MLX configuration
MLX_TOP_P_DISABLED = 0.0 # MLX expects 0.0 to disable top_p
DEFAULT_MAX_TOKENS = 512 # Default maximum tokens for generation
def is_apple_silicon():
"""Check if running on Apple Silicon"""
@ -46,7 +50,7 @@ class MLXModel:
instance.tokenizer = tokenizer
logger.info(f"Successfully loaded MLX model: {model_name}")
return instance, instance # Return model, tokenizer tuple for compatibility
return instance # Return instance for compatibility
except Exception as e:
error_msg = str(e)
@ -112,7 +116,7 @@ class MLXModel:
# Create the sampler
sampler = make_sampler(
temp=temperature,
top_p=top_p if top_p < 1.0 else 0.0, # MLX expects 0.0 to disable
top_p=top_p if top_p < 1.0 else MLX_TOP_P_DISABLED, # MLX expects 0.0 to disable
top_k=int(top_k) if top_k > 0 else 0,
min_p=min_p,
min_tokens_to_keep=1, # Always keep at least one token
@ -162,7 +166,7 @@ class MLXModel:
if 'max_new_tokens' in state and state['max_new_tokens'] > 0:
mlx_params['max_tokens'] = state['max_new_tokens']
else:
mlx_params['max_tokens'] = 512 # Default
mlx_params['max_tokens'] = DEFAULT_MAX_TOKENS # Default
# Create custom sampler with advanced parameters
sampler = self._create_mlx_sampler(state)