mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-25 10:00:50 +01:00
Fix model usage issue
This commit is contained in:
parent
365a997a7f
commit
25c8f1fda3
|
|
@ -39,7 +39,8 @@ class MLXModel:
|
|||
model_path = cls._resolve_model_path(model_name)
|
||||
|
||||
logger.info(f"Loading MLX model: {model_path}")
|
||||
model, tokenizer = load(model_path)
|
||||
tokenizer_config = {"trust_remote_code": True}
|
||||
model, tokenizer = load(model_path, tokenizer_config=tokenizer_config)
|
||||
|
||||
instance.model = model
|
||||
instance.tokenizer = tokenizer
|
||||
|
|
@ -48,8 +49,13 @@ class MLXModel:
|
|||
return instance, instance # Return model, tokenizer tuple for compatibility
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load MLX model {model_name}: {str(e)}")
|
||||
traceback.print_exc()
|
||||
error_msg = str(e)
|
||||
if "not supported" in error_msg.lower():
|
||||
logger.error(f"MLX model {model_name} uses an unsupported model type: {error_msg}")
|
||||
logger.error("Consider using a different loader or updating mlx-lm to a newer version")
|
||||
else:
|
||||
logger.error(f"Failed to load MLX model {model_name}: {error_msg}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -60,11 +66,18 @@ class MLXModel:
|
|||
if model_path.exists():
|
||||
# Local model path
|
||||
return str(model_path)
|
||||
elif model_name.startswith('mlx-community/'):
|
||||
# Already has mlx-community prefix
|
||||
elif '/' in model_name:
|
||||
# Already has repo/model format
|
||||
return model_name
|
||||
elif '_' in model_name and not model_name.startswith('_'):
|
||||
# Handle repo_name format - convert first underscore to slash
|
||||
# e.g., "mlx-community_model-name" -> "mlx-community/model-name"
|
||||
parts = model_name.split('_', 1)
|
||||
if len(parts) == 2:
|
||||
return f"{parts[0]}/{parts[1]}"
|
||||
return model_name
|
||||
else:
|
||||
# Try to find in mlx-community
|
||||
# Default to mlx-community for standalone model names
|
||||
return f"mlx-community/{model_name}"
|
||||
|
||||
def _create_mlx_sampler(self, state):
|
||||
|
|
@ -177,7 +190,7 @@ class MLXModel:
|
|||
if self.tokenizer.chat_template is not None:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
return formatted_prompt
|
||||
return prompt
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ def MLX_loader(model_name):
|
|||
|
||||
result = MLXModel.from_pretrained(model_name)
|
||||
if result is None:
|
||||
raise RuntimeError(f"Failed to load MLX model: {model_name}")
|
||||
raise RuntimeError(f"Failed to load MLX model: {model_name}. Check the logs above for specific error details.")
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue