Add model download branch handling in download_model_wrapper

Works with `*/tree/<branch>` URL or `*:<branch>` ID
This commit is contained in:
Th-Underscore 2026-04-17 03:50:36 -04:00
parent 145f3297a2
commit b6460908de
No known key found for this signature in database
GPG key ID: 21EEB0243310C90C
2 changed files with 22 additions and 3 deletions

View file

@ -57,8 +57,7 @@ class ModelDownloader:
return session
def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/':
model = model[:-1]
model = model.removesuffix("/")
if model.startswith(base + '/'):
model = model[len(base) + 1:]

View file

@ -239,8 +239,27 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
downloader_module = importlib.import_module("download-model")
downloader = downloader_module.ModelDownloader()
update_queue = queue.Queue()
branch = None
try:
# Handle branch in URL
if "/tree/" in repo_id:
try:
repo_id, branch = repo_id.split("/tree/")
except Exception as e:
yield f"Error parsing branch from URL: {e}"
progress(0.0)
return
# Handle branch delimited by ":"
elif ":" in repo_id:
try:
repo_id, branch = repo_id.split(":")
except Exception as e:
yield f"Error parsing branch from repo_id: {e}"
progress(0.0)
return
# Handle direct GGUF URLs
if repo_id.startswith("https://") and ("huggingface.co" in repo_id) and (repo_id.endswith(".gguf") or repo_id.endswith(".gguf?download=true")):
try:
@ -256,6 +275,7 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
progress(0.0)
return
if not repo_id:
yield "Please enter a model path."
progress(0.0)
@ -266,7 +286,7 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
progress(0.0, "Preparing download...")
model, branch = downloader.sanitize_model_and_branch_names(repo_id, None)
model, branch = downloader.sanitize_model_and_branch_names(repo_id, branch)
yield "Getting download links from Hugging Face..."
links, sha256, is_lora, is_llamacpp, file_sizes = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file)