Skip to content

Commit

Permalink
Add support for HF token
Browse files Browse the repository at this point in the history
  • Loading branch information
yoland68 committed Oct 4, 2024
1 parent 6fdc278 commit cdd0c2d
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 11 deletions.
94 changes: 84 additions & 10 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import pathlib
import sys
from typing import List, Optional, Tuple
from urllib.parse import unquote, urlparse

import requests
import typer
Expand All @@ -10,7 +12,7 @@
from comfy_cli import constants, tracking, ui
from comfy_cli.config_manager import ConfigManager
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH
from comfy_cli.file_utils import DownloadException, download_file
from comfy_cli.file_utils import DownloadException, check_unauthorized, download_file
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()
Expand All @@ -37,10 +39,41 @@ def potentially_strip_param_url(path_name: str) -> str:
return path_name


# Convert relative path to absolute path based on the current working
# directory
def check_huggingface_url(url: str) -> bool:
return "huggingface.co" in url
def check_huggingface_url(url: str) -> Tuple[bool, Optional[str], Optional[str], Optional[str], Optional[str]]:
"""
Check if the given URL is a Hugging Face URL and extract relevant information.
Args:
url (str): The URL to check.
Returns:
Tuple[bool, Optional[str], Optional[str], Optional[str], Optional[str]]:
- is_huggingface_url (bool): True if it's a Hugging Face URL, False otherwise.
- repo_id (Optional[str]): The repository ID if it's a Hugging Face URL, None otherwise.
- filename (Optional[str]): The filename if present, None otherwise.
- folder_name (Optional[str]): The folder name if present, None otherwise.
- branch_name (Optional[str]): The git branch name if present, None otherwise.
"""
parsed_url = urlparse(url)

if parsed_url.netloc != "huggingface.co" and parsed_url.netloc != "huggingface.com":
return False, None, None, None, None

path_parts = [p for p in parsed_url.path.split("/") if p]

if len(path_parts) < 5 or (path_parts[2] != "resolve" and path_parts[2] != "blob"):
return False, None, None, None, None
repo_id = f"{path_parts[0]}/{path_parts[1]}"
branch_name = path_parts[3]

remaining_path = "/".join(path_parts[4:])
folder_name = os.path.dirname(remaining_path) if "/" in remaining_path else None
filename = os.path.basename(remaining_path)

# URL decode the filename
filename = unquote(filename)

return True, repo_id, filename, folder_name, branch_name


def check_civitai_url(url: str) -> Tuple[bool, bool, int, int]:
Expand Down Expand Up @@ -154,6 +187,14 @@ def download(
show_default=False,
),
] = None,
set_hf_api_token: Annotated[
Optional[str],
typer.Option(
"--set-hf-api-token",
help="Set the HuggingFace API token to use for model listing.",
show_default=False,
),
] = None,
):
if relative_path is not None:
relative_path = os.path.expanduser(relative_path)
Expand All @@ -166,8 +207,12 @@ def download(
config_manager.set(constants.CIVITAI_API_TOKEN_KEY, set_civitai_api_token)
civitai_api_token = set_civitai_api_token

if set_hf_api_token is not None:
config_manager.set(constants.HF_API_TOKEN_KEY, set_hf_api_token)
hf_api_token = set_hf_api_token
else:
civitai_api_token = config_manager.get(constants.CIVITAI_API_TOKEN_KEY)
hf_api_token = config_manager.get(constants.HF_API_TOKEN_KEY)

if civitai_api_token is not None:
headers = {
Expand All @@ -176,6 +221,7 @@ def download(
}

is_civitai_model_url, is_civitai_api_url, model_id, version_id = check_civitai_url(url)
is_huggingface_url, repo_id, hf_filename, hf_folder_name, hf_branch_name = check_huggingface_url(url)

if is_civitai_model_url:
local_filename, url, model_type, basemodel = request_civitai_model_api(model_id, version_id, headers)
Expand All @@ -197,7 +243,9 @@ def download(
model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="")

relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel)
elif check_huggingface_url(url):
elif is_huggingface_url:
model_id = "/".join(url.split("/")[-2:])

local_filename = potentially_strip_param_url(url.split("/")[-1])

if relative_path is None:
Expand Down Expand Up @@ -225,14 +273,40 @@ def download(

local_filepath = get_workspace() / relative_path / local_filename

# Check if the file already exists
if local_filepath.exists():
print(f"[bold red]File already exists: {local_filepath}[/bold red]")
return

# File does not exist, proceed with download
print(f"Start downloading URL: {url} into {local_filepath}")
download_file(url, local_filepath, headers)
if is_huggingface_url and check_unauthorized(url, headers):
if hf_api_token is None:
print(
"Unauthorized access to Hugging Face model. Please set the HuggingFace API token using --set-hf-api-token"
)
return
else:
try:
import huggingface_hub
except ImportError:
print("huggingface_hub not found. Installing...")
import subprocess

subprocess.check_call([sys.executable, "-m", "pip", "install", "huggingface_hub"])
import huggingface_hub

print(f"Downloading model {model_id} from Hugging Face...")
output_path = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=hf_filename,
subfolder=hf_folder_name,
revision=hf_branch_name,
token=hf_api_token,
local_dir=get_workspace() / relative_path,
cache_dir=get_workspace() / relative_path,
)
print(f"Model downloaded successfully to: {output_path}")
else:
print(f"Start downloading URL: {url} into {local_filepath}")
download_file(url, local_filepath, headers)


@app.command()
Expand Down
1 change: 1 addition & 0 deletions comfy_cli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class PROC(str, Enum):
CONFIG_KEY_BACKGROUND = "background"

CIVITAI_API_TOKEN_KEY = "civitai_api_token"
HF_API_TOKEN_KEY = "hf_api_token"

DEFAULT_TRACKING_VALUE = True

Expand Down
19 changes: 19 additions & 0 deletions comfy_cli/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def parse_json(input_data):
return f"Unknown error occurred (status code: {status_code})"


def check_unauthorized(url: str, headers: Optional[dict] = None) -> bool:
"""
Perform a GET request to the given URL and check if the response status code is 401 (Unauthorized).
Args:
url (str): The URL to send the GET request to.
headers (Optional[dict]): Optional headers to include in the request.
Returns:
bool: True if the response status code is 401, False otherwise.
"""
try:
response = requests.get(url, headers=headers, allow_redirects=True)
return response.status_code == 401
except requests.RequestException:
# If there's an error making the request, we can't determine if it's unauthorized
return False


def download_file(url: str, local_filepath: pathlib.Path, headers: Optional[dict] = None):
"""Helper function to download a file."""
local_filepath.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists
Expand Down
65 changes: 64 additions & 1 deletion tests/comfy_cli/command/models/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from comfy_cli.command.models.models import check_civitai_url
from comfy_cli.command.models.models import check_civitai_url, check_huggingface_url


def test_valid_model_url():
Expand Down Expand Up @@ -34,3 +34,66 @@ def test_malformed_url():
def test_malformed_query_url():
url = "https://civitai.com/models/43331?version="
assert check_civitai_url(url) == (False, False, None, None)


def test_valid_huggingface_url():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", None, "main")


def test_valid_huggingface_url_sd_audio():
url = "https://huggingface.co/stabilityai/stable-audio-open-1.0/blob/main/model.safetensors"
assert check_huggingface_url(url) == (True, "stabilityai/stable-audio-open-1.0", "model.safetensors", None, "main")


def test_valid_huggingface_url_with_folder():
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt"
assert check_huggingface_url(url) == (
True,
"runwayml/stable-diffusion-v1-5",
"v1-5-pruned-emaonly.ckpt",
None,
"main",
)


def test_valid_huggingface_url_with_subfolder():
url = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt"
assert check_huggingface_url(url) == (
True,
"stabilityai/stable-diffusion-2-1",
"v2-1_768-ema-pruned.ckpt",
None,
"main",
)


def test_valid_huggingface_url_with_encoded_filename():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4%20(1).ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4 (1).ckpt", None, "main")


def test_invalid_huggingface_url():
url = "https://example.com/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (False, None, None, None, None)


def test_invalid_huggingface_url_structure():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (False, None, None, None, None)


def test_huggingface_url_with_com_domain():
url = "https://huggingface.com/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", None, "main")


def test_huggingface_url_with_folder_structure():
url = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors"
assert check_huggingface_url(url) == (
True,
"stabilityai/stable-diffusion-xl-base-1.0",
"sd_xl_base_1.0.safetensors",
None,
"main",
)

0 comments on commit cdd0c2d

Please sign in to comment.