diff --git a/modules/api/embeddings.py b/modules/api/embeddings.py index 16cf0482..17e595fb 100644 --- a/modules/api/embeddings.py +++ b/modules/api/embeddings.py @@ -6,6 +6,7 @@ from transformers import AutoModel from .errors import ServiceUnavailableError from .utils import debug_msg, float_list_to_base64 from modules.logging_colors import logger +from modules import shared embeddings_params_initialized = False @@ -41,7 +42,7 @@ def load_embedding_model(model: str): try: logger.info(f"Try embedding model: {model} on {embeddings_device}") if 'jina-embeddings' in model: - embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) # trust_remote_code is needed to use the encode method + embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=shared.args.trust_remote_code) embeddings_model = embeddings_model.to(embeddings_device) else: embeddings_model = SentenceTransformer(model, device=embeddings_device)