From b11fd3cdf8b46475ee99684bfd5939896a21d621 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:57:36 -0700 Subject: [PATCH] Add support for HF token (#193) --- comfy_cli/command/models/models.py | 94 +++++++++++++++++-- comfy_cli/constants.py | 1 + comfy_cli/file_utils.py | 19 ++++ tests/comfy_cli/command/models/test_models.py | 65 ++++++++++++- 4 files changed, 168 insertions(+), 11 deletions(-) diff --git a/comfy_cli/command/models/models.py b/comfy_cli/command/models/models.py index 82740cb..9aee451 100644 --- a/comfy_cli/command/models/models.py +++ b/comfy_cli/command/models/models.py @@ -1,6 +1,8 @@ import os import pathlib +import sys from typing import List, Optional, Tuple +from urllib.parse import unquote, urlparse import requests import typer @@ -10,7 +12,7 @@ from comfy_cli import constants, tracking, ui from comfy_cli.config_manager import ConfigManager from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH -from comfy_cli.file_utils import DownloadException, download_file +from comfy_cli.file_utils import DownloadException, check_unauthorized, download_file from comfy_cli.workspace_manager import WorkspaceManager app = typer.Typer() @@ -37,10 +39,41 @@ def potentially_strip_param_url(path_name: str) -> str: return path_name -# Convert relative path to absolute path based on the current working -# directory -def check_huggingface_url(url: str) -> bool: - return "huggingface.co" in url +def check_huggingface_url(url: str) -> Tuple[bool, Optional[str], Optional[str], Optional[str], Optional[str]]: + """ + Check if the given URL is a Hugging Face URL and extract relevant information. + + Args: + url (str): The URL to check. + + Returns: + Tuple[bool, Optional[str], Optional[str], Optional[str], Optional[str]]: + - is_huggingface_url (bool): True if it's a Hugging Face URL, False otherwise. + - repo_id (Optional[str]): The repository ID if it's a Hugging Face URL, None otherwise. + - filename (Optional[str]): The filename if present, None otherwise. + - folder_name (Optional[str]): The folder name if present, None otherwise. + - branch_name (Optional[str]): The git branch name if present, None otherwise. + """ + parsed_url = urlparse(url) + + if parsed_url.netloc != "huggingface.co" and parsed_url.netloc != "huggingface.com": + return False, None, None, None, None + + path_parts = [p for p in parsed_url.path.split("/") if p] + + if len(path_parts) < 5 or (path_parts[2] != "resolve" and path_parts[2] != "blob"): + return False, None, None, None, None + repo_id = f"{path_parts[0]}/{path_parts[1]}" + branch_name = path_parts[3] + + remaining_path = "/".join(path_parts[4:]) + folder_name = os.path.dirname(remaining_path) if "/" in remaining_path else None + filename = os.path.basename(remaining_path) + + # URL decode the filename + filename = unquote(filename) + + return True, repo_id, filename, folder_name, branch_name def check_civitai_url(url: str) -> Tuple[bool, bool, int, int]: @@ -154,6 +187,14 @@ def download( show_default=False, ), ] = None, + set_hf_api_token: Annotated[ + Optional[str], + typer.Option( + "--set-hf-api-token", + help="Set the HuggingFace API token to use for model listing.", + show_default=False, + ), + ] = None, ): if relative_path is not None: relative_path = os.path.expanduser(relative_path) @@ -166,8 +207,12 @@ def download( config_manager.set(constants.CIVITAI_API_TOKEN_KEY, set_civitai_api_token) civitai_api_token = set_civitai_api_token + if set_hf_api_token is not None: + config_manager.set(constants.HF_API_TOKEN_KEY, set_hf_api_token) + hf_api_token = set_hf_api_token else: civitai_api_token = config_manager.get(constants.CIVITAI_API_TOKEN_KEY) + hf_api_token = config_manager.get(constants.HF_API_TOKEN_KEY) if civitai_api_token is not None: headers = { @@ -176,6 +221,7 @@ def download( } is_civitai_model_url, is_civitai_api_url, model_id, version_id = check_civitai_url(url) + is_huggingface_url, repo_id, hf_filename, hf_folder_name, hf_branch_name = check_huggingface_url(url) if is_civitai_model_url: local_filename, url, model_type, basemodel = request_civitai_model_api(model_id, version_id, headers) @@ -197,7 +243,9 @@ def download( model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="") relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel) - elif check_huggingface_url(url): + elif is_huggingface_url: + model_id = "/".join(url.split("/")[-2:]) + local_filename = potentially_strip_param_url(url.split("/")[-1]) if relative_path is None: @@ -225,14 +273,40 @@ def download( local_filepath = get_workspace() / relative_path / local_filename - # Check if the file already exists if local_filepath.exists(): print(f"[bold red]File already exists: {local_filepath}[/bold red]") return - # File does not exist, proceed with download - print(f"Start downloading URL: {url} into {local_filepath}") - download_file(url, local_filepath, headers) + if is_huggingface_url and check_unauthorized(url, headers): + if hf_api_token is None: + print( + "Unauthorized access to Hugging Face model. Please set the HuggingFace API token using --set-hf-api-token" + ) + return + else: + try: + import huggingface_hub + except ImportError: + print("huggingface_hub not found. Installing...") + import subprocess + + subprocess.check_call([sys.executable, "-m", "pip", "install", "huggingface_hub"]) + import huggingface_hub + + print(f"Downloading model {model_id} from Hugging Face...") + output_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=hf_filename, + subfolder=hf_folder_name, + revision=hf_branch_name, + token=hf_api_token, + local_dir=get_workspace() / relative_path, + cache_dir=get_workspace() / relative_path, + ) + print(f"Model downloaded successfully to: {output_path}") + else: + print(f"Start downloading URL: {url} into {local_filepath}") + download_file(url, local_filepath, headers) @app.command() diff --git a/comfy_cli/constants.py b/comfy_cli/constants.py index 38baead..31763c8 100644 --- a/comfy_cli/constants.py +++ b/comfy_cli/constants.py @@ -42,6 +42,7 @@ class PROC(str, Enum): CONFIG_KEY_BACKGROUND = "background" CIVITAI_API_TOKEN_KEY = "civitai_api_token" +HF_API_TOKEN_KEY = "hf_api_token" DEFAULT_TRACKING_VALUE = True diff --git a/comfy_cli/file_utils.py b/comfy_cli/file_utils.py index 8f65b3c..6e6df0f 100644 --- a/comfy_cli/file_utils.py +++ b/comfy_cli/file_utils.py @@ -46,6 +46,25 @@ def parse_json(input_data): return f"Unknown error occurred (status code: {status_code})" +def check_unauthorized(url: str, headers: Optional[dict] = None) -> bool: + """ + Perform a GET request to the given URL and check if the response status code is 401 (Unauthorized). + + Args: + url (str): The URL to send the GET request to. + headers (Optional[dict]): Optional headers to include in the request. + + Returns: + bool: True if the response status code is 401, False otherwise. + """ + try: + response = requests.get(url, headers=headers, allow_redirects=True) + return response.status_code == 401 + except requests.RequestException: + # If there's an error making the request, we can't determine if it's unauthorized + return False + + def download_file(url: str, local_filepath: pathlib.Path, headers: Optional[dict] = None): """Helper function to download a file.""" local_filepath.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists diff --git a/tests/comfy_cli/command/models/test_models.py b/tests/comfy_cli/command/models/test_models.py index 94279e6..8e9bd9f 100644 --- a/tests/comfy_cli/command/models/test_models.py +++ b/tests/comfy_cli/command/models/test_models.py @@ -1,4 +1,4 @@ -from comfy_cli.command.models.models import check_civitai_url +from comfy_cli.command.models.models import check_civitai_url, check_huggingface_url def test_valid_model_url(): @@ -34,3 +34,66 @@ def test_malformed_url(): def test_malformed_query_url(): url = "https://civitai.com/models/43331?version=" assert check_civitai_url(url) == (False, False, None, None) + + +def test_valid_huggingface_url(): + url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt" + assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", None, "main") + + +def test_valid_huggingface_url_sd_audio(): + url = "https://huggingface.co/stabilityai/stable-audio-open-1.0/blob/main/model.safetensors" + assert check_huggingface_url(url) == (True, "stabilityai/stable-audio-open-1.0", "model.safetensors", None, "main") + + +def test_valid_huggingface_url_with_folder(): + url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt" + assert check_huggingface_url(url) == ( + True, + "runwayml/stable-diffusion-v1-5", + "v1-5-pruned-emaonly.ckpt", + None, + "main", + ) + + +def test_valid_huggingface_url_with_subfolder(): + url = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt" + assert check_huggingface_url(url) == ( + True, + "stabilityai/stable-diffusion-2-1", + "v2-1_768-ema-pruned.ckpt", + None, + "main", + ) + + +def test_valid_huggingface_url_with_encoded_filename(): + url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4%20(1).ckpt" + assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4 (1).ckpt", None, "main") + + +def test_invalid_huggingface_url(): + url = "https://example.com/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt" + assert check_huggingface_url(url) == (False, None, None, None, None) + + +def test_invalid_huggingface_url_structure(): + url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/main/sd-v1-4.ckpt" + assert check_huggingface_url(url) == (False, None, None, None, None) + + +def test_huggingface_url_with_com_domain(): + url = "https://huggingface.com/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt" + assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", None, "main") + + +def test_huggingface_url_with_folder_structure(): + url = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors" + assert check_huggingface_url(url) == ( + True, + "stabilityai/stable-diffusion-xl-base-1.0", + "sd_xl_base_1.0.safetensors", + None, + "main", + )