Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Collection Loader make optional LoRA Collection input #7579

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
23 changes: 14 additions & 9 deletions invokeai/app/invocations/flux_lora_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
version="1.2.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a FLUX transformer."""

loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)

transformer: Optional[TransformerField] = InputField(
Expand All @@ -119,7 +119,16 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []

if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)

if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)

for lora in loras:
if lora is None:
continue
assert type(lora) is LoRAField
if lora.lora.key in added_loras:
continue

Expand All @@ -130,14 +139,10 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:

added_loras.append(lora.lora.key)

if self.transformer is not None:
if output.transformer is None:
output.transformer = self.transformer.model_copy(deep=True)
if self.transformer is not None and output.transformer is not None:
output.transformer.loras.append(lora)

if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
if self.clip is not None and output.clip is not None:
output.clip.loras.append(lora)

return output
56 changes: 33 additions & 23 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
lora_key = self.lora.key

if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
raise Exception(f"Unknown lora: {lora_key}!")

if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
Expand Down Expand Up @@ -256,12 +256,12 @@ def invoke(self, context: InvocationContext) -> LoRASelectorOutput:
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))


@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.1.0")
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""

loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
Expand All @@ -281,7 +281,15 @@ def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []

if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)

for lora in loras:
if lora is None:
continue
assert type(lora) is LoRAField
if lora.lora.key in added_loras:
continue

Expand All @@ -292,14 +300,10 @@ def invoke(self, context: InvocationContext) -> LoRALoaderOutput:

added_loras.append(lora.lora.key)

if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
if self.unet is not None and output.unet is not None:
output.unet.loras.append(lora)

if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
if self.clip is not None and output.clip is not None:
output.clip.loras.append(lora)

return output
Expand Down Expand Up @@ -399,13 +403,13 @@ def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
title="SDXL LoRA Collection Loader",
tags=["model"],
category="model",
version="1.0.0",
version="1.1.0",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""

loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
Expand All @@ -431,7 +435,19 @@ def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []

if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)

if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)

if self.clip2 is not None:
output.clip2 = self.clip2.model_copy(deep=True)

for lora in loras:
if lora is None:
continue
assert type(lora) is LoRAField
if lora.lora.key in added_loras:
continue

Expand All @@ -442,19 +458,13 @@ def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:

added_loras.append(lora.lora.key)

if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
if self.unet is not None and output.unet is not None:
output.unet.loras.append(lora)

if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
if self.clip is not None and output.clip is not None:
output.clip.loras.append(lora)

if self.clip2 is not None:
if output.clip2 is None:
output.clip2 = self.clip2.model_copy(deep=True)
if self.clip2 is not None and output.clip2 is not None:
output.clip2.loras.append(lora)

return output
Expand All @@ -472,7 +482,7 @@ def invoke(self, context: InvocationContext) -> VAEOutput:
key = self.vae_model.key

if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!")
raise Exception(f"Unknown vae: {key}!")

return VAEOutput(vae=VAEField(vae=self.vae_model))

Expand Down
11 changes: 3 additions & 8 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6181,7 +6181,7 @@ export type components = {
* @description LoRA models and weights. May be a single LoRA or collection.
* @default null
*/
loras?: components["schemas"]["LoRAField"] | components["schemas"]["LoRAField"][];
loras?: components["schemas"]["LoRAField"] | components["schemas"]["LoRAField"][] | null;
/**
* Transformer
* @description Transformer
Expand Down Expand Up @@ -7716,11 +7716,6 @@ export type components = {
* @description Gets the bounding box of the given mask image.
*/
GetMaskBoundingBoxInvocation: {
/**
* @description Optional metadata to be saved with the image
* @default null
*/
metadata?: components["schemas"]["MetadataField"] | null;
/**
* Id
* @description The id of this instance of an invocation. Must be unique among all instances of invocations.
Expand Down Expand Up @@ -12173,7 +12168,7 @@ export type components = {
* @description LoRA models and weights. May be a single LoRA or collection.
* @default null
*/
loras?: components["schemas"]["LoRAField"] | components["schemas"]["LoRAField"][];
loras?: components["schemas"]["LoRAField"] | components["schemas"]["LoRAField"][] | null;
/**
* UNet
* @description UNet (scheduler, LoRAs)
Expand Down Expand Up @@ -16021,7 +16016,7 @@ export type components = {
* @description LoRA models and weights. May be a single LoRA or collection.
* @default null
*/
loras?: components["schemas"]["LoRAField"] | components["schemas"]["LoRAField"][];
loras?: components["schemas"]["LoRAField"] | components["schemas"]["LoRAField"][] | null;
/**
* UNet
* @description UNet (scheduler, LoRAs)
Expand Down
Loading