diff --git a/modules/logits.py b/modules/logits.py index 32aef7ae..56a20572 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -7,6 +7,7 @@ from modules import models, shared from modules.logging_colors import logger from modules.models import load_model from modules.text_generation import generate_reply +from modules.utils import check_model_loaded global_scores = None @@ -33,9 +34,9 @@ def get_next_logits(*args, **kwargs): def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False): - if shared.model is None: - logger.error("No model is loaded! Select one in the Model tab.") - return 'Error: No model is loaded1 Select one in the Model tab.', previous + model_is_loaded, error_message = check_model_loaded() + if not model_is_loaded: + return error_message, previous # llama.cpp case if shared.model.__class__.__name__ == 'LlamaServer': diff --git a/modules/text_generation.py b/modules/text_generation.py index 7e48a2f6..c0c0350d 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -14,6 +14,7 @@ from modules.callbacks import Iteratorize from modules.extensions import apply_extensions from modules.html_generator import generate_basic_html from modules.logging_colors import logger +from modules.utils import check_model_loaded def generate_reply(*args, **kwargs): @@ -34,8 +35,8 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap # Find the appropriate generation function generate_func = apply_extensions('custom_generate_reply') if generate_func is None: - if shared.model_name == 'None' or shared.model is None: - logger.error("No model is loaded! Select one in the Model tab.") + model_is_loaded, error_message = check_model_loaded() + if not model_is_loaded: yield '' return diff --git a/modules/utils.py b/modules/utils.py index 77324139..0e390d08 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -72,6 +72,20 @@ def natural_keys(text): return [atoi(c) for c in re.split(r'(\d+)', text)] +def check_model_loaded(): + if shared.model_name == 'None' or shared.model is None: + if len(get_available_models()) <= 1: + error_msg = "No model is loaded.\n\nTo get started:\n1) Place a GGUF file in your user_data/models folder\n2) Go to the Model tab and select it" + logger.error(error_msg) + return False, error_msg + else: + error_msg = "No model is loaded. Please select one in the Model tab." + logger.error(error_msg) + return False, error_msg + + return True, None + + def get_available_models(): # Get all GGUF files gguf_files = get_available_ggufs()