Skip to content

Commit

Permalink
Feat: add support for civitai url model download (#60)
Browse files Browse the repository at this point in the history
* Feat: add support for civitai url model download

* Add pytest workflow
  • Loading branch information
yoland68 authored May 15, 2024
1 parent 956cb2e commit 4901cd5
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 10 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9' # Adjust the version as needed
python-version: '3.9' # Follow the min version in pyproject.toml

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r requirements.txt # If you have other dependencies
pip install -e .
- name: Run tests
env:
PYTHONPATH: ${{ github.workspace }}
run: |
pytest
pytest
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ repos:
hooks:
- id: pylint
args:
- --disable=R,C,W,E0401
- --disable=R,C,W,E0401
105 changes: 100 additions & 5 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
from typing import List, Optional
from typing import List, Optional, Tuple

import requests
import typer

from typing_extensions import Annotated
Expand All @@ -24,10 +25,91 @@ def potentially_strip_param_url(path_name: str) -> str:
return path_name


# Convert relative path to absolute path based on the current working
# directory
def check_huggingface_url(url: str) -> bool:
return "huggingface.co" in url


def check_civitai_url(url: str) -> Tuple[bool, bool, int, int]:
"""
Returns:
is_civitai_model_url: True if the url is a civitai model url
is_civitai_api_url: True if the url is a civitai api url
model_id: The model id or None if it's api url
version_id: The version id or None if it doesn't have version id info
"""
prefix = "civitai.com"
try:
if prefix in url:
# URL is civitai api download url: https://civitai.com/api/download/models/12345
if "civitai.com/api/download" in url:
# This is a direct download link
version_id = url.strip("/").split("/")[-1]
return False, True, None, int(version_id)

# URL is civitai web url (e.g.
# - https://civitai.com/models/43331
# - https://civitai.com/models/43331/majicmix-realistic
subpath = url[url.find(prefix) + len(prefix) :].strip("/")
url_parts = subpath.split("?")
if len(url_parts) > 1:
model_id = url_parts[0].split("/")[1]
version_id = url_parts[1].split("=")[1]
return True, False, int(model_id), int(version_id)
else:
model_id = subpath.split("/")[1]
return True, False, int(model_id), None
except (ValueError, IndexError):
print("Error parsing Civitai model URL")

return False, False, None, None


def request_civitai_model_version_api(version_id: int):
# Make a request to the Civitai API to get the model information
response = requests.get(
f"https://civitai.com/api/v1/model-versions/{version_id}", timeout=10
)
response.raise_for_status() # Raise an error for bad status codes

model_data = response.json()
for file in model_data["files"]:
if file["primary"]: # Assuming we want the primary file
model_name = file["name"]
download_url = file["downloadUrl"]
return model_name, download_url


def request_civitai_model_api(model_id: int, version_id: int = None):
# Make a request to the Civitai API to get the model information
response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10)
response.raise_for_status() # Raise an error for bad status codes

model_data = response.json()

# If version_id is None, use the first version
if version_id is None:
version_id = model_data["modelVersions"][0]["id"]

# Find the version with the specified version_id
for version in model_data["modelVersions"]:
if version["id"] == version_id:
# Get the model name and download URL from the files array
for file in version["files"]:
if file["primary"]: # Assuming we want the primary file
model_name = file["name"]
download_url = file["downloadUrl"]
return model_name, download_url

# If the specified version_id is not found, raise an error
raise ValueError(f"Version ID {version_id} not found for model ID {model_id}")


@app.command()
@tracking.track_command("model")
def download(
ctx: typer.Context,
_ctx: typer.Context,
url: Annotated[
str,
typer.Option(
Expand All @@ -42,9 +124,22 @@ def download(
),
] = DEFAULT_COMFY_MODEL_PATH,
):
"""Download a model to a specified relative path if it is not already downloaded."""
# Convert relative path to absolute path based on the current working directory
local_filename = potentially_strip_param_url(url.split("/")[-1])

local_filename = None

is_civitai_model_url, is_civitai_api_url, model_id, version_id = check_civitai_url(
url
)
is_huggingface = False
if is_civitai_model_url:
local_filename, url = request_civitai_model_api(model_id, version_id)
elif is_civitai_api_url:
local_filename, url = request_civitai_model_version_api(version_id)
elif check_huggingface_url(url):
is_huggingface = True
local_filename = potentially_strip_param_url(url.split("/")[-1])
else:
print("Model source is unknown")
local_filename = ui.prompt_input(
"Enter filename to save model as", default=local_filename
)
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ GitPython
requests
pyyaml
typing-extensions
questionary
mixpanel
tomlkit
pathspec
httpx
httpx
packaging
36 changes: 36 additions & 0 deletions tests/comfy_cli/command/models/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from comfy_cli.command.models.models import check_civitai_url


def test_valid_model_url():
url = "https://civitai.com/models/43331"
assert check_civitai_url(url) == (True, False, 43331, None)


def test_valid_model_url_with_version():
url = "https://civitai.com/models/43331/majicmix-realistic"
assert check_civitai_url(url) == (True, False, 43331, None)


def test_valid_model_url_with_query():
url = "https://civitai.com/models/43331?version=12345"
assert check_civitai_url(url) == (True, False, 43331, 12345)


def test_valid_api_url():
url = "https://civitai.com/api/download/models/67890"
assert check_civitai_url(url) == (False, True, None, 67890)


def test_invalid_url():
url = "https://example.com/models/43331"
assert check_civitai_url(url) == (False, False, None, None)


def test_malformed_url():
url = "https://civitai.com/models/"
assert check_civitai_url(url) == (False, False, None, None)


def test_malformed_query_url():
url = "https://civitai.com/models/43331?version="
assert check_civitai_url(url) == (False, False, None, None)

0 comments on commit 4901cd5

Please sign in to comment.