forked from Fannovel16/comfyui_controlnet_aux
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DepthAnythingLoader and Zoe_DepthAnythingLoader
- Loading branch information
1 parent
5a049bd
commit 1775bdf
Showing
1 changed file
with
54 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,88 @@ | ||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT | ||
from ..utils import common_annotator_call, define_preprocessor_inputs, INPUT, MAX_RESOLUTION | ||
import comfy.model_management as model_management | ||
|
||
class Depth_Anything_Loader: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": { "ckpt_name": (["depth_anything_vitl14.pth", "depth_anything_vitb14.pth", "depth_anything_vits14.pth"], {"default": "depth_anything_vitl14.pth"}) }} | ||
RETURN_TYPES = ("DEPTH_MODEL",) | ||
FUNCTION = "load_checkpoint" | ||
|
||
CATEGORY = "ControlNet Preprocessors/Depth Loader" | ||
|
||
def load_checkpoint(self, ckpt_name): | ||
from custom_controlnet_aux.depth_anything import DepthAnythingDetector | ||
model = DepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device()) | ||
return (model, ) | ||
|
||
class Depth_Anything_Preprocessor: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return define_preprocessor_inputs( | ||
ckpt_name=INPUT.COMBO( | ||
["depth_anything_vitl14.pth", "depth_anything_vitb14.pth", "depth_anything_vits14.pth"] | ||
), | ||
resolution=INPUT.RESOLUTION() | ||
) | ||
return { | ||
"required": { | ||
"image": ("IMAGE",), | ||
"model": ("DEPTH_MODEL",) | ||
}, | ||
"optional": { | ||
"resolution": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}) | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("IMAGE",) | ||
FUNCTION = "execute" | ||
|
||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators" | ||
|
||
def execute(self, image, ckpt_name="depth_anything_vitl14.pth", resolution=512, **kwargs): | ||
from custom_controlnet_aux.depth_anything import DepthAnythingDetector | ||
|
||
model = DepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device()) | ||
def execute(self, image, model, resolution=512, **kwargs): | ||
out = common_annotator_call(model, image, resolution=resolution) | ||
del model | ||
return (out, ) | ||
|
||
class Zoe_Depth_Anything_Loader: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": { "environment": (["indoor", "outdoor"], {"default": "indoor"})}} | ||
RETURN_TYPES = ("ZOEDEPTH_MODEL",) | ||
FUNCTION = "load_checkpoint" | ||
|
||
CATEGORY = "ControlNet Preprocessors/Depth Loader" | ||
|
||
def load_checkpoint(self, environment): | ||
from custom_controlnet_aux.zoe import ZoeDepthAnythingDetector | ||
ckpt_name = "depth_anything_metric_depth_indoor.pt" if environment == "indoor" else "depth_anything_metric_depth_outdoor.pt" | ||
model = ZoeDepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device()) | ||
return (model, ) | ||
|
||
class Zoe_Depth_Anything_Preprocessor: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return define_preprocessor_inputs( | ||
environment=INPUT.COMBO(["indoor", "outdoor"]), | ||
resolution=INPUT.RESOLUTION() | ||
) | ||
return { | ||
"required": { | ||
"image": ("IMAGE",), | ||
"model": ("ZOEDEPTH_MODEL",) | ||
}, | ||
"optional": { | ||
"resolution": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}) | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("IMAGE",) | ||
FUNCTION = "execute" | ||
|
||
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators" | ||
|
||
def execute(self, image, environment="indoor", resolution=512, **kwargs): | ||
from custom_controlnet_aux.zoe import ZoeDepthAnythingDetector | ||
ckpt_name = "depth_anything_metric_depth_indoor.pt" if environment == "indoor" else "depth_anything_metric_depth_outdoor.pt" | ||
model = ZoeDepthAnythingDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device()) | ||
def execute(self, image, model, resolution=512, **kwargs): | ||
out = common_annotator_call(model, image, resolution=resolution) | ||
del model | ||
return (out, ) | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
"DepthAnythingLoader": Depth_Anything_Loader, | ||
"DepthAnythingPreprocessor": Depth_Anything_Preprocessor, | ||
"Zoe_DepthAnythingLoader": Zoe_Depth_Anything_Loader, | ||
"Zoe_DepthAnythingPreprocessor": Zoe_Depth_Anything_Preprocessor | ||
} | ||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
"DepthAnythingLoader": "Depth Anything Loader", | ||
"DepthAnythingPreprocessor": "Depth Anything", | ||
"Zoe_DepthAnythingLoader": "Zoe Depth Anything Loader", | ||
"Zoe_DepthAnythingPreprocessor": "Zoe Depth Anything" | ||
} |