Skip to content

Commit

Permalink
Different options for flux dev and flux schnell.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 4, 2024
1 parent 76cdd93 commit 66f5bf1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui_tensorrt"
description = "TensorRT Node for ComfyUI\nThis node enables the best performance on NVIDIA RTX™ Graphics Cards (GPUs) for Stable Diffusion by leveraging NVIDIA TensorRT."
version = "0.1.6"
version = "0.1.7"
license = "LICENSE"
dependencies = [
"tensorrt>=10.0.1",
Expand Down
11 changes: 8 additions & 3 deletions tensorrt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class TensorRTLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux"], ),
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
Expand Down Expand Up @@ -154,11 +154,16 @@ def load_unet(self, unet_name, model_type):
conf = comfy.supported_models.AuraFlow({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
elif model_type == "flux":
elif model_type == "flux_dev":
conf = comfy.supported_models.Flux({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
unet.dtype = torch.bfloat16 #TODO: autodetect
elif model_type == "flux_schnell":
conf = comfy.supported_models.FluxSchnell({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
unet.dtype = torch.bfloat16 #TODO: autodetect
model.diffusion_model = unet
model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting

Expand All @@ -168,4 +173,4 @@ def load_unet(self, unet_name, model_type):

NODE_CLASS_MAPPINGS = {
"TensorRTLoader": TensorRTLoader,
}
}

0 comments on commit 66f5bf1

Please sign in to comment.