Skip to content
This repository has been archived by the owner on Jan 24, 2025. It is now read-only.

support config model path in extra_model_paths.yaml #346

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 82 additions & 85 deletions WAS_Node_Suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ def print(self, **kwargs):
ALLOWED_EXT = ('.jpeg', '.jpg', '.png',
'.tiff', '.gif', '.bmp', '.webp')


#! Get model path with extra_model_paths
def get_model_path_with_extra(model_name):
if model_name in comfy_paths.folder_names_and_paths.keys():
return comfy_paths.folder_names_and_paths[model_name][0][0]
else:
return os.path.join(MODELS_DIR, model_name)

#! INSTALLATION CLEANUP

# Delete legacy nodes
Expand Down Expand Up @@ -230,7 +236,6 @@ def getSuiteConfig():
cstr(f"Unable to load conf file at `{WAS_CONFIG_FILE}`. Using internal config template.").error.print()
return was_conf_template
return was_config
return was_config

def updateSuiteConfig(conf):
try:
Expand Down Expand Up @@ -330,12 +335,18 @@ def updateSuiteConfig(conf):
#! SUITE SPECIFIC CLASSES & FUNCTIONS

# Freeze PIP modules
cached_packages = None
def packages(versions=False):
import sys
import subprocess
return [( r.decode().split('==')[0] if not versions else r.decode() ) for r in subprocess.check_output([sys.executable, '-s', '-m', 'pip', 'freeze']).split()]

global cached_packages
if cached_packages is None:
cached_packages = subprocess.check_output([sys.executable, '-m', 'pip', 'freeze']).split()
if versions:
return [r.decode() for r in cached_packages]
else:
return [r.decode().split('==')[0] for r in cached_packages]

def install_package(package, uninstall_first: Union[List[str], str] = None):
print(f"Installing {package}...")
if os.getenv("WAS_BLOCK_AUTO_INSTALL", 'False').lower() in ('true', '1', 't'):
cstr(f"Preventing package install of '{package}' due to WAS_BLOCK_INSTALL env").msg.print()
else:
Expand Down Expand Up @@ -5840,7 +5851,7 @@ def image_rembg(

from rembg import remove, new_session

os.environ['U2NET_HOME'] = os.path.join(MODELS_DIR, 'rembg')
os.environ['U2NET_HOME'] = get_model_path_with_extra('rembg')
os.makedirs(os.environ['U2NET_HOME'], exist_ok=True)

# Set bgcolor
Expand Down Expand Up @@ -8491,7 +8502,7 @@ def inject_noise(self, samples, noise_std):

class MiDaS_Model_Loader:
def __init__(self):
self.midas_dir = os.path.join(MODELS_DIR, 'midas')
self.midas_dir = get_model_path_with_extra('midas')

@classmethod
def INPUT_TYPES(cls):
Expand Down Expand Up @@ -8551,7 +8562,7 @@ def install_midas(self):

class MiDaS_Depth_Approx:
def __init__(self):
self.midas_dir = os.path.join(MODELS_DIR, 'midas')
self.midas_dir =get_model_path_with_extra('midas')

@classmethod
def INPUT_TYPES(cls):
Expand Down Expand Up @@ -8668,7 +8679,7 @@ def install_midas(self):

class MiDaS_Background_Foreground_Removal:
def __init__(self):
self.midas_dir = os.path.join(MODELS_DIR, 'midas')
self.midas_dir = get_model_path_with_extra('midas')

@classmethod
def INPUT_TYPES(cls):
Expand Down Expand Up @@ -10981,7 +10992,7 @@ def blip_model(self, blip_model):

from .modules.BLIP.blip_module import blip_decoder

blip_dir = os.path.join(MODELS_DIR, 'blip')
blip_dir = get_model_path_with_extra('blip')
if not os.path.exists(blip_dir):
os.makedirs(blip_dir, exist_ok=True)

Expand All @@ -10999,8 +11010,8 @@ def blip_model(self, blip_model):
elif blip_model == 'interrogate':

from .modules.BLIP.blip_module import blip_vqa

blip_dir = os.path.join(MODELS_DIR, 'blip')
blip_dir = get_model_path_with_extra('blip')
if not os.path.exists(blip_dir):
os.makedirs(blip_dir, exist_ok=True)

Expand Down Expand Up @@ -11077,8 +11088,8 @@ def transformImage(input_image, image_size, device):
model = blip_model[0].to(device)
else:
from .modules.BLIP.blip_module import blip_decoder

blip_dir = os.path.join(MODELS_DIR, 'blip')
blip_dir = get_model_path_with_extra('blip')
if not os.path.exists(blip_dir):
os.makedirs(blip_dir, exist_ok=True)

Expand Down Expand Up @@ -11107,7 +11118,7 @@ def transformImage(input_image, image_size, device):
else:
from .modules.BLIP.blip_module import blip_vqa

blip_dir = os.path.join(MODELS_DIR, 'blip')
blip_dir = get_model_path_with_extra('blip')
if not os.path.exists(blip_dir):
os.makedirs(blip_dir, exist_ok=True)

Expand Down Expand Up @@ -11155,7 +11166,7 @@ def INPUT_TYPES(cls):
def clipseg_model(self, model):
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

cache = os.path.join(MODELS_DIR, 'clipseg')
cache = get_model_path_with_extra('clipseg')

inputs = CLIPSegProcessor.from_pretrained(model, cache_dir=cache)
model = CLIPSegForImageSegmentation.from_pretrained(model, cache_dir=cache)
Expand Down Expand Up @@ -11190,7 +11201,7 @@ def CLIPSeg_image(self, image, text=None, clipseg_model=None):
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

image = tensor2pil(image)
cache = os.path.join(MODELS_DIR, 'clipseg')
cache = get_model_path_with_extra('clipseg')

if clipseg_model:
inputs = clipseg_model[0]
Expand Down Expand Up @@ -11284,7 +11295,7 @@ def CLIPSeg_images(self, image_a, image_b, text_a, text_b, image_c=None, image_d
if text_f:
prompts.append(text_f)

cache = os.path.join(MODELS_DIR, 'clipseg')
cache = get_model_path_with_extra('clipseg')

inputs = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined", cache_dir=cache)
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined", cache_dir=cache)
Expand Down Expand Up @@ -11359,8 +11370,8 @@ def sam_load_model(self, model_size):
Repo.clone_from('https://github.com/facebookresearch/segment-anything', os.path.join(WAS_SUITE_ROOT, 'repos'+os.sep+'SAM'))

sys.path.append(os.path.join(WAS_SUITE_ROOT, 'repos'+os.sep+'SAM'))

sam_dir = os.path.join(MODELS_DIR, 'sam')
sam_dir = get_model_path_with_extra('sam')
if not os.path.exists(sam_dir):
os.makedirs(sam_dir, exist_ok=True)

Expand Down Expand Up @@ -13177,23 +13188,6 @@ def IS_CHANGED(cls, **kwargs):

# CUSTOM COMFYUI NODES

class WAS_Checkpoint_Loader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (comfy_paths.get_filename_list("configs"), ),
"ckpt_name": (comfy_paths.get_filename_list("checkpoints"), )}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", TEXT_TYPE)
RETURN_NAMES = ("MODEL", "CLIP", "VAE", "NAME_STRING")
FUNCTION = "load_checkpoint"

CATEGORY = "WAS Suite/Loaders/Advanced"

def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = comfy_paths.get_full_path("configs", config_name)
ckpt_path = comfy_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=comfy_paths.get_folder_paths("embeddings"))
return (out[0], out[1], out[2], os.path.splitext(os.path.basename(ckpt_name))[0])

class WAS_Checkpoint_Loader:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -14099,61 +14093,64 @@ def encode(self, clip, text, token_normalization, weight_interpretation, seed=0,
if NODE_CLASS_MAPPINGS.__contains__("CLIPTextEncode (BlenderNeko Advanced + NSP)"):
cstr('`CLIPTextEncode (BlenderNeko Advanced + NSP)` node enabled under `WAS Suite/Conditioning` menu.').msg.print()

# opencv-python-headless handling
if 'opencv-python' in packages() or 'opencv-python-headless' in packages():
try:
import cv2
build_info = ' '.join(cv2.getBuildInformation().split())
if "FFMPEG: YES" in build_info:
if was_config.__contains__('show_startup_junk'):
if was_config['show_startup_junk']:
cstr("OpenCV Python FFMPEG support is enabled").msg.print()
if was_config.__contains__('ffmpeg_bin_path'):
if was_config['ffmpeg_bin_path'] == "/path/to/ffmpeg":
cstr(f"`ffmpeg_bin_path` is not set in `{WAS_CONFIG_FILE}` config file. Will attempt to use system ffmpeg binaries if available.").warning.print()
else:
if was_config.__contains__('show_startup_junk'):
if was_config['show_startup_junk']:
cstr(f"`ffmpeg_bin_path` is set to: {was_config['ffmpeg_bin_path']}").msg.print()
else:
cstr(f"OpenCV Python FFMPEG support is not enabled\033[0m. OpenCV Python FFMPEG support, and FFMPEG binaries is required for video writing.").warning.print()
except ImportError:
cstr("OpenCV Python module cannot be found. Attempting install...").warning.print()
install_package(
package='opencv-python-headless[ffmpeg]',
uninstall_first=['opencv-python', 'opencv-python-headless[ffmpeg]']
)
def check_deps():
# opencv-python-headless handling
if 'opencv-python' in packages() or 'opencv-python-headless' in packages():
try:
import cv2
build_info = ' '.join(cv2.getBuildInformation().split())
if "FFMPEG: YES" in build_info:
if was_config.__contains__('show_startup_junk'):
if was_config['show_startup_junk']:
cstr("OpenCV Python FFMPEG support is enabled").msg.print()
if was_config.__contains__('ffmpeg_bin_path'):
if was_config['ffmpeg_bin_path'] == "/path/to/ffmpeg":
cstr(f"`ffmpeg_bin_path` is not set in `{WAS_CONFIG_FILE}` config file. Will attempt to use system ffmpeg binaries if available.").warning.print()
else:
if was_config.__contains__('show_startup_junk'):
if was_config['show_startup_junk']:
cstr(f"`ffmpeg_bin_path` is set to: {was_config['ffmpeg_bin_path']}").msg.print()
else:
cstr(f"OpenCV Python FFMPEG support is not enabled\033[0m. OpenCV Python FFMPEG support, and FFMPEG binaries is required for video writing.").warning.print()
except ImportError:
cstr("OpenCV Python module cannot be found. Attempting install...").warning.print()
install_package(
package='opencv-python-headless[ffmpeg]',
uninstall_first=['opencv-python', 'opencv-python-headless[ffmpeg]']
)
try:
import cv2
cstr("OpenCV Python installed.").msg.print()
except ImportError:
cstr("OpenCV Python module still cannot be imported. There is a system conflict.").error.print()
else:
install_package('opencv-python-headless[ffmpeg]')
try:
import cv2
cstr("OpenCV Python installed.").msg.print()
except ImportError:
cstr("OpenCV Python module still cannot be imported. There is a system conflict.").error.print()
else:
install_package('opencv-python-headless[ffmpeg]')
try:
import cv2
cstr("OpenCV Python installed.").msg.print()
except ImportError:
cstr("OpenCV Python module still cannot be imported. There is a system conflict.").error.print()

# scipy handling
if 'scipy' not in packages():
install_package('scipy')
# scipy handling
if 'scipy' not in packages():
install_package('scipy')
try:
import scipy
except ImportError as e:
cstr("Unable to import tools for certain masking procedures.").msg.print()
print(e)

# scikit-image handling
try:
import scipy
import skimage
except ImportError as e:
cstr("Unable to import tools for certain masking procedures.").msg.print()
print(e)

# scikit-image handling
try:
import skimage
except ImportError as e:
install_package(
package='scikit-image',
uninstall_first=['scikit-image']
)
import skimage
install_package(
package='scikit-image',
uninstall_first=['scikit-image']
)
import skimage
# Check for dependencies
# check_deps()

was_conf = getSuiteConfig()

Expand Down