diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 5953803a..8f1924cb 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -20,6 +20,7 @@ from modules.image_utils import ( convert_pil_to_base64 ) from modules.logging_colors import logger +from modules.utils import resolve_model_path llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"} @@ -351,14 +352,12 @@ class LlamaServer: if path.exists(): cmd += ["--mmproj", str(path)] if shared.args.model_draft not in [None, 'None']: - path = Path(shared.args.model_draft) - if not path.exists(): - path = Path(f'{shared.args.model_dir}/{shared.args.model_draft}') + path = resolve_model_path(shared.args.model_draft) if path.is_file(): model_file = path else: - model_file = sorted(Path(f'{shared.args.model_dir}/{shared.args.model_draft}').glob('*.gguf'))[0] + model_file = sorted(path.glob('*.gguf'))[0] cmd += ["--model-draft", model_file] if shared.args.draft_max > 0: