Skip to content

Commit

Permalink
Feat: add support for civitai url model download
Browse files Browse the repository at this point in the history
  • Loading branch information
yoland68 committed May 15, 2024
1 parent 85041e0 commit a79fce2
Showing 1 changed file with 65 additions and 5 deletions.
70 changes: 65 additions & 5 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
from typing import List, Optional
from typing import List, Optional, Tuple

import requests
import typer

from typing_extensions import Annotated
Expand All @@ -24,10 +25,60 @@ def potentially_strip_param_url(path_name: str) -> str:
return path_name


# Convert relative path to absolute path based on the current working
# directory
def is_huggingface_model(url: str) -> bool:
return "huggingface.co" in url


def is_civitai_model(url: str) -> Tuple[bool, int, int]:
prefix = "civitai.com"
try:
if prefix in url:
subpath = url[url.find(prefix) + len(prefix) :].strip("/")
url_parts = subpath.split("?")
if len(url_parts) > 1:
model_id = url_parts[0].split("/")[1]
version_id = url_parts[1].split("=")[1]
return True, int(model_id), int(version_id)
else:
model_id = subpath.split("/")[1]
return True, int(model_id), None
except ValueError:
print("Error parsing Civitai model URL")
pass
return False, None, None


def request_civitai_api(model_id: int, version_id: int = None):
# Make a request to the Civitai API to get the model information
response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10)
response.raise_for_status() # Raise an error for bad status codes

model_data = response.json()

# If version_id is None, use the first version
if version_id is None:
version_id = model_data["modelVersions"][0]["id"]

# Find the version with the specified version_id
for version in model_data["modelVersions"]:
if version["id"] == version_id:
# Get the model name and download URL from the files array
for file in version["files"]:
if file["primary"]: # Assuming we want the primary file
model_name = file["name"]
download_url = file["downloadUrl"]
return model_name, download_url

# If the specified version_id is not found, raise an error
raise ValueError(f"Version ID {version_id} not found for model ID {model_id}")


@app.command()
@tracking.track_command("model")
def download(
ctx: typer.Context,
_ctx: typer.Context,
url: Annotated[
str,
typer.Option(
Expand All @@ -42,9 +93,18 @@ def download(
),
] = DEFAULT_COMFY_MODEL_PATH,
):
"""Download a model to a specified relative path if it is not already downloaded."""
# Convert relative path to absolute path based on the current working directory
local_filename = potentially_strip_param_url(url.split("/")[-1])

local_filename = None

is_civitai, model_id, version_id = is_civitai_model(url)
is_huggingface = False
if is_civitai:
local_filename, url = request_civitai_api(model_id, version_id)
elif is_huggingface_model(url):
is_huggingface = True
local_filename = potentially_strip_param_url(url.split("/")[-1])
else:
print("Model source is unknown")
local_filename = ui.prompt_input(
"Enter filename to save model as", default=local_filename
)
Expand Down

0 comments on commit a79fce2

Please sign in to comment.