mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-05 16:20:04 +01:00
Fix the return value in MLX loader and add named constants for magic numbers
This commit is contained in:
parent
fe0bef40d2
commit
297fd7a67a
|
|
@ -5,6 +5,10 @@ from pathlib import Path
|
|||
import modules.shared as shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
# Constants for MLX configuration
|
||||
MLX_TOP_P_DISABLED = 0.0 # MLX expects 0.0 to disable top_p
|
||||
DEFAULT_MAX_TOKENS = 512 # Default maximum tokens for generation
|
||||
|
||||
|
||||
def is_apple_silicon():
|
||||
"""Check if running on Apple Silicon"""
|
||||
|
|
@ -46,7 +50,7 @@ class MLXModel:
|
|||
instance.tokenizer = tokenizer
|
||||
|
||||
logger.info(f"Successfully loaded MLX model: {model_name}")
|
||||
return instance, instance # Return model, tokenizer tuple for compatibility
|
||||
return instance # Return instance for compatibility
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
|
@ -112,7 +116,7 @@ class MLXModel:
|
|||
# Create the sampler
|
||||
sampler = make_sampler(
|
||||
temp=temperature,
|
||||
top_p=top_p if top_p < 1.0 else 0.0, # MLX expects 0.0 to disable
|
||||
top_p=top_p if top_p < 1.0 else MLX_TOP_P_DISABLED, # MLX expects 0.0 to disable
|
||||
top_k=int(top_k) if top_k > 0 else 0,
|
||||
min_p=min_p,
|
||||
min_tokens_to_keep=1, # Always keep at least one token
|
||||
|
|
@ -162,7 +166,7 @@ class MLXModel:
|
|||
if 'max_new_tokens' in state and state['max_new_tokens'] > 0:
|
||||
mlx_params['max_tokens'] = state['max_new_tokens']
|
||||
else:
|
||||
mlx_params['max_tokens'] = 512 # Default
|
||||
mlx_params['max_tokens'] = DEFAULT_MAX_TOKENS # Default
|
||||
|
||||
# Create custom sampler with advanced parameters
|
||||
sampler = self._create_mlx_sampler(state)
|
||||
|
|
|
|||
Loading…
Reference in a new issue