interpolation frames
+ */
+ set animationInterpolation(newFrames) {
+ if (!isEquivalent(this.animationInterpolation, newFrames)) {
+ this.publish("engineAnimationInterpolationChange", newFrames);
+ }
+ this.kwargs.interpolate_frames = newFrames;
+ }
+
+ /**
+ * @return int Animation frame rate
+ */
+ get animationRate() {
+ return this.kwargs.frame_rate || 8;
+ }
+
+ /**
+ * @param int Animation frame rate
+ */
+ set animationRate(newRate) {
+ if (!isEquivalent(this.animationRate, newRate)) {
+ this.publish("engineAnimationRateChange", newRate);
+ }
+ this.kwargs.frame_rate = newRate;
+ }
+
+ /**
+ * @return bool Tile along the horizontal dimension
+ */
+ get tileHorizontal() {
+ let tile = this.kwargs.tile;
+ if (isEmpty(tile)) return false;
+ return tile[0];
+ }
+
+ /**
+ * @param bool Tile along the horizontal dimension
+ */
+ set tileHorizontal(newTile) {
+ if (newTile !== this.tileHorizontal) {
+ this.publish("engineTileHorizontalChange", newTile);
+ }
+ this.kwargs.tile = [newTile, this.tileVertical];
+ }
+
+ /**
+ * @return bool Tile along the vertical dimension
+ */
+ get tileVertical() {
+ let tile = this.kwargs.tile;
+ if (isEmpty(tile)) return false;
+ return tile[1];
+ }
+
+ /**
+ * @param bool Tile along the horizontal dimension
+ */
+ set tileVertical(newTile) {
+ if (newTile !== this.tileVertical) {
+ this.publish("engineTileVerticalChange", newTile);
+ }
+ this.kwargs.tile = [this.tileHorizontal, newTile];
+ }
+
+ /**
+ * @return bool outpaint empty space
+ */
+ get outpaint() {
+ return isEmpty(this.kwargs.outpaint) ? true : this.kwargs.outpaint;
+ }
+
+ /**
+ * @param bool outpaint empty space
+ */
+ set outpaint(newOutpaint) {
+ if (this.outpaint !== newOutpaint) {
+ this.publish("engineOutpaintChange", newOutpaint);
+ }
+ }
+
/**
* On initialization, create DOM elements related to invocations.
*/
@@ -813,10 +1115,9 @@ class InvocationController extends Controller {
E.invocationTask().hide(),
E.invocationRemaining().hide()
);
- this.invocationSampleChooser = E.invocationSampleChooser().hide();
this.engineStop = E.engineStop().content("Stop Engine").on("click", () => { this.stopEngine() });
- (await this.images.getNode()).append(this.loadingBar).append(this.invocationSampleChooser);
this.application.container.appendChild(await this.engineStop.render());
+ this.application.container.appendChild(await this.loadingBar.render());
this.subscribe("engineReady", () => {
this.enableStop();
});
@@ -828,13 +1129,6 @@ class InvocationController extends Controller {
});
}
- /**
- * Hides the sample chooser from outside the controller.
- */
- hideSampleChooser() {
- this.invocationSampleChooser.hide();
- }
-
/**
* Enables the engine stopper.
*/
@@ -887,6 +1181,8 @@ class InvocationController extends Controller {
parseInt(invocationPayload.height) || 512
);
} else {
+ this.startSample = true;
+ this.application.samples.resetState();
await this.canvasInvocation(result.uuid);
}
}
@@ -913,6 +1209,34 @@ class InvocationController extends Controller {
}
}
+ /**
+ * Sets the sample images on the canvas and chooser
+ */
+ setSampleImages(images) {
+ // Get IDs from images
+ let isAnimation = !isEmpty(this.animationFrames) && this.animationFrames > 0;
+ this.application.samples.setSamples(
+ images,
+ isAnimation
+ );
+ if (this.startSample) {
+ if (isAnimation) {
+ this.application.samples.setLoop(true);
+ this.application.samples.setPlay(true);
+ } else {
+ this.application.samples.setActive(0);
+ }
+ this.startSample = false;
+ }
+ }
+
+ /**
+ * Sets the sample video in the viewer, enabling video operations
+ */
+ setSampleVideo(video) {
+ this.application.samples.setVideo(video);
+ }
+
/**
* This is the meat and potatoes of watching an invocation as it goes; this method will be called by implementing functions with callbacks.
* We estimate using total duration, this will end up being more accurate over the entirety of the invocation is they will typically
@@ -923,7 +1247,7 @@ class InvocationController extends Controller {
* @param callable onError A callback that is called when an error occur.
* @param callable onEstimatedDuration A callback that will receive (int $millisecondsRemaining) when new estimates are available.
*/
- async monitorInvocation(uuid, onTaskChanged, onImagesReceived, onError, onEstimatedDuration) {
+ async monitorInvocation(uuid, onTaskChanged, onImagesReceived, onVideoReceived, onError, onEstimatedDuration) {
const initialInterval = this.application.config.model.invocation.interval || 1000;
const queuedInterval = this.application.config.model.queue.interval || 5000;
const consecutiveErrorCutoff = this.application.config.model.invocation.errors.consecutive || 2;
@@ -932,6 +1256,7 @@ class InvocationController extends Controller {
if (onTaskChanged === undefined) onTaskChanged = () => {};
if (onError === undefined) onError = () => {};
if (onEstimatedDuration === undefined) onEstimatedDuration = () => {};
+ if (onVideoReceived === undefined) onVideoReceived = () => {};
let start = (new Date()).getTime(),
lastTask,
@@ -962,7 +1287,6 @@ class InvocationController extends Controller {
onError();
return;
}
-
if (invokeResult.total !== lastTotal) {
if (!isEmpty(lastTotal)) {
lastTotalDeltaTime = (new Date()).getTime();
@@ -986,6 +1310,10 @@ class InvocationController extends Controller {
isCompleted = invokeResult.status === "completed";
onImagesReceived(imagePaths, isCompleted);
}
+ if (!isEmpty(invokeResult.video)) {
+ let videoPath = `/api/invocation/${invokeResult.video}`;
+ onVideoReceived(videoPath);
+ }
if (invokeResult.status === "error") {
this.notify("error", "Invocation Failed", invokeResult.message);
onError();
@@ -1009,59 +1337,6 @@ class InvocationController extends Controller {
checkInvocation();
}
- /**
- * Sets the sample images on the canvas and chooser
- */
- setSampleImages(images) {
- let currentSampleCount = this.invocationSampleChooser.children().length;
- if (isEmpty(images)) {
- this.images.hideCurrentInvocation();
- this.invocationSampleChooser.empty().hide();
- return;
- } else if (currentSampleCount === 0) {
- if (this.invocationSampleIndex === null) {
- this.invocationSampleIndex = 0;
- }
- this.invocationSampleChooser.append(
- E.invocationSample().class("no-sample").content("×").on("click", () => {
- this.invocationSampleIndex = null;
- this.images.hideCurrentInvocation();
- })
- );
- } else {
- currentSampleCount--;
- }
-
- for (let i = 0; i < images.length; i++) {
- let imageNode = new Image();
- if (i >= currentSampleCount) {
- // Go ahead and add it right away
- imageNode.src = images[i];
- let sampleNode = E.invocationSample().content(imageNode).on("click", () => {
- this.images.setCurrentInvocationImage(images[i]);
- this.invocationSampleIndex = i;
- });
- this.invocationSampleChooser.append(sampleNode);
- } else {
- // Wait for it to load to avoid flash
- let imageContainer = this.invocationSampleChooser.getChild(i+1);
- imageContainer.off("click").on("click", () => {
- this.images.setCurrentInvocationImage(images[i]);
- this.invocationSampleIndex = i;
- });
- imageNode.onload = () => {
- imageContainer.content(imageNode);
- }
- imageNode.src = images[i];
- }
- }
-
- if (!isEmpty(this.invocationSampleIndex)) {
- this.images.setCurrentInvocationImage(images[this.invocationSampleIndex]);
- }
- this.invocationSampleChooser.show().render();
- }
-
/**
* Monitors an invocation on the canvas.
*
@@ -1163,6 +1438,9 @@ class InvocationController extends Controller {
updateImages();
}
},
+ onVideoReceived = async (video) => {
+ this.setSampleVideo(video);
+ },
onTaskChanged = (newTask) => {
lastTask = newTask;
if (isEmpty(newTask)) {
@@ -1199,12 +1477,9 @@ class InvocationController extends Controller {
};
this.loadingBar.addClass("loading");
- this.images.hideCurrentInvocation();
- this.invocationSampleChooser.empty();
- this.invocationSampleIndex = null;
window.requestAnimationFrame(() => updateEstimate());
- this.monitorInvocation(uuid, onTaskChanged, onImagesReceived, onError, onEstimatedDuration);
+ this.monitorInvocation(uuid, onTaskChanged, onImagesReceived, onVideoReceived, onError, onEstimatedDuration);
await waitFor(() => complete);
taskNode.empty().hide();
this.loadingBar.removeClass("loading");
@@ -1259,47 +1534,6 @@ class InvocationController extends Controller {
this.monitorInvocation(uuid, onImagesReceived, onError);
await waitFor(() => receivedFirstImage === true);
}
-
- /**
- * Gets default state, no samples
- */
- getDefaultState() {
- return {
- "samples": null,
- "sample": null
- };
- }
-
- /**
- * Get state is only for UI; only use the sample choosers here
- */
- getState(includeImages = true) {
- if (!includeImages) {
- return this.getDefaultState();
- }
- let chooserChildren = this.invocationSampleChooser.children();
- if (chooserChildren.length < 2) {
- return this.getDefaultState();
- }
- return {
- "sample": this.invocationSampleIndex,
- "samples": chooserChildren.slice(1).map((container) => container.getChild(0).src)
- };
- }
-
- /**
- * Set state is only for UI; set the sample choosers here
- */
- setState(newState) {
- this.invocationSampleChooser.empty();
- if (isEmpty(newState) || isEmpty(newState.samples)) {
- this.invocationSampleIndex = null;
- this.invocationSampleChooser.hide();
- } else {
- this.invocationSampleIndex = newState.sample;
- this.setSampleImages(newState.samples);
- }
- }
}
export { InvocationController };
diff --git a/src/js/controller/common/layers.mjs b/src/js/controller/common/layers.mjs
new file mode 100644
index 00000000..d6464842
--- /dev/null
+++ b/src/js/controller/common/layers.mjs
@@ -0,0 +1,915 @@
+/** @module controllers/common/layers */
+import { isEmpty } from "../../base/helpers.mjs";
+import { ElementBuilder } from "../../base/builder.mjs";
+import { Controller } from "../base.mjs";
+import { View } from "../../view/base.mjs";
+import { ImageView } from "../../view/image.mjs";
+import { ToolbarView } from "../../view/menu.mjs";
+import {
+ ImageEditorScribbleNodeOptionsFormView,
+ ImageEditorPromptNodeOptionsFormView,
+ ImageEditorImageNodeOptionsFormView,
+ ImageEditorVideoNodeOptionsFormView
+} from "../../forms/enfugue/image-editor.mjs";
+
+const E = new ElementBuilder();
+
+/**
+ * This view holds the menu for an individual layer.
+ */
+class LayerOptionsView extends View {
+ /**
+ * @var string Tag name
+ */
+ static tagName = "enfugue-layer-options-view";
+
+ /**
+ * @var string Text to show when no options
+ */
+ static placeholderText = "No options available. When you select a layer with options, they will appear in this pane.";
+
+ /**
+ * Sets the form
+ */
+ async setForm(formView) {
+ this.node.content(await formView.getNode());
+ }
+
+ /**
+ * Resets the form
+ */
+ async resetForm() {
+ this.node.content(E.div().class("placeholder").content(this.constructor.placeholderText));
+ }
+
+ /**
+ * On first build, append placeholder
+ */
+ async build() {
+ let node = await super.build();
+ node.content(
+ E.div().class("placeholder").content(this.constructor.placeholderText)
+ );
+ return node;
+ }
+}
+
+/**
+ * This view allows you to select between individual layers
+ */
+class LayersView extends View {
+ /**
+ * @var string Tag name
+ */
+ static tagName = "enfugue-layers-view";
+
+ /**
+ * @var string Text to show when no layers
+ */
+ static placeholderText = "No layers yet. Use the buttons above to add layers, drag and drop videos or images onto the canvas, or paste media from your clipboard.";
+
+ /**
+ * On construct, create toolbar
+ */
+ constructor(config) {
+ super(config);
+ this.toolbar = new ToolbarView(config);
+ }
+
+ /**
+ * Empties the layers
+ */
+ async emptyLayers() {
+ this.node.content(
+ await this.toolbar.getNode(),
+ E.div().class("placeholder").content(this.constructor.placeholderText)
+ );
+ }
+
+ /**
+ * Adds a layer
+ */
+ async addLayer(newLayer, resetLayers = false) {
+ if (resetLayers) {
+ this.node.content(
+ await this.toolbar.getNode(),
+ await newLayer.getNode()
+ );
+ } else {
+ this.node.append(await newLayer.getNode());
+ this.node.render();
+ }
+ }
+
+ /**
+ * On first build, append placeholder
+ */
+ async build() {
+ let node = await super.build();
+ node.content(
+ await this.toolbar.getNode(),
+ E.div().class("placeholder").content(this.constructor.placeholderText)
+ );
+ node.on("drop", (e) => {
+ e.preventDefault();
+ e.stopPropagation();
+ });
+ return node;
+ }
+}
+
+/**
+ * This class represents an individual layer
+ */
+class LayerView extends View {
+ /**
+ * @var int Preview width
+ */
+ static previewWidth = 30;
+
+ /**
+ * @var int Preview height
+ */
+ static previewHeight = 30;
+
+ /**
+ * @var string tag name in the layer view
+ */
+ static tagName = "enfugue-layer-view";
+
+ /**
+ * On construct, store editor node and form
+ */
+ constructor(controller, editorNode, form) {
+ super(controller.config);
+ this.controller = controller;
+ this.editorNode = editorNode;
+ this.form = form;
+ this.isActive = false;
+ this.isVisible = true;
+ this.isLocked = false;
+ this.previewImage = new ImageView(controller.config, null, false);
+ this.editorNode.onResize(() => this.resized());
+ this.getLayerImage().then((image) => this.previewImage.setImage(image));
+ this.form.onSubmit(() => { setTimeout(() => { this.drawPreviewImage(); }, 150); });
+ this.subtitle = null;
+ }
+
+ /**
+ * @var default foreground style
+ */
+ get foregroundStyle() {
+ return window.getComputedStyle(document.documentElement).getPropertyValue("--theme-color-primary");
+ }
+
+ /**
+ * Gets the layer image
+ */
+ async getLayerImage() {
+ let width = this.controller.images.width,
+ height = this.controller.images.height,
+ maxDimension = Math.max(width, height),
+ scale = this.constructor.previewWidth / maxDimension,
+ widthRatio = width / maxDimension,
+ heightRatio = height / maxDimension,
+ previewWidth = this.constructor.previewWidth * widthRatio,
+ previewHeight = this.constructor.previewHeight * heightRatio,
+ nodeState = this.editorNode.getState(true),
+ scaledX = nodeState.x * scale,
+ scaledY = nodeState.y * scale,
+ scaledWidth = nodeState.w * scale,
+ scaledHeight = nodeState.h * scale,
+ canvas = document.createElement("canvas");
+
+ this.lastCanvasWidth = width;
+ this.lastCanvasHeight = height;
+ this.lastNodeWidth = nodeState.w;
+ this.lastNodeHeight = nodeState.h;
+ this.lastNodeX = nodeState.x;
+ this.lastNodeY = nodeState.y;
+
+ canvas.width = previewWidth;
+ canvas.height = previewHeight;
+
+ let context = canvas.getContext("2d");
+
+ if (nodeState.src) {
+ let imageSource = nodeState.src;
+ if (
+ imageSource.startsWith("data:video") ||
+ imageSource.endsWith("mp4") ||
+ imageSource.endsWith("webp") ||
+ imageSource.endsWith("avi") ||
+ imageSource.endsWith("mov")
+ ) {
+ // Get the current frame
+ let frameCanvas = document.createElement("canvas");
+ frameCanvas.width = this.editorNode.content.video.videoWidth;
+ frameCanvas.height = this.editorNode.content.video.videoHeight;
+
+ let frameContext = frameCanvas.getContext("2d");
+ frameContext.drawImage(this.editorNode.content.video, 0, 0);
+ imageSource = frameCanvas.toDataURL();
+ }
+
+ let imageView = new ImageView(this.config, imageSource);
+ await imageView.waitForLoad();
+
+ let imageTop = 0,
+ imageLeft = 0,
+ scaledImageWidth = imageView.width * scale,
+ scaledImageHeight = imageView.height * scale,
+ nodeAnchor = isEmpty(nodeState.anchor)
+ ? null
+ : nodeState.anchor.split("-");
+
+ if (nodeState.fit === "cover" || nodeState.fit === "contain") {
+ let scaledWidthRatio = scaledWidth / imageView.width,
+ scaledHeightRatio = scaledHeight / imageView.height;
+
+ if (nodeState.fit === "cover") {
+ let horizontalWidth = Math.ceil(imageView.width * scaledWidthRatio),
+ horizontalHeight = Math.ceil(imageView.height * scaledWidthRatio),
+ verticalWidth = Math.ceil(imageView.width * scaledHeightRatio),
+ verticalHeight = Math.ceil(imageView.height * scaledHeightRatio);
+
+ if (scaledWidth <= horizontalWidth && scaledHeight <= horizontalHeight) {
+ scaledImageWidth = horizontalWidth;
+ scaledImageHeight = horizontalHeight;
+ if (!isEmpty(nodeAnchor)) {
+ switch (nodeAnchor[0]) {
+ case "center":
+ imageTop = Math.floor((scaledHeight / 2) - (scaledImageHeight / 2));
+ break;
+ case "bottom":
+ imageTop = scaledHeight - scaledImageHeight;
+ break;
+ }
+ }
+ } else if(scaledWidth <= verticalWidth && scaledHeight <= verticalHeight) {
+ scaledImageWidth = verticalWidth;
+ scaledImageHeight = verticalHeight;
+ if (!isEmpty(nodeAnchor)) {
+ switch (nodeAnchor[1]) {
+ case "center":
+ imageLeft = Math.floor((scaledWidth / 2) - (scaledImageWidth / 2));
+ break;
+ case "right":
+ imageLeft = scaledWidth - scaledImageWidth;
+ break;
+ }
+ }
+ }
+ } else {
+ let horizontalWidth = Math.floor(imageView.width * scaledWidthRatio),
+ horizontalHeight = Math.floor(imageView.height * scaledWidthRatio),
+ verticalWidth = Math.floor(imageView.width * scaledHeightRatio),
+ verticalHeight = Math.floor(imageView.height * scaledHeightRatio);
+
+ if (scaledWidth >= horizontalWidth && scaledHeight >= horizontalHeight) {
+ scaledImageWidth = horizontalWidth;
+ scaledImageHeight = horizontalHeight;
+ if (!isEmpty(nodeAnchor)) {
+ switch (nodeAnchor[0]) {
+ case "center":
+ imageTop = Math.floor((scaledHeight / 2) - (scaledImageHeight / 2));
+ break;
+ case "bottom":
+ imageTop = scaledHeight - scaledImageHeight;
+ break;
+ }
+ }
+ } else if (scaledWidth >= verticalWidth && scaledHeight >= verticalHeight) {
+ scaledImageWidth = verticalWidth;
+ scaledImageHeight = verticalHeight;
+ if (!isEmpty(nodeAnchor)) {
+ switch (nodeAnchor[1]) {
+ case "center":
+ imageLeft = Math.floor((scaledWidth / 2) - (scaledImageWidth / 2));
+ break;
+ case "right":
+ imageLeft = scaledWidth - scaledImageWidth;
+ break;
+ }
+ }
+ }
+ }
+ } else if (nodeState.fit === "stretch") {
+ scaledImageWidth = scaledWidth;
+ scaledImageHeight = scaledHeight;
+ } else if (!isEmpty(nodeAnchor)) {
+ switch (nodeAnchor[0]) {
+ case "center":
+ imageTop = Math.floor((scaledHeight / 2) - (scaledImageHeight / 2));
+ break;
+ case "bottom":
+ imageTop = scaledHeight - scaledImageHeight;
+ break;
+ }
+ switch (nodeAnchor[1]) {
+ case "center":
+ imageLeft = Math.floor((scaledWidth / 2) - (scaledImageWidth / 2));
+ break;
+ case "right":
+ imageLeft = scaledWidth - scaledImageWidth;
+ break;
+ }
+ }
+
+ context.beginPath();
+ context.rect(scaledX, scaledY, scaledWidth, scaledHeight);
+ context.clip()
+
+ context.drawImage(
+ imageView.image,
+ scaledX + imageLeft,
+ scaledY + imageTop,
+ scaledImageWidth,
+ scaledImageHeight
+ );
+ } else {
+ context.fillStyle = this.foregroundStyle;
+ context.fillRect(scaledX, scaledY, scaledWidth, scaledHeight);
+ }
+
+ return canvas.toDataURL();
+ }
+
+ /**
+ * Triggers re-rendering of preview image if needed
+ */
+ async resized() {
+ let width = this.controller.images.width,
+ height = this.controller.images.height,
+ nodeState = this.editorNode.getState();
+
+ if (width !== this.lastCanvasWidth ||
+ height !== this.lastCanvasHeight ||
+ nodeState.w !== this.lastNodeWidth ||
+ nodeState.h !== this.lastNodeHeight ||
+ nodeState.x !== this.lastNodeX ||
+ nodeState.y !== this.lastNodeY
+ ) {
+ this.drawPreviewImage();
+ }
+ }
+
+ /**
+ * Re-renders the preview image
+ */
+ async drawPreviewImage() {
+ this.previewImage.setImage(await this.getLayerImage());
+ }
+
+ /**
+ * Removes this layer
+ */
+ async remove() {
+ this.controller.removeLayer(this);
+ }
+
+ /**
+ * Enables/disables a layer
+ */
+ async setActive(isActive) {
+ this.isActive = isActive;
+ if (this.isActive) {
+ this.addClass("active");
+ } else {
+ this.removeClass("active");
+ }
+ }
+
+ /**
+ * Hides/shows a layer
+ */
+ async setVisible(isVisible) {
+ this.isVisible = isVisible;
+ if (!isEmpty(this.hideShowLayer)) {
+ let hideShowLayerIcon = this.isVisible ? "fa-solid fa-eye": "fa-solid fa-eye-slash";
+ this.hideShowLayer.setIcon(hideShowLayerIcon);
+ }
+ if (this.isVisible) {
+ this.editorNode.show();
+ } else {
+ this.editorNode.hide();
+ }
+ }
+
+ /**
+ * Locks.unlocks a layer
+ */
+ async setLocked(isLocked) {
+ this.isLocked = isLocked;
+ if (!isEmpty(this.lockUnlockLayer)) {
+ let lockUnlockLayerIcon = this.isLocked ? "fa-solid fa-lock" : "fa-solid fa-lock-open";
+ this.lockUnlockLayer.setIcon(lockUnlockLayerIcon);
+ }
+ if (this.isLocked) {
+ this.editorNode.addClass("locked");
+ } else {
+ this.editorNode.removeClass("locked");
+ }
+ }
+
+ /**
+ * Gets the state of editor node and form
+ */
+ getState(includeImages = true) {
+ return {
+ ...this.editorNode.getState(includeImages),
+ ...this.form.values,
+ ...{
+ "isLocked": this.isLocked,
+ "isActive": this.isActive,
+ "isVisible": this.isVisible,
+ }
+ };
+ }
+
+ /**
+ * Sets the state of the editor node and form, then populates DOM
+ */
+ async setState(newState) {
+ await this.editorNode.setState(newState);
+ await this.form.setValues(newState);
+ }
+
+ /**
+ * Sets the name
+ */
+ async setName(name) {
+ if (this.node !== undefined) {
+ this.node.find("span.name").content(name);
+ }
+ }
+
+ /**
+ * Sets the subtitle
+ */
+ async setSubtitle(subtitle) {
+ this.subtitle = subtitle;
+ if (this.node !== undefined) {
+ let subtitleNode = this.node.find("span.subtitle");
+ if (isEmpty(subtitle)) {
+ subtitleNode.empty().hide();
+ } else {
+ subtitleNode.content(subtitle).show();
+ }
+ }
+ }
+
+ /**
+ * On build, populate DOM with known details and buttons
+ */
+ async build() {
+ let node = await super.build();
+
+ this.toolbar = new ToolbarView(this.config);
+
+ let hideShowLayerText = this.isVisible ? "Hide Layer" : "Show Layer",
+ hideShowLayerIcon = this.isVisible ? "fa-solid fa-eye": "fa-solid fa-eye-slash";
+
+ this.hideShowLayer = await this.toolbar.addItem(hideShowLayerText, hideShowLayerIcon);
+
+ let lockUnlockLayerText = this.isLocked ? "Unlock Layer" : "Lock Layer",
+ lockUnlockLayerIcon = this.isLocked ? "fa-solid fa-lock" : "fa-solid fa-lock-open";
+
+ this.lockUnlockLayer = await this.toolbar.addItem("Lock Layer", "fa-solid fa-lock-open");
+ this.hideShowLayer.onClick(() => this.setVisible(!this.isVisible));
+ this.lockUnlockLayer.onClick(() => this.setLocked(!this.isLocked));
+
+ let nameNode = E.span().class("name").content(this.editorNode.name),
+ subtitleNode = E.span().class("subtitle");
+
+ if (isEmpty(this.subtitle)) {
+ subtitleNode.hide();
+ } else {
+ subtitleNode.content(this.subtitle);
+ }
+
+ node.content(
+ await this.hideShowLayer.getNode(),
+ await this.lockUnlockLayer.getNode(),
+ E.div().class("title").content(nameNode, subtitleNode),
+ await this.previewImage.getNode(),
+ E.button().content("×").class("close").on("click", () => this.remove())
+ )
+ .attr("draggable", "true")
+ .on("dragstart", (e) => {
+ e.dataTransfer.effectAllowed = "move";
+ this.controller.draggedLayer = this;
+ this.addClass("dragging");
+ })
+ .on("dragleave", (e) => {
+ this.removeClass("drag-target-below").removeClass("drag-target-above");
+ if (this.controller.dragTarget === this) {
+ this.controller.dragTarget = null;
+ }
+ })
+ .on("dragover", (e) => {
+ if (this.controller.draggedLayer !== this) {
+ let dropBelow = e.layerY > e.target.getBoundingClientRect().height / 2;
+ if (dropBelow) {
+ this.removeClass("drag-target-above").addClass("drag-target-below");
+ } else {
+ this.addClass("drag-target-above").removeClass("drag-target-below");
+ }
+ this.controller.dragTarget = this;
+ this.controller.dropBelow = dropBelow;
+ }
+ })
+ .on("dragend", (e) => {
+ this.controller.dragEnd();
+ this.removeClass("dragging").removeClass("drag-target-below").removeClass("drag-target-above");
+ e.preventDefault();
+ e.stopPropagation();
+ })
+ .on("click", (e) => {
+ this.controller.activate(this);
+ })
+ .on("drop", (e) => {
+ e.preventDefault();
+ e.stopPropagation();
+ });
+
+ return node;
+ }
+}
+
+/**
+ * The LayersController manages the layer menu and holds state for each layer
+ */
+class LayersController extends Controller {
+ /**
+ * Removes layers
+ */
+ removeLayer(layerToRemove, removeNode = true) {
+ if (removeNode) {
+ layerToRemove.editorNode.remove(false);
+ }
+ let layerIndex = this.layers.indexOf(layerToRemove);
+ if (layerIndex === -1) {
+ console.error("Couldn't find", layerToRemove);
+ return;
+ }
+ this.layers = this.layers.slice(0, layerIndex).concat(this.layers.slice(layerIndex+1));
+ if (this.layers.length === 0) {
+ this.layersView.emptyLayers();
+ this.layerOptions.resetForm();
+ } else {
+ this.layersView.node.remove(layerToRemove.node.element);
+ }
+ if (layerToRemove.isActive) {
+ this.layerOptions.resetForm();
+ }
+ this.layersChanged();
+ }
+
+ /**
+ * Fired when done dragging layers
+ */
+ dragEnd() {
+ if (!isEmpty(this.draggedLayer) && !isEmpty(this.dragTarget) && this.draggedLayer !== this.dragTarget) {
+ this.draggedLayer.removeClass("dragging");
+ this.dragTarget.removeClass("drag-target-above").removeClass("drag-target-below");
+
+ let layerIndex = this.layers.indexOf(this.draggedLayer),
+ targetIndex = this.layers.indexOf(this.dragTarget);
+
+ if (targetIndex > layerIndex) {
+ targetIndex--;
+ }
+ if (!this.dropBelow) {
+ targetIndex++;
+ }
+
+ if (targetIndex !== layerIndex) {
+ // Re-order on canvas (inverse)
+ this.images.reorderNode(targetIndex, this.draggedLayer.editorNode);
+
+ // Re-order in memory
+ this.layers = this.layers.filter(
+ (layer) => layer !== this.draggedLayer
+ );
+ this.layers.splice(targetIndex, 0, this.draggedLayer);
+
+ // Re-order in DOM
+ this.layersView.node.remove(this.draggedLayer.node);
+ this.layersView.node.insert(targetIndex + 1, this.draggedLayer.node);
+ this.layersView.node.render();
+
+ // Trigger callbacks
+ this.layersChanged();
+ }
+ }
+ this.draggedLayer = null;
+ this.dragTarget = null;
+ }
+
+ /**
+ * Gets the state of all layers.
+ */
+ getState(includeImages = true) {
+ return {
+ "layers": this.layers.map((layer) => layer.getState(includeImages))
+ }
+ }
+
+ /**
+ * Gets the default state on init.
+ */
+ getDefaultState() {
+ return {
+ "layers": []
+ };
+ }
+
+ /**
+ * Sets the state from memory/file
+ */
+ async setState(newState) {
+ this.emptyLayers();
+ if (!isEmpty(newState.layers)) {
+ for (let layer of newState.layers) {
+ await this.addLayerByState(layer);
+ }
+ this.activateLayer(this.layers.length-1);
+ }
+ }
+
+ /**
+ * Adds a layer by state
+ */
+ async addLayerByState(layer, node = null) {
+ let addedLayer;
+ switch (layer.classname) {
+ case "ImageEditorPromptNodeView":
+ addedLayer = await this.addPromptLayer(false, node, layer.name);
+ break;
+ case "ImageEditorScribbleNodeView":
+ addedLayer = await this.addScribbleLayer(false, node, layer.name);
+ break;
+ case "ImageEditorImageNodeView":
+ addedLayer = await this.addImageLayer(layer.src, false, node, layer.name);
+ break;
+ case "ImageEditorVideoNodeView":
+ addedLayer = await this.addVideoLayer(layer.src, false, node, layer.name);
+ break;
+ default:
+ console.error(`Unknown layer class ${layer.classname}, skipping and dumping layer data.`);
+ console.log(layer);
+ console.log(node);
+ }
+ if (!isEmpty(addedLayer)) {
+ await addedLayer.setState(layer);
+ }
+ return addedLayer;
+ }
+
+ /**
+ * Empties layers
+ */
+ async emptyLayers() {
+ for (let layer of this.layers) {
+ this.images.removeNode(layer.editorNode);
+ }
+ this.layers = [];
+ this.layersView.emptyLayers();
+ this.layerOptions.resetForm();
+ this.layersChanged();
+ }
+
+ /**
+ * Activates a layer by index
+ */
+ async activateLayer(layerIndex) {
+ if (layerIndex === -1) {
+ return;
+ }
+ for (let i = 0; i < this.layers.length; i++) {
+ this.layers[i].setActive(i === layerIndex);
+ }
+ this.layerOptions.setForm(this.layers[layerIndex].form);
+ }
+
+ /**
+ * Activates a layer by layer
+ */
+ activate(layer) {
+ return this.activateLayer(
+ this.layers.indexOf(layer)
+ );
+ }
+
+ /**
+ * Adds a layer
+ */
+ async addLayer(newLayer, activate = true) {
+ // Bind editor node events
+ newLayer.editorNode.onNameChange((newName) => {
+ newLayer.setName(newName, false);
+ });
+ newLayer.editorNode.onClose(() => {
+ this.removeLayer(newLayer, false);
+ });
+ newLayer.form.onSubmit(() => {
+ this.layersChanged();
+ });
+ this.layers.push(newLayer);
+ await this.layersView.addLayer(newLayer, this.layers.length === 1);
+ if (activate) {
+ this.activateLayer(this.layers.length-1);
+ }
+ this.layersChanged();
+ }
+
+ /**
+ * Adds a video layer
+ */
+ async addVideoLayer(videoData, activate = true, videoNode = null, name = "Video") {
+ if (isEmpty(videoNode)) {
+ videoNode = await this.images.addVideoNode(videoData, name);
+ }
+
+ let videoForm = new ImageEditorVideoNodeOptionsFormView(this.config),
+ videoLayer = new LayerView(this, videoNode, videoForm);
+
+ videoForm.onSubmit((values) => {
+ let videoRoles = [];
+ if (values.denoise) {
+ videoRoles.push("Video to Video");
+ }
+ if (values.videoPrompt) {
+ videoRoles.push("Prompt");
+ }
+ if (values.control && !isEmpty(values.controlnetUnits)) {
+ let controlNets = values.controlnetUnits.map((unit) => isEmpty(unit.controlnet) ? "canny" : unit.controlnet),
+ uniqueControlNets = controlNets.filter((v, i) => controlNets.indexOf(v) === i);
+ videoRoles.push(`ControlNet (${uniqueControlNets.join(", ")})`);
+ }
+ let subtitle = isEmpty(videoRoles)
+ ? "Passthrough"
+ : videoRoles.join(", ");
+ videoNode.updateOptions(values);
+ videoLayer.setSubtitle(subtitle);
+ });
+
+ await this.addLayer(videoLayer, activate);
+ return videoLayer;
+ }
+
+ /**
+ * Adds an image layer
+ */
+ async addImageLayer(imageData, activate = true, imageNode = null, name = "Image") {
+ if (imageData instanceof ImageView) {
+ imageData = imageData.src;
+ }
+ if (isEmpty(imageNode)) {
+ imageNode = await this.images.addImageNode(imageData, name);
+ }
+
+ let imageForm = new ImageEditorImageNodeOptionsFormView(this.config),
+ imageLayer = new LayerView(this, imageNode, imageForm);
+
+ imageForm.onSubmit((values) => {
+ let imageRoles = [];
+ if (values.denoise) {
+ imageRoles.push("Image to Image");
+ }
+ if (values.imagePrompt) {
+ imageRoles.push("Prompt");
+ }
+ if (values.control && !isEmpty(values.controlnetUnits)) {
+ let controlNets = values.controlnetUnits.map((unit) => isEmpty(unit.controlnet) ? "canny" : unit.controlnet),
+ uniqueControlNets = controlNets.filter((v, i) => controlNets.indexOf(v) === i);
+ imageRoles.push(`ControlNet (${uniqueControlNets.join(", ")})`);
+ }
+ let subtitle = isEmpty(imageRoles)
+ ? "Passthrough"
+ : imageRoles.join(", ");
+ imageNode.updateOptions(values);
+ imageLayer.setSubtitle(subtitle);
+ });
+
+ await this.addLayer(imageLayer, activate);
+ return imageLayer;
+ }
+
+ /**
+ * Adds a scribble layer
+ */
+ async addScribbleLayer(activate = true, scribbleNode = null, name = "Scribble") {
+ if (isEmpty(scribbleNode)) {
+ scribbleNode = await this.images.addScribbleNode(name);
+ }
+
+ let scribbleForm = new ImageEditorScribbleNodeOptionsFormView(this.config),
+ scribbleLayer = new LayerView(this, scribbleNode, scribbleForm),
+ scribbleDrawTimer;
+
+ scribbleNode.content.onDraw(() => {
+ this.activate(scribbleLayer);
+ clearTimeout(scribbleDrawTimer);
+ scribbleDrawTimer = setTimeout(() => {
+ scribbleLayer.drawPreviewImage();
+ }, 100);
+ });
+ await this.addLayer(scribbleLayer, activate);
+
+ return scribbleLayer;
+ }
+
+ /**
+ * Adds a prompt layer
+ */
+ async addPromptLayer(activate = true, promptNode = null, name = "Prompt") {
+ if (isEmpty(promptNode)) {
+ promptNode = await this.images.addPromptNode(name);
+ }
+
+ let promptForm = new ImageEditorPromptNodeOptionsFormView(this.config),
+ promptLayer = new LayerView(this, promptNode, promptForm);
+
+ promptForm.onSubmit((values) => {
+ promptNode.setPrompts(values.prompt, values.negativePrompt);
+ });
+
+ await this.addLayer(promptLayer, activate);
+
+ return promptLayer;
+ }
+
+ /**
+ * Prompts for an image then adds a layer
+ */
+ async promptAddImageLayer() {
+ let imageToLoad;
+ try {
+ imageToLoad = await promptFiles();
+ } catch(e) { }
+ if (!isEmpty(imageToLoad)) {
+ // Triggers necessary state changes
+ this.application.loadFile(imageToLoad, truncate(imageToLoad.name, 16));
+ }
+ }
+
+ /**
+ * Gets the layer corresponding to a node on the editor
+ */
+ getLayerByEditorNode(node) {
+ return this.layers.filter((layer) => layer.editorNode === node).shift();
+ }
+
+ /**
+ * After copying a node, adds a layer
+ */
+ async addCopiedNode(newNode, previousNode) {
+ let existingLayer = this.getLayerByEditorNode(previousNode),
+ existingLayerState = existingLayer.getState(),
+ newNodeState = newNode.getState();
+
+ await this.addLayerByState({...existingLayerState, ...newNodeState}, newNode);
+
+ this.activateLayer(this.layers.length-1);
+ }
+
+ /**
+ * Fired when a layer is changed
+ */
+ async layersChanged() {
+ this.publish("layersChanged", this.getState().layers);
+ }
+
+ /**
+ * On initialize, add menus to view
+ */
+ async initialize() {
+ // Initial layers state
+ this.layers = [];
+ this.layerOptions = new LayerOptionsView(this.config);
+ this.layersView = new LayersView(this.config);
+
+ // Add layer tools
+ let imageLayer = await this.layersView.toolbar.addItem("Image/Video", "fa-regular fa-image"),
+ scribbleLayer = await this.layersView.toolbar.addItem("Draw Scribble", "fa-solid fa-pencil");
+
+ imageLayer.onClick(() => this.promptAddImageLayer());
+ scribbleLayer.onClick(() => this.addScribbleLayer());
+
+ // Add layer options
+ this.application.container.appendChild(await this.layerOptions.render());
+ this.application.container.appendChild(await this.layersView.render());
+
+ // Register callbacks for image editor
+ this.images.onNodeFocus((node) => {
+ this.activate(this.getLayerByEditorNode(node));
+ });
+ this.images.onNodeCopy((newNode, previousNode) => {
+ this.addCopiedNode(newNode, previousNode);
+ });
+ }
+};
+
+export { LayersController };
diff --git a/src/js/controller/common/model-manager.mjs b/src/js/controller/common/model-manager.mjs
index 45bd0fdd..5882fbda 100644
--- a/src/js/controller/common/model-manager.mjs
+++ b/src/js/controller/common/model-manager.mjs
@@ -47,157 +47,120 @@ class ModelManagerController extends Controller {
static managerWindowHeight = 600;
/**
- * Creates the manager table.
+ * Creates a window to edit a configuration
*/
- async createManager() {
- this.tableView = new ModelTableView(this.config, this.model.DiffusionModel);
- this.buttonView = new NewModelInputView(this.config);
+ async showEditModel(model) {
+ let modelValues = model.getAttributes();
- // Set columns, formatters, searchables
- this.tableView.setColumns({
- "name": "Name",
- "model": "Model",
- "size": "Size",
- "prompt": "Prompt",
- "negative_prompt": "Negative Prompt"
- });
- this.tableView.setSearchFields(["name", "prompt", "negative_prompt", "model"]);
- this.tableView.setFormatter("size", (datum) => `${datum}px`);
+ modelValues.checkpoint = modelValues.model;
+ modelValues.lora = isEmpty(model.lora) ? [] : model.lora.map((lora) => lora.getAttributes());
+ modelValues.lycoris = isEmpty(model.lycoris) ? [] : model.lycoris.map((lycoris) => lycoris.getAttributes());
+ modelValues.inversion = isEmpty(model.inversion) ? [] : model.inversion.map((inversion) => inversion.model);
+ modelValues.vae = isEmpty(model.vae) ? null : model.vae[0].name;
+ modelValues.motion_module = isEmpty(model.motion_module) ? null : model.motion_module[0].name;
+
+ if (!isEmpty(model.refiner)) {
+ modelValues.refiner = model.refiner[0].model;
+ }
- // Add the 'Edit' button
- this.tableView.addButton("Edit", "fa-solid fa-edit", async (row) => {
- let modelValues = row.getAttributes();
+ if (!isEmpty(model.inpainter)) {
+ modelValues.inpainter = model.inpainter[0].model;
+ }
+
+ if (!isEmpty(model.config)) {
+ let defaultConfig = {};
+ for (let configItem of model.config) {
+ defaultConfig[configItem.configuration_key] = configItem.configuration_value;
+ }
- modelValues.checkpoint = modelValues.model;
- modelValues.lora = isEmpty(row.lora) ? [] : row.lora.map((lora) => lora.getAttributes());
- modelValues.lycoris = isEmpty(row.lycoris) ? [] : row.lycoris.map((lycoris) => lycoris.getAttributes());
- modelValues.inversion = isEmpty(row.inversion) ? [] : row.inversion.map((inversion) => inversion.model);
- modelValues.vae = isEmpty(row.vae) ? null : row.vae[0].name;
+ modelValues = {...modelValues, ...defaultConfig};
- if (!isEmpty(row.refiner)) {
- modelValues.refiner = row.refiner[0].model;
- modelValues.refiner_size = row.refiner[0].size;
+ if (!isEmpty(defaultConfig.prompt_2)) {
+ modelValues.prompt = [modelValues.prompt, defaultConfig.prompt_2];
}
-
- if (!isEmpty(row.inpainter)) {
- modelValues.inpainter = row.inpainter[0].model;
- modelValues.inpainter_size = row.inpainter[0].size;
+ if (!isEmpty(defaultConfig.negative_prompt_2)) {
+ modelValues.negative_prompt = [defaultConfig.negative_prompt, defaultConfig.negative_prompt_2];
}
+ }
- if (!isEmpty(row.config)) {
- let defaultConfig = {};
- for (let configItem of row.config) {
- defaultConfig[configItem.configuration_key] = configItem.configuration_value;
- }
-
- modelValues = {...modelValues, ...defaultConfig};
+ if (!isEmpty(model.scheduler)) {
+ modelValues.scheduler = model.scheduler[0].name;
+ }
- if (!isEmpty(defaultConfig.prompt_2)) {
- modelValues.prompt = [modelValues.prompt, defaultConfig.prompt_2];
- }
- if (!isEmpty(defaultConfig.negative_prompt_2)) {
- modelValues.negative_prompt = [defaultConfig.negative_prompt, defaultConfig.negative_prompt_2];
- }
- if (!isEmpty(defaultConfig.upscale_diffusion_prompt_2)) {
- modelValues.upscale_diffusion_prompt = defaultConfig.upscale_diffusion_prompt.map(
- (prompt, index) => [prompt, defaultConfig.upscale_diffusion_prompt_2[index]],
- );
- }
- if (!isEmpty(defaultConfig.upscale_diffusion_negative_prompt_2)) {
- modelValues.upscale_diffusion_negative_prompt = defaultConfig.upscale_diffusion_negative_prompt.map(
- (prompt, index) => [prompt, defaultConfig.upscale_diffusion_negative_prompt_2[index]],
- );
- }
+ let modelForm = new ModelFormView(this.config, deepClone(modelValues)),
+ modelWindow;
+
+ modelForm.onChange(async (updatedValues) => {
+ if (!isEmpty(modelForm.values.refiner)) {
+ modelForm.addClass("show-refiner");
+ } else {
+ modelForm.removeClass("show-refiner");
}
-
- if (!isEmpty(row.scheduler)) {
- modelValues.scheduler = row.scheduler[0].name;
+ if (!isEmpty(modelForm.values.inpainter)) {
+ modelForm.addClass("show-inpainter");
+ } else {
+ modelForm.removeClass("show-inpainter");
}
+ });
- let modelForm = new ModelFormView(this.config, deepClone(modelValues)),
- modelWindow;
-
- modelForm.onChange(async (updatedValues) => {
- if (!isEmpty(modelForm.values.refiner)) {
- modelForm.addClass("show-refiner");
- } else {
- modelForm.removeClass("show-refiner");
- }
- if (!isEmpty(modelForm.values.inpainter)) {
- modelForm.addClass("show-inpainter");
- } else {
- modelForm.removeClass("show-inpainter");
- }
- });
+ modelForm.onSubmit(async (updatedValues) => {
+ if (Array.isArray(updatedValues.prompt)) {
+ updatedValues.prompt_2 = updatedValues.prompt[1];
+ updatedValues.prompt = updatedValues.prompt[0];
+ }
+ if (Array.isArray(updatedValues.negative_prompt)) {
+ updatedValues.negative_prompt_2 = updatedValues.negative_prompt[1];
+ updatedValues.negative_prompt = updatedValues.negative_prompt[0];
+ }
- modelForm.onSubmit(async (updatedValues) => {
- if (Array.isArray(updatedValues.prompt)) {
- updatedValues.prompt_2 = updatedValues.prompt[1];
- updatedValues.prompt = updatedValues.prompt[0];
- }
- if (Array.isArray(updatedValues.negative_prompt)) {
- updatedValues.negative_prompt_2 = updatedValues.negative_prompt[1];
- updatedValues.negative_prompt = updatedValues.negative_prompt[0];
- }
- let upscalePrompt = [],
- upscalePrompt2 = [],
- upscaleNegativePrompt = [],
- upscaleNegativePrompt2 = [];
-
- if (!isEmpty(updatedValues.upscale_diffusion_prompt)) {
- for (let promptPart of updatedValues.upscale_diffusion_prompt) {
- if (Array.isArray(promptPart)) {
- upscalePrompt.push(promptPart[0]);
- upscalePrompt2.push(promptPart[1]);
- } else {
- upscalePrompt.push(promptPart);
- upscalePrompt2.push(null);
- }
- }
+ try {
+ await this.model.patch(`/models/${model.name}`, null, null, updatedValues);
+ if (!isEmpty(modelWindow)) {
+ modelWindow.remove();
}
- if (!isEmpty(updatedValues.upscale_diffusion_negative_prompt)) {
- for (let promptPart of updatedValues.upscale_diffusion_negative_prompt) {
- if (Array.isArray(promptPart)) {
- upscaleNegativePrompt.push(promptPart[0]);
- upscaleNegativePrompt2.push(promptPart[1]);
- } else {
- upscaleNegativePrompt.push(promptPart);
- upscaleNegativePrompt2.push(null);
- }
- }
+ if (!isEmpty(this.tableView)) {
+ this.tableView.requery();
}
+ } catch(e) {
+ let errorMessage = isEmpty(e)
+ ? "Couldn't communicate with server."
+ : isEmpty(e.detail)
+ ? `${e}`
+ : e.detail;
- updatedValues.upscale_diffusion_prompt = upscalePrompt;
- updatedValues.upscale_diffusion_prompt_2 = upscalePrompt2;
- updatedValues.upscale_diffusion_negative_prompt = upscaleNegativePrompt;
- updatedValues.upscale_diffusion_negative_prompt_2 = upscaleNegativePrompt2;
+ this.notify("error", "Couldn't update model", errorMessage);
+ modelForm.enable();
+ }
+ });
+ modelForm.onCancel(() => modelWindow.remove());
+ modelWindow = await this.spawnWindow(
+ `Edit ${model.name}`,
+ modelForm,
+ this.constructor.modelWindowWidth,
+ this.constructor.modelWindowHeight
+ );
+ }
- try {
- await this.model.patch(`/models/${row.name}`, null, null, updatedValues);
- if (!isEmpty(modelWindow)) {
- modelWindow.remove();
- }
- this.tableView.requery();
- } catch(e) {
- let errorMessage = isEmpty(e)
- ? "Couldn't communicate with server."
- : isEmpty(e.detail)
- ? `${e}`
- : e.detail;
+ /**
+ * Creates the manager table.
+ */
+ async createManager() {
+ this.tableView = new ModelTableView(this.config, this.model.DiffusionModel);
+ this.buttonView = new NewModelInputView(this.config);
- this.notify("error", "Couldn't update model", errorMessage);
- modelForm.enable();
- }
- });
- modelForm.onCancel(() => modelWindow.remove());
- modelWindow = await this.spawnWindow(
- `Edit ${row.name}`,
- modelForm,
- this.constructor.modelWindowWidth,
- this.constructor.modelWindowHeight
- );
+ // Set columns, formatters, searchables
+ this.tableView.setColumns({
+ "name": "Name",
+ "model": "Model",
+ "prompt": "Prompt",
+ "negative_prompt": "Negative Prompt"
});
+ this.tableView.setSearchFields(["name", "prompt", "negative_prompt", "model"]);
+ this.tableView.setFormatter("size", (datum) => `${datum}px`);
+ // Add the 'Edit' button
+ this.tableView.addButton("Edit", "fa-solid fa-edit", (row) => this.showEditModel(row));
+
// Add the 'Delete' button
this.tableView.addButton("Delete", "fa-solid fa-trash", async (row) => {
try{
@@ -239,38 +202,6 @@ class ModelManagerController extends Controller {
values.negative_prompt_2 = values.negative_prompt[1];
values.negative_prompt = values.negative_prompt[0];
}
- let upscalePrompt = [],
- upscalePrompt2 = [],
- upscaleNegativePrompt = [],
- upscaleNegativePrompt2 = [];
-
- if (!isEmpty(values.upscale_diffusion_prompt)) {
- for (let promptPart of values.upscale_diffusion_prompt) {
- if (Array.isArray(promptPart)) {
- upscalePrompt.push(promptPart[0]);
- upscalePrompt2.push(promptPart[1]);
- } else {
- upscalePrompt.push(promptPart);
- upscalePrompt2.push(null);
- }
- }
- }
- if (!isEmpty(values.upscale_diffusion_negative_prompt)) {
- for (let promptPart of values.upscale_diffusion_negative_prompt) {
- if (Array.isArray(promptPart)) {
- upscaleNegativePrompt.push(promptPart[0]);
- upscaleNegativePrompt2.push(promptPart[1]);
- } else {
- upscaleNegativePrompt.push(promptPart);
- upscaleNegativePrompt2.push(null);
- }
- }
- }
-
- values.upscale_diffusion_prompt = upscalePrompt;
- values.upscale_diffusion_prompt_2 = upscalePrompt2;
- values.upscale_diffusion_negative_prompt = upscaleNegativePrompt;
- values.upscale_diffusion_negative_prompt_2 = upscaleNegativePrompt2;
try {
let response = await this.model.post("/models", null, null, values);
diff --git a/src/js/controller/common/model-picker.mjs b/src/js/controller/common/model-picker.mjs
index 8bb8b040..ab8f2138 100644
--- a/src/js/controller/common/model-picker.mjs
+++ b/src/js/controller/common/model-picker.mjs
@@ -14,14 +14,6 @@ const E = new ElementBuilder();
* Extend the TableView to disable sorting and add conditional buttons
*/
class ModelTensorRTTableView extends TableView {
- /**
- * Add a parameter for the engine build callable
- */
- constructor(config, data, buildEngine) {
- super(config, data);
- this.buildEngine = buildEngine;
- }
-
/**
* @var bool Disable sorting.
*/
@@ -51,6 +43,14 @@ class ModelTensorRTTableView extends TableView {
}
}
};
+
+ /**
+ * Add a parameter for the engine build callable
+ */
+ constructor(config, data, buildEngine) {
+ super(config, data);
+ this.buildEngine = buildEngine;
+ }
};
/**
@@ -226,6 +226,46 @@ class ModelPickerFormView extends FormView {
};
};
+/**
+ * This class holds the forms for chosen model and quick-set models
+ */
+class ModelPickerFormsView extends View {
+ /**
+ * @var string tag name
+ */
+ static tagName = "enfugue-model-picker";
+
+ /**
+ * @var string Text in the button
+ */
+ static showMoreText = "More Model Configuration";
+
+ /**
+ * Constructor registers forms
+ */
+ constructor(config, pickerForm, onShowMore) {
+ super(config);
+ this.pickerForm = pickerForm;
+ this.onShowMore = onShowMore;
+ }
+
+ /**
+ * On build, append forms
+ */
+ async build() {
+ let node = await super.build(),
+ showMore = E.button().content(this.constructor.showMoreText).on("click", (e) => {
+ e.stopPropagation();
+ this.onShowMore();
+ });
+
+ node.content(
+ await this.pickerForm.getNode(),
+ showMore
+ );
+ return node;
+ }
+}
/**
* The ModelPickerController appends the model chooser input to the image editor view.
@@ -242,6 +282,21 @@ class ModelPickerController extends Controller {
*/
static tensorRTStatusWindowHeight = 750;
+ /**
+ * @var int The width of the model config window
+ */
+ static modelWindowWidth = 500;
+
+ /**
+ * @var int The height of the model config window
+ */
+ static modelWindowHeight = 500;
+
+ /**
+ * @var string title of the model window
+ */
+ static modelWindowTitle = "More Model Configuration";
+
/**
* Get state from the model picker
*/
@@ -268,7 +323,7 @@ class ModelPickerController extends Controller {
setState(newState) {
if (!isEmpty(newState.model)) {
this.modelPickerFormView.suppressDefaults = true;
- this.modelPickerFormView.setValues(newState.model).then(
+ this.modelPickerFormView.setValues(newState.model, false).then(
() => this.modelPickerFormView.submit()
);
}
@@ -341,15 +396,45 @@ class ModelPickerController extends Controller {
return null;
}
+ /**
+ * Shows the abridged model form
+ */
+ async showModelForm() {
+ if (this.engine.modelType === "model") {
+ let modelData = await this.model.DiffusionModel.query({name: this.engine.model});
+ this.application.modelManager.showEditModel(modelData);
+ } else {
+ if (!isEmpty(this.modelFormWindow)) {
+ this.modelFormWindow.focus();
+ } else {
+ this.modelFormWindow = await this.spawnWindow(
+ this.constructor.modelWindowTitle,
+ this.abridgedModelFormView,
+ this.constructor.modelWindowWidth,
+ this.constructor.modelWindowHeight
+ );
+ this.modelFormWindow.onClose(() => { delete this.modelFormWindow; });
+ setTimeout(() => { this.abridgedModelFormView.setValues(this.abridgedModelFormView.values); }, 100); // Refresh draw
+ }
+ }
+ }
+
/**
* When initialized, append form to container and register callbacks.
*/
async initialize() {
+ this.xl = false;
this.builtEngines = {};
this.modelPickerFormView = new ModelPickerFormView(this.config);
this.abridgedModelFormView = new AbridgedModelFormView(this.config);
+ this.formsView = new ModelPickerFormsView(
+ this.config,
+ this.modelPickerFormView,
+ () => this.showModelForm()
+ );
+
this.modelPickerFormView.onSubmit(async (values) => {
let suppressDefaults = this.modelPickerFormView.suppressDefaults;
this.modelPickerFormView.suppressDefaults = false;
@@ -358,15 +443,16 @@ class ModelPickerController extends Controller {
this.engine.model = selectedName;
this.engine.modelType = selectedType;
if (selectedType === "model") {
- this.abridgedModelFormView.hide();
try {
let fullModel = await this.model.DiffusionModel.query({name: selectedName}),
modelStatus = await fullModel.getStatus(),
tensorRTStatus = {supported: false};
fullModel.status = modelStatus;
+
if (suppressDefaults) {
fullModel._relationships.config = null;
+ fullModel._relationships.scheduler = null;
}
this.publish("modelPickerChange", fullModel);
@@ -386,9 +472,31 @@ class ModelPickerController extends Controller {
console.error(e);
}
} else {
- this.abridgedModelFormView.show();
- this.abridgedModelFormView.submit();
- this.modelPickerFormView.setTensorRTStatus({supported: false});
+ // Query for metadata
+ try {
+ let modelMetadata = await this.model.get(`/models/${selectedName}/status`);
+ if (!isEmpty(modelMetadata.metadata.base)) {
+ if (modelMetadata.metadata.base.inpainter) {
+ this.notify("warn", "Unexpected Configuration", "You've selected an inpainting model as your base model. This will work as expected for inpainting, but if you aren't inpainting, results will be poorer than desired. Expand 'Additional Models' and put your model under 'Inpainting Checkpoint' to only use it when inpainting.");
+ }
+ if (modelMetadata.metadata.base.refiner) {
+ this.notify("warn", "Unexpected Configuration", "You've selected a refining model as your base model. This will work as expected for refining, but if you aren't refining, results will be poorer than desired. Expand 'Additional Models' and put your model under 'Refining Checkpoint' to only use it when refining.");
+ }
+ }
+
+ this.abridgedModelFormView.submit();
+ this.modelPickerFormView.setTensorRTStatus({supported: false});
+ this.publish("modelPickerChange", {"status": modelMetadata, "defaultConfiguration": {}});
+ } catch(e) {
+ let errorMessage = "This model's metadata could not be read. It may still work, but it's possible the file is corrupt or otherwise unsupported.";
+ if (!isEmpty(e.title)) {
+ errorMessage += ` The error was: ${e.title}`;
+ if (!isEmpty(e.detail)) {
+ errorMessage += `(${e.detail})`;
+ }
+ }
+ this.notify("warn", "Metadata Error", errorMessage);
+ }
}
} else {
this.modelPickerFormView.setTensorRTStatus({supported: false});
@@ -404,6 +512,11 @@ class ModelPickerController extends Controller {
this.engine.refinerVae = values.refiner_vae;
this.engine.inpainter = values.inpainter;
this.engine.inpainterVae = values.inpainter_vae;
+ this.engine.motionModule = values.motion_module;
+ this.abridgedModelFormView.enable();
+ if (!isEmpty(this.modelFormWindow)) {
+ this.modelFormWindow.remove();
+ }
});
this.abridgedModelFormView.onChange(async () => {
@@ -419,8 +532,7 @@ class ModelPickerController extends Controller {
}
});
- this.application.container.appendChild(await this.modelPickerFormView.render());
- this.application.container.appendChild(await this.abridgedModelFormView.render());
+ this.application.container.appendChild(await this.formsView.render());
this.subscribe("invocationError", (payload) => {
if (!isEmpty(payload.metadata) && !isEmpty(payload.metadata.tensorrt_build)) {
diff --git a/src/js/controller/common/prompts.mjs b/src/js/controller/common/prompts.mjs
new file mode 100644
index 00000000..2f180d47
--- /dev/null
+++ b/src/js/controller/common/prompts.mjs
@@ -0,0 +1,648 @@
+/** @module controllers/common/prompts */
+import { isEmpty, bindMouseUntilRelease } from "../../base/helpers.mjs";
+import { ElementBuilder } from "../../base/builder.mjs";
+import { View } from "../../view/base.mjs";
+import { Controller } from "../base.mjs";
+import { PromptTravelFormView } from "../../forms/enfugue/prompts.mjs";
+
+const E = new ElementBuilder();
+
+/**
+ * This class represents a single prompt in the prompts track
+ */
+class PromptView extends View {
+ /**
+ * @var string css classname
+ */
+ static className = "prompt-view";
+
+ /**
+ * @var string classname of the edit icon
+ */
+ static editIcon = "fa-solid fa-edit";
+
+ /**
+ * @var string classname of the delete icon
+ */
+ static deleteIcon = "fa-solid fa-trash";
+
+ /**
+ * @var int Minimum number of milliseconds between ticks when dragging prompts
+ */
+ static minimumTickInterval = 100;
+
+ /**
+ * @var int Number of pixels from edges to drag
+ */
+ static edgeHandlerTolerance = 15;
+
+ /**
+ * @var int Number of pixels from edges to slide
+ */
+ static slideHandlerTolerance = 40;
+
+ /**
+ * @var int minimum number of frames per prompt
+ */
+ static minimumFrames = 2;
+
+ constructor(config, total, positive = null, negative = null, start = null, end = null, weight = null) {
+ super(config);
+ this.total = total;
+ this.positive = positive;
+ this.negative = negative;
+ if (isEmpty(start)) {
+ this.start = 0;
+ } else {
+ this.start = start;
+ }
+ if (isEmpty(end)) {
+ this.end = total;
+ } else {
+ this.end = end;
+ }
+ if (isEmpty(weight)) {
+ this.weight = 1.0;
+ } else {
+ this.weight = weight;
+ }
+ this.showEditCallbacks = [];
+ this.onRemoveCallbacks = [];
+ this.onChangeCallbacks = [];
+ }
+
+ /**
+ * Set the position of a node
+ * Do not call this method directly
+ */
+ setPosition(node) {
+ let startRatio, endRatio;
+ if (isEmpty(this.start)) {
+ startRatio = 0.0;
+ } else {
+ startRatio = this.start / this.total;
+ }
+ if (isEmpty(this.end)) {
+ endRatio = 1.0;
+ } else {
+ endRatio = this.end / this.total;
+ }
+ node.css({
+ "margin-left": `${startRatio*100.0}%`,
+ "margin-right": `${(1.0-endRatio)*100.0}%`,
+ });
+ }
+
+ /**
+ * Resets the position of this node to where it should be
+ */
+ resetPosition() {
+ if (!isEmpty(this.node)) {
+ this.setPosition(this.node);
+ }
+ }
+
+ /**
+ * Adds a callback when the edit button is clicked
+ */
+ onShowEdit(callback) {
+ this.showEditCallbacks.push(callback);
+ }
+
+ /**
+ * Triggers showEdit callbacks
+ */
+ showEdit() {
+ for (let callback of this.showEditCallbacks) {
+ callback();
+ }
+ }
+
+ /**
+ * Adds a callback for when this prompt is removed
+ */
+ onRemove(callback) {
+ this.onRemoveCallbacks.push(callback);
+ }
+
+ /**
+ * Triggers remove callbacks
+ */
+ remove() {
+ for (let callback of this.onRemoveCallbacks) {
+ callback();
+ }
+ }
+
+ /**
+ * Adds a callback for when this is changed
+ */
+ onChange(callback) {
+ this.onChangeCallbacks.push(callback);
+ }
+
+ /**
+ * Triggers change callbacks
+ */
+ changed() {
+ let state = this.getState();
+ for (let callback of this.onChangeCallbacks) {
+ callback(state);
+ }
+ }
+
+ /**
+ * Gets the state of this prompt from all sources
+ */
+ getState() {
+ return {
+ "positive": this.positive,
+ "negative": this.negative,
+ "weight": this.weight,
+ "start": this.start,
+ "end": this.end
+ };
+ }
+
+ /**
+ * Sets daata that will be handled by an external form
+ */
+ setFormData(newData) {
+ this.positive = newData.positive;
+ this.negative = newData.negative;
+ this.weight = isEmpty(newData.weight) ? 1.0 : newData.weight;
+
+ if (!isEmpty(this.node)) {
+ let positive = this.node.find(".positive"),
+ negative = this.node.find(".negative"),
+ weight = this.node.find(".weight");
+
+ weight.content(`${this.weight.toFixed(2)}`);
+
+ if (isEmpty(this.positive)) {
+ positive.content("(none)");
+ } else {
+ positive.content(this.positive);
+ if (isEmpty(this.negative)) {
+ negative.hide();
+ } else {
+ negative.show().content(this.negative);
+ }
+ }
+ }
+ }
+
+ /**
+ * Sets all state, form data and start/end
+ */
+ setState(newData) {
+ if (isEmpty(newData)) newData = {};
+ this.start = newData.start;
+ this.end = newData.end;
+ this.resetPosition();
+ this.setFormData(newData);
+ }
+
+ /**
+ * On build, append nodes for positive/negative, weight indicator and buttons
+ */
+ async build() {
+ let node = await super.build(),
+ weight = isEmpty(this.weight)
+ ? 1.0
+ : this.weight,
+ edit = E.i().class(this.constructor.editIcon).on("click", () => {
+ this.showEdit();
+ }),
+ remove = E.i().class(this.constructor.deleteIcon).on("click", () => {
+ this.remove();
+ }),
+ positive = E.p().class("positive"),
+ negative = E.p().class("negative"),
+ prompts = E.div().class("prompts").content(positive, negative);
+
+ if (isEmpty(this.positive)) {
+ positive.content("(none)");
+ negative.hide();
+ } else {
+ positive.content(this.positive);
+ if (!isEmpty(this.negative)) {
+ negative.content(this.negative);
+ } else {
+ negative.hide();
+ }
+ }
+
+ let activeLeft = false,
+ activeRight = false,
+ activeSlide = false,
+ canDragLeft = false,
+ canDragRight = false,
+ canSlide = false,
+ lastTick = (new Date()).getTime(),
+ slideStartFrame,
+ slideStartRange,
+ updateFrame = (closestFrame) => {
+ if (activeLeft) {
+ this.start = Math.min(
+ closestFrame,
+ this.end - this.constructor.minimumFrames
+ );
+ this.setPosition(node);
+ this.changed();
+ } else if(activeRight) {
+ this.end = Math.max(
+ closestFrame,
+ this.start + this.constructor.minimumFrames
+ );
+ this.setPosition(node);
+ this.changed();
+ } else if(activeSlide) {
+ let difference = slideStartFrame - closestFrame,
+ [initialStart, initialEnd] = slideStartRange;
+
+ this.start = Math.max(initialStart - difference, 0);
+ this.end = Math.min(initialEnd - difference, this.total);
+ this.setPosition(node);
+ this.changed();
+ }
+ },
+ updatePosition = (e) => {
+ // e might be relative to window or to prompt container, so get absolute pos
+ let now = (new Date()).getTime();
+ if (now - lastTick < this.constructor.minimumTickInterval) return;
+ lastTick = now;
+
+ let promptPosition = node.element.getBoundingClientRect(),
+ promptWidth = promptPosition.width,
+ relativeLeft = Math.min(
+ Math.max(e.clientX - promptPosition.x, 0),
+ promptWidth
+ ),
+ relativeRight = Math.max(promptWidth - relativeLeft, 0),
+ containerPosition = node.element.parentElement.getBoundingClientRect(),
+ containerWidth = containerPosition.width - 15, // Padding
+ containerRelativeLeft = Math.min(
+ Math.max(e.clientX - containerPosition.x, 0),
+ containerWidth
+ ),
+ containerRelativeRight = Math.max(containerWidth - containerRelativeLeft, 0),
+ ratio = containerRelativeLeft / containerWidth,
+ closestFrame = Math.ceil(ratio * this.total);
+
+ canDragLeft = false;
+ canDragRight = false;
+ canSlide = false;
+
+ if (relativeLeft < this.constructor.edgeHandlerTolerance) {
+ node.css("cursor", "ew-resize");
+ canDragLeft = true;
+ } else if (relativeRight < this.constructor.edgeHandlerTolerance) {
+ node.css("cursor", "ew-resize");
+ canDragRight = true;
+ } else if (
+ relativeLeft >= this.constructor.slideHandlerTolerance &&
+ relativeRight >= this.constructor.slideHandlerTolerance
+ ) {
+ node.css("cursor", "grab");
+ canSlide = true;
+ if (!activeSlide) {
+ slideStartFrame = closestFrame;
+ }
+ } else if (!activeLeft && !activeRight && !activeSlide) {
+ node.css("cursor", "default");
+ }
+
+ updateFrame(closestFrame);
+ e.preventDefault();
+ e.stopPropagation();
+ };
+
+ node.content(
+ prompts,
+ E.div().class("weight").content(`${weight.toFixed(2)}`),
+ edit,
+ remove
+ ).on("mouseenter", (e) => {
+ updatePosition(e);
+ }).on("mousemove", (e) => {
+ updatePosition(e);
+ }).on("mousedown", (e) => {
+ if (canDragLeft) {
+ activeLeft = true;
+ } else if (canDragRight) {
+ activeRight = true;
+ } else if (canSlide) {
+ activeSlide = true;
+ slideStartRange = [this.start, this.end];
+ }
+ updatePosition(e);
+ bindMouseUntilRelease(
+ (e2) => {
+ updatePosition(e2);
+ },
+ (e2) => {
+ activeLeft = false;
+ activeRight = false;
+ activeSlide = false;
+ }
+ );
+ }).on("dblclick", (e) => {
+ e.preventDefault();
+ e.stopPropagation();
+ this.showEdit();
+ });
+ this.setPosition(node);
+ return node;
+ }
+}
+
+/**
+ * This view manages DOM interactions with the prompt travel sliders
+ */
+class PromptTravelView extends View {
+ /**
+ * @var string DOM tag name
+ */
+ static tagName = "enfugue-prompt-travel-view";
+
+ /**
+ * @var int Width of the prompt change form
+ */
+ static promptFormWindowWidth = 350;
+
+ /**
+ * @var int Height of the prompt change form
+ */
+ static promptFormWindowHeight = 450;
+
+ /**
+ * @var int Minimum number of frames
+ */
+ static minimumFrames = 2;
+
+ /**
+ * Constructor has callback to spawn a window
+ */
+ constructor(config, spawnWindow, length = 16) {
+ super(config);
+ this.length = length;
+ this.spawnWindow = spawnWindow;
+ this.promptViews = [];
+ this.onChangeCallbacks = [];
+ }
+
+ /**
+ * Sets the length of all frames, re-scaling notches and prompts
+ */
+ setLength(newLength) {
+ this.length = newLength;
+ if (this.node !== undefined) {
+ let notchNode = this.node.find(".notches"),
+ newPromptNotches = new Array(this.length).fill(null).map(
+ (_, i) => E.span().class("notch").content(`${i+1}`)
+ );
+
+ notchNode.content(...newPromptNotches);
+ }
+
+ for (let promptView of this.promptViews) {
+ promptView.total = this.length;
+ promptView.end = Math.min(promptView.total, promptView.end);
+ promptView.start = Math.max(0, Math.min(promptView.start, promptView.end-this.constructor.minimumFrames));
+ promptView.resetPosition();
+ }
+ }
+
+ /**
+ * Adds a callback when any prompts are changed
+ */
+ onChange(callback) {
+ this.onChangeCallbacks.push(callback);
+ }
+
+ /**
+ * Triggers changed callbacks
+ */
+ changed() {
+ let state = this.getState();
+ for (let callback of this.onChangeCallbacks) {
+ callback(state);
+ }
+ }
+
+ /**
+ * Removes a prompt view by pointer
+ */
+ removePrompt(promptView) {
+ let promptViewIndex = this.promptViews.indexOf(promptView);
+ if (promptViewIndex === -1) {
+ console.error("Couldn't find prompt view in memory", promptView);
+ return;
+ }
+ this.promptViews.splice(promptViewIndex, 1);
+ if (!isEmpty(this.node) && !isEmpty(promptView.node)) {
+ this.node.find(".prompts-container").remove(promptView.node);
+ }
+ }
+
+ /**
+ * Adds a new prompt from an option/state dictionary
+ */
+ async addPrompt(newPrompt) {
+ if (isEmpty(newPrompt)) newPrompt = {};
+ let promptView = new PromptView(
+ this.config,
+ this.length,
+ newPrompt.positive,
+ newPrompt.negative,
+ newPrompt.start,
+ newPrompt.end,
+ newPrompt.weight
+ ),
+ promptFormView = new PromptTravelFormView(
+ this.config,
+ newPrompt
+ ),
+ promptWindow;
+
+ promptFormView.onSubmit((values) => {
+ promptView.setFormData(values);
+ this.changed();
+ });
+ promptView.onChange(() => this.changed());
+ promptView.onRemove(() => {
+ this.removePrompt(promptView);
+ });
+ promptView.onShowEdit(async () => {
+ if (!isEmpty(promptWindow)) {
+ promptWindow.focus();
+ } else {
+ promptWindow = await this.spawnWindow(
+ "Edit Prompt",
+ promptFormView,
+ this.constructor.promptFormWindowWidth,
+ this.constructor.promptFormWindowHeight,
+ );
+ promptWindow.onClose(() => { promptWindow = null; });
+ }
+ });
+ this.promptViews.push(promptView);
+ if (this.node !== undefined) {
+ this.node.find(".prompts-container").append(await promptView.getNode());
+ }
+ }
+
+ /**
+ * Empties the array of prompts in memory and DOM
+ */
+ emptyPrompts() {
+ this.promptViews = [];
+ if (!isEmpty(this.node)) {
+ this.node.find(".prompts-container").empty();
+ }
+ }
+
+ /**
+ * Gets the state of all prompt views
+ */
+ getState() {
+ return this.promptViews.map((view) => view.getState());
+ }
+
+ /**
+ * Sets the state of all prompt views
+ */
+ async setState(newState = []) {
+ this.emptyPrompts();
+ for (let promptState of newState) {
+ await this.addPrompt(promptState);
+ }
+ this.node.render();
+ }
+
+ /**
+ * On build, add track and add prompt button
+ */
+ async build() {
+ let node = await super.build(),
+ addPrompt = E.button().content("Add Prompt").on("click", () => {
+ this.addPrompt();
+ }),
+ promptNotches = new Array(this.length).fill(null).map(
+ (_, i) => E.span().class("notch").content(`${i+1}`)
+ ),
+ promptNotchContainer = E.div().class("notches").content(...promptNotches),
+ promptContainer = E.div().class("prompts-container"),
+ promptsTrack = E.div().class("prompts-track").content(promptNotchContainer, promptContainer);
+
+ for (let promptView of this.promptViews) {
+ promptContainer.append(await promptView.getNode());
+ }
+ node.content(promptsTrack, addPrompt);
+ return node;
+ }
+}
+
+/**
+ * The prompt travel controller is triggered by the prompt sidebar controller
+ * It will show/hide the prompt travel view and manage state with the invocation engine
+ */
+class PromptTravelController extends Controller {
+ /**
+ * By default no prompts are provided
+ */
+ getDefaultState() {
+ return {
+ "travel": []
+ }
+ }
+
+ /**
+ * We use the view class for easy state management
+ */
+ getState() {
+ return {
+ "travel": this.promptView.getState()
+ };
+ }
+
+ /**
+ * Sets the state in the view class
+ */
+ async setState(newState) {
+ if (!isEmpty(newState.travel)) {
+ await this.promptView.setState(newState.travel);
+ }
+ }
+
+ /**
+ * Disables prompt travel entirely (hides PT container)
+ */
+ async disablePromptTravel() {
+ this.promptView.hide();
+ this.application.container.classList.remove("prompt-travel");
+ this.engine.prompts = null;
+ }
+
+ /**
+ * Enables prompt travel
+ * If there was previous state, keeps that. If there wasn't, add the current satte of the main prompts
+ */
+ async enablePromptTravel() {
+ let currentState = this.promptView.getState();
+ if (currentState.length === 0) {
+ // Add the current prompt
+ let positive = [this.engine.prompt, this.engine.prompt2];
+ let negative = [this.engine.negativePrompt, this.engine.negativePrompt2];
+
+ positive = positive.filter((value) => !isEmpty(value));
+ if (isEmpty(positive)) {
+ positive = null;
+ } else if(positive.length === 1) {
+ positive = positive[0];
+ }
+
+ negative = negative.filter((value) => !isEmpty(value));
+ if (isEmpty(negative)) {
+ negative = null;
+ } else if(negative.length === 1) {
+ negative = negative[0];
+ }
+
+ this.promptView.addPrompt({
+ positive: positive,
+ negative: negative
+ });
+ }
+ this.promptView.show();
+ this.application.container.classList.add("prompt-travel");
+ this.engine.prompts = this.promptView.getState();
+ }
+
+ /**
+ * On initialize, append and hide prompt travel view, waiting for the enable
+ */
+ async initialize() {
+ this.promptView = new PromptTravelView(
+ this.config,
+ (title, content, w, h, x, y) => this.spawnWindow(title, content, w, h, x, y)
+ );
+ this.promptView.hide();
+ this.promptView.onChange((newPrompts) => {
+ this.engine.prompts = newPrompts;
+ });
+ this.application.container.appendChild(await this.promptView.render());
+ this.subscribe("promptTravelEnabled", () => this.enablePromptTravel());
+ this.subscribe("promptTravelDisabled", () => this.disablePromptTravel());
+
+ // Let the sidebar cascade enable and disable events; we'll just change size
+ this.subscribe("engineAnimationFramesChange", (frames) => {
+ if (!isEmpty(frames) && frames > 0) {
+ this.promptView.setLength(frames);
+ }
+ });
+ }
+}
+
+export { PromptTravelController };
diff --git a/src/js/controller/common/samples.mjs b/src/js/controller/common/samples.mjs
new file mode 100644
index 00000000..c36910be
--- /dev/null
+++ b/src/js/controller/common/samples.mjs
@@ -0,0 +1,886 @@
+/** @module controller/common/samples */
+import {
+ downloadAsDataURL,
+ isEmpty,
+ waitFor,
+ sleep
+} from "../../base/helpers.mjs";
+import { ElementBuilder } from "../../base/builder.mjs";
+import { Controller } from "../base.mjs";
+import { SimpleNotification } from "../../common/notify.mjs";
+import { SampleChooserView } from "../../view/samples/chooser.mjs";
+import { SampleView } from "../../view/samples/viewer.mjs";
+import {
+ ImageAdjustmentView,
+ ImageFilterView
+} from "../../view/samples/filter.mjs";
+import { View } from "../../view/base.mjs";
+import { ImageView } from "../../view/image.mjs";
+import { ToolbarView } from "../../view/menu.mjs";
+import {
+ UpscaleFormView,
+ DownscaleFormView
+} from "../../forms/enfugue/upscale.mjs";
+
+const E = new ElementBuilder();
+
+/**
+ * This is the main controller that manages state and views
+ */
+class SamplesController extends Controller {
+
+ /**
+ * @var int The number of milliseconds to wait after leaving the image to hide tools
+ */
+ static hideTime = 250;
+
+ /**
+ * @var int The width of the adjustment window in pixels
+ */
+ static imageAdjustmentWindowWidth = 750;
+
+ /**
+ * @var int The height of the adjustment window in pixels
+ */
+ static imageAdjustmentWindowHeight = 450;
+
+ /**
+ * @var int The width of the filter window in pixels
+ */
+ static imageFilterWindowWidth = 450;
+
+ /**
+ * @var int The height of the filter window in pixels
+ */
+ static imageFilterWindowHeight = 350;
+
+ /**
+ * @var int The width of the upscale window in pixels
+ */
+ static imageUpscaleWindowWidth = 300;
+
+ /**
+ * @var int The height of the upscale window in pixels
+ */
+ static imageUpscaleWindowHeight = 320;
+
+ /**
+ * @Xvar int The width of the upscale window in pixels
+ */
+ static imageDownscaleWindowWidth = 260;
+
+ /**
+ * @var int The height of the upscale window in pixels
+ */
+ static imageDownscaleWindowHeight = 210;
+
+ /**
+ * Adds the image menu to the passed menu
+ */
+ async prepareImageMenu(menu) {
+ if (!!navigator.clipboard && typeof ClipboardItem === "function") {
+ let copyImage = await menu.addItem("Copy to Clipboard", "fa-solid fa-clipboard", "c");
+ copyImage.onClick(() => this.copyToClipboard());
+ }
+
+ let popoutImage = await menu.addItem("Popout Image", "fa-solid fa-arrow-up-right-from-square", "p");
+ popoutImage.onClick(() => this.sendToWindow());
+
+ let saveImage = await menu.addItem("Save As", "fa-solid fa-floppy-disk", "a");
+ saveImage.onClick(() => this.saveToDisk());
+
+ let adjustImage = await menu.addItem("Adjust Image", "fa-solid fa-sliders", "j");
+ adjustImage.onClick(() => this.startImageAdjustment());
+
+ let filterImage = await menu.addItem("Filter Image", "fa-solid fa-wand-magic-sparkles", "l");
+ filterImage.onClick(() => this.startImageFilter());
+
+ let editImage = await menu.addItem("Edit Image", "fa-solid fa-pen-to-square", "t");
+ editImage.onClick(() => this.sendToCanvas());
+
+ let upscaleImage = await menu.addItem("Upscale Image", "fa-solid fa-up-right-and-down-left-from-center", "u");
+ upscaleImage.onClick(() => this.startImageUpscale());
+
+ let downscaleImage = await menu.addItem("Downscale Image", "fa-solid fa-down-left-and-up-right-to-center", "w");
+ downscaleImage.onClick(() => this.startImageDownscale());
+ }
+
+ /**
+ * Adds the video menu to the passed menu
+ */
+ async prepareVideoMenu(menu) {
+ let popoutVideo = await menu.addItem("Popout Video", "fa-solid fa-arrow-up-right-from-square", "p");
+ popoutVideo.onClick(() => this.sendVideoToWindow());
+
+ let saveVideo = await menu.addItem("Save As", "fa-solid fa-floppy-disk", "a");
+ saveVideo.onClick(() => this.saveVideoToDisk());
+
+ let editVideo = await menu.addItem("Edit Video", "fa-solid fa-pen-to-square", "t");
+ editVideo.onClick(() => this.sendVideoToCanvas());
+
+ let upscaleVideo = await menu.addItem("Upscale Video", "fa-solid fa-up-right-and-down-left-from-center", "u");
+ upscaleVideo.onClick(() => this.startVideoUpscale());
+
+ let gifVideo = await menu.addItem("Get as GIF", "fa-solid fa-file-video", "g");
+ gifVideo.onClick(() => this.getVideoGif());
+ /**
+ let interpolateVideo = await menu.addItem("Interpolate Video", "fa-solid fa-film", "i");
+ interpolateVideo.onClick(() => this.startVideoInterpolate());
+ */
+ }
+
+ /**
+ * Triggers the copy to clipboard
+ */
+ async copyToClipboard() {
+ navigator.clipboard.write([
+ new ClipboardItem({
+ "image/png": await this.sampleViewer.getBlob()
+ })
+ ]);
+ SimpleNotification.notify("Copied to clipboard!", 2000);
+ }
+
+ /**
+ * Saves the image to disk
+ * Asks for a filename first
+ */
+ async saveToDisk() {
+ this.application.saveBlobAs(
+ "Save Image",
+ await this.sampleViewer.getBlob(),
+ ".png"
+ );
+ }
+
+ /**
+ * Saves the image to disk
+ * Asks for a filename first
+ */
+ async saveVideoToDisk() {
+ this.application.saveRemoteAs("Save Video", this.video);
+ }
+
+ /**
+ * Sends the image to a new canvas
+ */
+ async sendToCanvas() {
+ return await this.application.initializeStateFromImage(
+ this.sampleViewer.getDataURL(),
+ true, // Save history
+ null, // Prompt for settings
+ null, // No state overrides
+ false, // Not video
+ );
+ }
+
+ /*
+ * Sends the video to a new canvas
+ */
+ async sendVideoToCanvas() {
+ return await this.application.initializeStateFromImage(
+ await downloadAsDataURL(this.video),
+ true, // Save history
+ null, // Prompt for settings
+ null, // No state overrides
+ true, // Video
+ );
+ }
+
+ /**
+ * Opens the video as a gif
+ */
+ async getVideoGif() {
+ if (isEmpty(this.video) || !this.video.endsWith("mp4")) {
+ throw `Video is empty or not a URL.`;
+ }
+ let gifURL = this.video.substring(0, this.video.length-4) + ".gif";
+ window.open(gifURL, "_blank");
+ }
+
+ /**
+ * Starts downscaling the image
+ * Replaces the current visible canvas with an in-progress edit.
+ */
+ async startImageDownscale() {
+ if (this.checkActiveTool("downscale")) return;
+
+ let imageBeforeDownscale = this.sampleViewer.getDataURL(),
+ widthBeforeDownscale = this.sampleViewer.width,
+ heightBeforeDownscale = this.sampleViewer.height,
+ setDownscaleAmount = async (amount) => {
+ let image = new ImageView(this.config, imageBeforeDownscale);
+ await image.waitForLoad();
+ await image.downscale(amount);
+ this.sampleViewer.setImage(image.src);
+ this.application.images.setDimension(
+ image.width,
+ image.height,
+ false
+ );
+ },
+ saveResults = false;
+
+ this.imageDownscaleForm = new DownscaleFormView(this.config);
+ this.imageDownscaleWindow = await this.application.windows.spawnWindow(
+ "Downscale Image",
+ this.imageDownscaleForm,
+ this.constructor.imageDownscaleWindowWidth,
+ this.constructor.imageDownscaleWindowHeight
+ );
+ this.imageDownscaleWindow.onClose(() => {
+ this.imageDownscaleForm = null;
+ this.imageDownscaleWindow = null;
+ if (!saveResults) {
+ this.sampleViewer.setImage(imageBeforeDownscale);
+ this.application.images.setDimension(widthBeforeDownscale, heightBeforeDownscale, false);
+ }
+ });
+ this.imageDownscaleForm.onChange(async () => setDownscaleAmount(this.imageDownscaleForm.values.downscale));
+ this.imageDownscaleForm.onCancel(() => this.imageDownscaleWindow.remove());
+ this.imageDownscaleForm.onSubmit(async (values) => {
+ saveResults = true;
+ this.imageDownscaleWindow.remove();
+ });
+ setDownscaleAmount(2); // Default to 2
+ }
+
+ /**
+ * Starts upscaling the image
+ * Does not replace the current visible canvas.
+ * This will use the canvas and upscale settings to send to the backend.
+ */
+ async startImageUpscale() {
+ if (this.checkActiveTool("upscale")) return;
+
+ this.imageUpscaleForm = new UpscaleFormView(this.config);
+ this.imageUpscaleWindow = await this.application.windows.spawnWindow(
+ "Upscale Image",
+ this.imageUpscaleForm,
+ this.constructor.imageUpscaleWindowWidth,
+ this.constructor.imageUpscaleWindowHeight
+ );
+ this.imageUpscaleWindow.onClose(() => {
+ this.imageUpscaleForm = null;
+ this.imageUpscaleWindow = null;
+ });
+ this.imageUpscaleForm.onCancel(() => this.imageUpscaleWindow.remove());
+ this.imageUpscaleForm.onSubmit(async (values) => {
+ await this.application.layers.emptyLayers();
+ await this.application.images.setDimension(
+ this.sampleViewer.width,
+ this.sampleViewer.height,
+ false,
+ true
+ );
+ await this.application.layers.addImageLayer(this.sampleViewer.getDataURL());
+ this.publish("quickUpscale", values);
+ // Remove window
+ this.imageUpscaleWindow.remove();
+ // Show the canvas
+ this.showCanvas();
+ // Wait a tick then trigger invoke
+ setTimeout(() => {
+ this.application.publish("tryInvoke");
+ }, 1000);
+ });
+ }
+
+ /**
+ * Starts upscaling the video
+ * Does not replace the current visible canvas.
+ * This will use the canvas and upscale settings to send to the backend.
+ */
+ async startVideoUpscale() {
+ if (this.checkActiveTool("upscale")) return;
+
+ this.videoUpscaleForm = new UpscaleFormView(this.config);
+ this.videoUpscaleWindow = await this.application.windows.spawnWindow(
+ "Upscale Video",
+ this.videoUpscaleForm,
+ this.constructor.imageUpscaleWindowWidth,
+ this.constructor.imageUpscaleWindowHeight
+ );
+ this.videoUpscaleWindow.onClose(() => {
+ this.videoUpscaleForm = null;
+ this.videoUpscaleWindow = null;
+ });
+ this.videoUpscaleForm.onCancel(() => this.videoUpscaleWindow.remove());
+ this.videoUpscaleForm.onSubmit(async (values) => {
+ await this.application.layers.emptyLayers();
+ await this.application.images.setDimension(
+ this.sampleViewer.width,
+ this.sampleViewer.height,
+ false,
+ true
+ );
+ await this.application.layers.addVideoLayer(
+ await downloadAsDataURL(this.video)
+ );
+ this.publish("quickUpscale", values);
+ // Remove window
+ this.videoUpscaleWindow.remove();
+ // Show the canvas
+ this.showCanvas();
+ // Wait a tick then trigger invoke
+ setTimeout(() => {
+ this.application.publish("tryInvoke");
+ }, 1000);
+ });
+ }
+
+ /**
+ * Starts filtering the image
+ * Replaces the current visible canvas with an in-progress edit.
+ */
+ async startImageFilter() {
+ if (this.checkActiveTool("filter")) return;
+
+ this.imageFilterView = new ImageFilterView(
+ this.config,
+ this.sampleViewer.getDataURL(),
+ this.sampleViewer.node.element.parentElement
+ );
+ this.imageFilterWindow = await this.application.windows.spawnWindow(
+ "Filter Image",
+ this.imageFilterView,
+ this.constructor.imageFilterWindowWidth,
+ this.constructor.imageFilterWindowHeight
+ );
+
+ let reset = () => {
+ try {
+ this.imageFilterView.removeCanvas();
+ } catch(e) { }
+ this.imageFilterView = null;
+ this.imageFilterWindow = null;
+ }
+
+ this.imageFilterWindow.onClose(reset);
+ this.imageFilterView.onSave(async () => {
+ await this.sampleViewer.setImage(this.imageFilterView.getImageSource());
+ setTimeout(() => {
+ this.imageFilterWindow.remove();
+ reset();
+ }, 150);
+ });
+ this.imageFilterView.onCancel(() => {
+ this.imageFilterWindow.remove();
+ reset();
+ });
+ }
+
+ /**
+ * Starts adjusting the image
+ * Replaces the current visible canvas with an in-progress edit.
+ */
+ async startImageAdjustment() {
+ if (this.checkActiveTool("adjust")) return;
+
+ this.imageAdjustmentView = new ImageAdjustmentView(
+ this.config,
+ this.sampleViewer.getDataURL(),
+ this.sampleViewer.node.element.parentElement
+ );
+ this.imageAdjustmentWindow = await this.application.windows.spawnWindow(
+ "Adjust Image",
+ this.imageAdjustmentView,
+ this.constructor.imageAdjustmentWindowWidth,
+ this.constructor.imageAdjustmentWindowHeight
+ );
+
+ let reset = () => {
+ try {
+ this.imageAdjustmentView.removeCanvas();
+ } catch(e) { }
+ this.imageAdjustmentView = null;
+ this.imageAdjustmentWindow = null;
+ }
+
+ this.imageAdjustmentWindow.onClose(reset);
+ this.imageAdjustmentView.onSave(async () => {
+ await this.sampleViewer.setImage(this.imageAdjustmentView.getImageSource());
+ setTimeout(() => {
+ this.imageAdjustmentWindow.remove();
+ reset();
+ }, 150);
+ });
+ this.imageAdjustmentView.onCancel(() => {
+ this.imageAdjustmentWindow.remove();
+ reset();
+ });
+ }
+
+ /**
+ * Checks if there is an active tool and either:
+ * 1. If the active tool matches the intended action, focus on it
+ * 2. If the active tool does not, display a warning
+ * Then return true. If there is no active tool, return false.
+ */
+ checkActiveTool(intendedAction) {
+ if (!isEmpty(this.imageAdjustmentWindow)) {
+ if (intendedAction !== "adjust") {
+ this.notify(
+ "warn",
+ "Finish Adjusting",
+ `Complete adjustments before trying to ${intendedAction}.`
+ );
+ } else {
+ this.imageAdjustmentWindow.focus();
+ }
+ return true;
+ }
+ if (!isEmpty(this.imageFilterWindow)) {
+ if (intendedAction !== "filter") {
+ this.notify(
+ "warn",
+ "Finish Filtering",
+ `Complete filtering before trying to ${intendedAction}.`
+ );
+ } else {
+ this.imageFilterWindow.focus();
+ }
+ return true;
+ }
+ if (!isEmpty(this.imageUpscaleWindow)) {
+ if (intendedAction !== "upscale") {
+ this.notify(
+ "warn",
+ "Finish Upscaling",
+ `Complete your upscale selection or cancel before trying to ${intendedAction}.`
+ );
+ } else {
+ this.imageUpscaleWindow.focus();
+ }
+ return true;
+ }
+ if (!isEmpty(this.imageDownscaleWindow)) {
+ if (intendedAction !== "downscale") {
+ this.notify(
+ "warn",
+ "Finish Downscaling",
+ `Complete your downscale selection or cancel before trying to ${intendedAction}.`
+ );
+ } else {
+ this.imageDownscaleWindow.focus();
+ }
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Opens the image in a new window
+ */
+ async sendToWindow() {
+ const url = URL.createObjectURL(await this.sampleViewer.getBlob());
+ window.open(url, "_blank");
+ }
+
+ /**
+ * Opens the video in a new window
+ */
+ async sendVideoToWindow() {
+ window.open(this.video, "_blank");
+ }
+
+ /**
+ * The callback when the toolbar has been entered
+ */
+ async toolbarEntered() {
+ this.stopHideTimer();
+ }
+
+ /**
+ * The callback when the toolbar has been left
+ */
+ async toolbarLeft() {
+ this.startHideTimer();
+ }
+
+ /**
+ * Stops the timeout that will hide tools
+ */
+ stopHideTimer() {
+ clearTimeout(this.timer);
+ }
+
+ /**
+ * Start the timeout that will hide tools
+ */
+ startHideTimer() {
+ this.timer = setTimeout(async () => {
+ let release = await this.lock.acquire();
+ release();
+ }, this.constructor.hideTime);
+ }
+
+ /**
+ * The callback for MouseEnter
+ */
+ async onMouseEnter(e) {
+ this.stopHideTimer();
+ }
+
+ /**
+ * The callback for MouesLeave
+ */
+ async onMouseLeave(e) {
+ this.startHideTimer();
+ }
+
+ /**
+ * Gets frame time in milliseconds
+ */
+ get frameTime() {
+ return 1000.0 / this.playbackRate;
+ }
+
+ /**
+ * Gets sample IDs mapped to images
+ */
+ get sampleUrls() {
+ return isEmpty(this.samples)
+ ? []
+ : this.isIntermediate
+ ? this.samples.map((v) => `${this.model.api.baseUrl}/invocation/intermediates/${v}.png`)
+ : this.samples.map((v) => `${this.model.api.baseUrl}/invocation/images/${v}.png`);
+ }
+
+ /**
+ * Gets sample IDs mapped to thumbnails
+ */
+ get thumbnailUrls() {
+ return isEmpty(this.samples)
+ ? []
+ : this.isIntermediate
+ ? this.samples.map((v) => `${this.model.api.baseUrl}/invocation/intermediates/${v}.png`)
+ : this.samples.map((v) => `${this.model.api.baseUrl}/invocation/thumbnails/${v}.png`);
+ }
+
+ /**
+ * Spawns a video player if one doesn't exist
+ */
+ async spawnVideoPlayer() {
+ if (isEmpty(this.videoPlayerWindow)) {
+ this.videoPlayerWindow = await this.application.spawnVideoPlayer(this.video);
+ this.videoPlayerWindow.onClose(() => { delete this.videoPlayerWindow; });
+ } else {
+ this.videoPlayerWindow.focus();
+ }
+ }
+
+ /**
+ * Closes the video player if it exists
+ */
+ closeVideoPlayer() {
+ if (!isEmpty(this.videoPlayerWindow)) {
+ this.videoPlayerWindow.remove();
+ }
+ }
+
+ /**
+ * Sets a final video
+ */
+ setVideo(newVideo) {
+ if (this.video !== newVideo) {
+ this.closeVideoPlayer();
+ }
+ this.video = newVideo;
+ if (!isEmpty(newVideo)) {
+ this.videoToolsMenu.show();
+ this.spawnVideoPlayer();
+ } else {
+ this.videoToolsMenu.hide();
+ }
+ }
+
+ /**
+ * Sets samples
+ */
+ setSamples(sampleImages, isAnimation) {
+ // Get IDs from samples
+ if (isEmpty(sampleImages)) {
+ this.samples = null;
+ this.images.removeClass("has-sample");
+ } else {
+ this.samples = sampleImages.map((v) => v.split("/").slice(-1)[0].split(".")[0]);
+ }
+
+ this.isIntermediate = !isEmpty(this.samples) && sampleImages[0].indexOf("intermediate") !== -1;
+ this.isAnimation = isAnimation;
+
+ this.sampleChooser.setIsAnimation(isAnimation);
+ this.sampleChooser.setSamples(this.thumbnailUrls).then(() => {
+ this.sampleChooser.setActiveIndex(this.activeIndex, false);
+ });
+ this.sampleViewer.setImage(isAnimation ? this.sampleUrls : isEmpty(this.activeIndex) ? null : this.sampleUrls[this.activeIndex]);
+ if (this.isAnimation) {
+ this.sampleViewer.setFrame(this.activeIndex);
+ }
+ if (!isEmpty(this.activeIndex)) {
+ if (this.isAnimation) {
+ this.imageToolsMenu.hide();
+ if (!isEmpty(this.video)) {
+ this.videoToolsMenu.show();
+ }
+ } else {
+ this.imageToolsMenu.show();
+ this.videoToolsMenu.hide();
+ }
+ sleep(100).then(() => {
+ waitFor(() => !isEmpty(this.sampleViewer.width)).then(() => {
+ this.images.setDimension(this.sampleViewer.width, this.sampleViewer.height, false);
+ this.images.addClass("has-sample");
+ });
+ });
+ }
+ }
+
+ /**
+ * Sets the active index when looking at images
+ */
+ setActive(activeIndex) {
+ this.activeIndex = activeIndex;
+ if (this.isAnimation) {
+ this.sampleChooser.setActiveIndex(activeIndex, false);
+ this.sampleViewer.setFrame(activeIndex);
+ if (!isEmpty(activeIndex)) {
+ this.imageToolsMenu.hide();
+ if (!isEmpty(this.video)) {
+ this.videoToolsMenu.show();
+ }
+ }
+ } else {
+ this.sampleViewer.setImage(this.sampleUrls[this.activeIndex]);
+ if (!isEmpty(activeIndex)) {
+ this.imageToolsMenu.show();
+ this.videoToolsMenu.hide();
+ }
+ }
+
+ if (isEmpty(activeIndex)) {
+ this.images.removeClass("has-sample");
+ this.images.setDimension(this.engine.width, this.engine.height);
+ this.sampleViewer.hide();
+ } else {
+ sleep(100).then(() => {
+ waitFor(() => !isEmpty(this.sampleViewer.width)).then(() => {
+ this.images.setDimension(this.sampleViewer.width, this.sampleViewer.height, false);
+ this.images.addClass("has-sample");
+ });
+ });
+ }
+ }
+
+ /**
+ * Ticks the animation to the next frame
+ */
+ tickAnimation() {
+ if (isEmpty(this.samples)) return;
+ let frameStart = (new Date()).getTime();
+ requestAnimationFrame(() => {
+ let activeIndex = this.activeIndex,
+ frameLength = this.samples.length,
+ nextIndex = activeIndex + 1;
+
+ if (this.isPlaying) {
+ if (nextIndex < frameLength) {
+ this.setActive(nextIndex);
+ } else if(this.isLooping) {
+ this.setActive(0);
+ } else {
+ this.sampleChooser.setPlayAnimation(false);
+ return;
+ }
+ let frameTime = (new Date()).getTime() - frameStart;
+ clearTimeout(this.tick);
+ this.tick = setTimeout(
+ () => this.tickAnimation(),
+ this.frameTime - frameTime
+ );
+ }
+ });
+ }
+
+ /**
+ * Modifies playback rate
+ */
+ async setPlaybackRate(playbackRate, updateChooser = true) {
+ this.playbackRate = playbackRate;
+ if (updateChooser) {
+ this.sampleChooser.setPlaybackRate(playbackRate);
+ }
+ }
+
+ /**
+ * Starts/stops playing
+ */
+ async setPlay(playing, updateChooser = true) {
+ this.isPlaying = playing;
+ if (playing) {
+ if (this.activeIndex >= this.samples.length - 1) {
+ // Reset animation
+ this.setActive(0);
+ }
+ clearTimeout(this.tick);
+ this.tick = setTimeout(
+ () => this.tickAnimation(),
+ this.frameTime
+ );
+ } else {
+ clearTimeout(this.tick);
+ }
+ if (updateChooser) {
+ this.sampleChooser.setPlayAnimation(playing);
+ }
+ }
+
+ /**
+ * Enables/disables looping
+ */
+ async setLoop(loop, updateChooser = true) {
+ this.isLooping = loop;
+ if (updateChooser) {
+ this.sampleChooser.setLoopAnimation(loop);
+ }
+ }
+
+ /**
+ * Sets horizontal tiling
+ */
+ async setTileHorizontal(tile, updateChooser = true) {
+ this.tileHorizontal = tile;
+ this.sampleViewer.tileHorizontal = tile;
+ requestAnimationFrame(() => {
+ if (updateChooser) {
+ this.sampleChooser.setHorizontalTile(tile);
+ }
+ this.sampleViewer.checkVisibility();
+ });
+ }
+
+ /**
+ * Sets vertical tiling
+ */
+ async setTileVertical(tile, updateChooser = true) {
+ this.tileVertical = tile;
+ this.sampleViewer.tileVertical = tile;
+ requestAnimationFrame(() => {
+ if (updateChooser) {
+ this.sampleChooser.setVerticalTile(tile);
+ }
+ this.sampleViewer.checkVisibility();
+ });
+ }
+
+ /**
+ * Shows the canvas, hiding samples
+ */
+ async showCanvas(updateChooser = true) {
+ this.setPlay(false);
+ this.sampleViewer.hide();
+ this.imageToolsMenu.hide();
+ this.videoToolsMenu.hide();
+ if (updateChooser) {
+ this.sampleChooser.setActiveIndex(null);
+ }
+ this.images.setDimension(this.engine.width, this.engine.height, false);
+ setTimeout(() => { this.images.removeClass("has-sample"); }, 250);
+ }
+
+ /**
+ * On initialize, add DOM nodes
+ */
+ async initialize() {
+ // Create views
+ this.sampleChooser = new SampleChooserView(this.config);
+ this.sampleViewer = new SampleView(this.config);
+
+ // Bind chooser events
+ this.sampleChooser.onShowCanvas(() => this.showCanvas(false));
+ this.sampleChooser.onLoopAnimation((loop) => this.setLoop(loop, false));
+ this.sampleChooser.onPlayAnimation((play) => this.setPlay(play, false));
+ this.sampleChooser.onTileHorizontal((tile) => this.setTileHorizontal(tile, false));
+ this.sampleChooser.onTileVertical((tile) => this.setTileVertical(tile, false));
+ this.sampleChooser.onSetActive((active) => this.setActive(active, false));
+ this.sampleChooser.onSetPlaybackRate((rate) => this.setPlaybackRate(rate, false));
+
+ // Create toolbars
+ this.imageToolsMenu = new ToolbarView(this.config);
+ this.prepareImageMenu(this.imageToolsMenu);
+ this.videoToolsMenu = new ToolbarView(this.config);
+ this.prepareVideoMenu(this.videoToolsMenu);
+
+ // Get initial variables
+ this.activeIndex = 0;
+ this.playbackRate = SampleChooserView.playbackRate;
+
+ // Add chooser to main container
+ this.application.container.appendChild(await this.sampleChooser.render());
+
+ // Get image editor in DOM
+ let imageEditor = await this.images.getNode();
+
+ // Add sample viewer and toolbars to canvas
+ imageEditor.find("enfugue-node-canvas").append(
+ await this.sampleViewer.getNode(),
+ E.div().class("sample-tools-container").content(
+ await this.imageToolsMenu.getNode(),
+ await this.videoToolsMenu.getNode(),
+ )
+ );
+ imageEditor.render();
+ }
+
+ /**
+ * Gets default state, no samples
+ */
+ getDefaultState() {
+ return {
+ "samples": {
+ "urls": null,
+ "active": null,
+ "video": null,
+ "animation": false
+ }
+ };
+ }
+
+ /**
+ * Get state is only for UI; only use the sample choosers here
+ */
+ getState(includeImages = true) {
+ if (!includeImages) {
+ return this.getDefaultState();
+ }
+
+ return {
+ "samples": {
+ "urls": this.sampleUrls,
+ "active": this.activeIndex,
+ "animation": this.isAnimation,
+ "video": this.video
+ }
+ };
+ }
+
+ /**
+ * Set state is only for UI; set the sample choosers here
+ */
+ setState(newState) {
+ if (isEmpty(newState) || isEmpty(newState.samples) || isEmpty(newState.samples.urls)) {
+ this.setSamples(null);
+ this.setVideo(null);
+ } else {
+ this.activeIndex = newState.samples.active;
+ this.setSamples(
+ newState.samples.urls,
+ newState.samples.animation === true
+ );
+ this.setVideo(newState.samples.video);
+ }
+ }
+}
+
+export { SamplesController };
diff --git a/src/js/controller/file/01-new.mjs b/src/js/controller/file/01-new.mjs
index c4ff6c8b..40000417 100644
--- a/src/js/controller/file/01-new.mjs
+++ b/src/js/controller/file/01-new.mjs
@@ -25,6 +25,7 @@ class NewFileController extends MenuController {
*/
async onClick() {
await this.application.resetState();
+ this.notify("info", "Success", "Successfully reset to defaults.");
}
}
diff --git a/src/js/controller/file/04-history.mjs b/src/js/controller/file/04-history.mjs
index 0eb7ee5f..dbbd0510 100644
--- a/src/js/controller/file/04-history.mjs
+++ b/src/js/controller/file/04-history.mjs
@@ -170,11 +170,13 @@ class HistoryController extends MenuController {
if (!isEmpty(item.state.images)) {
let nodeNumber = 1,
nodes;
+
if (Array.isArray(item.state.images)) {
nodes = item.state.images;
} else {
nodes = item.state.images.nodes;
}
+
for (let node of nodes) {
let nodeSummaryParts = {};
switch (node.classname) {
@@ -241,6 +243,7 @@ class HistoryController extends MenuController {
this.resetHistory();
}, this.constructor.searchHistoryDebounceDelay);
});
+
this.historyTable = new HistoryTableView(this.config, await this.getHistory());
let container = new ParentView(this.config);
diff --git a/src/js/controller/file/05-results.mjs b/src/js/controller/file/05-results.mjs
index 7d996061..4cdbc7fa 100644
--- a/src/js/controller/file/05-results.mjs
+++ b/src/js/controller/file/05-results.mjs
@@ -2,7 +2,12 @@
import { MenuController } from "../menu.mjs";
import { ModelTableView } from "../../view/table.mjs";
import { ImageView } from "../../view/image.mjs";
-import { isEmpty, humanDuration, sleep } from "../../base/helpers.mjs";
+import {
+ sleep,
+ isEmpty,
+ humanDuration,
+ downloadAsDataURL
+} from "../../base/helpers.mjs";
import { ElementBuilder } from "../../base/builder.mjs";
const E = new ElementBuilder({
@@ -33,7 +38,7 @@ class InvocationTableView extends ModelTableView {
"label": "Delete",
"click": async function(datum) {
await InvocationTableView.deleteInvocation(datum.id); // Set at init
- await sleep(100); // Wait a tick
+ await sleep(250); // Wait 1/4 second
await this.parent.requery();
}
}
@@ -70,24 +75,89 @@ class InvocationTableView extends ModelTableView {
static columnFormatters = {
"duration": (value) => humanDuration(parseFloat(value), true, true),
"plan": (plan) => {
+ plan.layers = isEmpty(plan.layers)
+ ? "(none)"
+ : `(${plan.layers.length} layer${plan.layers.length==1?'':'s'})`;
return JSON.stringify(plan);
},
"prompt": (_, datum) => datum.plan.prompt,
- "seed": (_, datum) => datum.plan.seed.toString(),
+ "seed": (_, datum) => `${datum.plan.seed}`,
"outputs": async function(outputCount, datum) {
if (outputCount > 0) {
let outputContainer = E.invocationOutputs();
- for (let i = 0; i < outputCount; i++) {
- let imageName = `${datum.id}_${i}.png`,
- imageSource = `/api/invocation/images/${imageName}`,
- thumbnailSource = `/api/invocation/thumbnails/${imageName}`,
- imageView = new ImageView(this.config, thumbnailSource, false),
+ if (!isEmpty(datum.plan.animation_frames) && datum.plan.animation_frames > 0) {
+ let videoSource = `/api/invocation/animation/images/${datum.id}.mp4`,
+ gifSource = `/api/invocation/animation/images/${datum.id}.gif`,
+ thumbnailVideoSource = `/api/invocation/animation/thumbnails/${datum.id}.mp4`,
imageContainer = E.invocationOutput()
- .content(await imageView.getNode())
- .on("click", async () => {
- InvocationTableView.setCurrentInvocationImage(imageSource); // Set at init
+ .content(
+ E.video()
+ .content(E.source().src(thumbnailVideoSource))
+ .autoplay(true)
+ .muted(true)
+ .loop(true),
+ E.div().class("buttons").content(
+ E.button()
+ .content(E.i().class("fa-solid fa-film"))
+ .on("click", (e) => {
+ e.stopPropagation();
+ let imageURLs = new Array(datum.plan.animation_frames).fill(null).map((_, i) => {
+ return `/api/invocation/images/${datum.id}_${i}.png`;
+ });
+ InvocationTableView.showAnimationFrames(imageURLs);
+ })
+ .data("tooltip", "Click to View Frames"),
+ E.button()
+ .content(E.i().class("fa-solid fa-file-video"))
+ .on("click", (e) => {
+ e.stopPropagation();
+ window.open(gifSource, "_blank");
+ })
+ .data("tooltip", "Click to View as .GIF"),
+ E.button()
+ .content(E.i().class("fa-solid fa-edit"))
+ .on("click", async (e) => {
+ e.stopPropagation();
+ InvocationTableView.initializeStateFromImage(
+ await downloadAsDataURL(videoSource),
+ true
+ );
+ })
+ .data("tooltip", "Click to Edit")
+ )
+ )
+ .data("tooltip", "Click to View")
+ .on("click", () => {
+ window.open(videoSource, "_blank");
});
- outputContainer.append(imageContainer);
+
+ outputContainer.append(imageContainer);
+ } else {
+ for (let i = 0; i < outputCount; i++) {
+ let imageName = `${datum.id}_${i}.png`,
+ imageSource = `/api/invocation/images/${imageName}`,
+ thumbnailSource = `/api/invocation/thumbnails/${imageName}`,
+ imageView = new ImageView(this.config, thumbnailSource, false),
+ imageContainer = E.invocationOutput()
+ .content(
+ await imageView.getNode(),
+ E.div().class("buttons").content(
+ E.button()
+ .content(E.i().class("fa-solid fa-edit"))
+ .on("click", (e) => {
+ e.stopPropagation();
+ InvocationTableView.initializeStateFromImage(imageSource);
+ })
+ .data("tooltip", "Click to Edit")
+ )
+ )
+ .data("tooltip", "Click to View")
+ .on("click", async () => {
+ window.open(imageSource, "_blank");
+ });
+
+ outputContainer.append(imageContainer);
+ }
}
return outputContainer;
} else if(!isEmpty(datum.error)) {
@@ -146,8 +216,11 @@ class ResultsController extends MenuController {
async initialize() {
await super.initialize();
InvocationTableView.deleteInvocation = (id) => { this.model.delete(`/invocation/${id}`); };
- InvocationTableView.setCurrentInvocationImage = (image) => this.application.images.setCurrentInvocationImage(image);
-
+ InvocationTableView.initializeStateFromImage = (image, isVideo) => this.application.initializeStateFromImage(image, true, null, null, isVideo);
+ InvocationTableView.showAnimationFrames = async (frames) => {
+ await this.application.samples.setSamples(frames, true);
+ setTimeout(() => this.application.samples.setPlay(true), 250);
+ };
}
/**
diff --git a/src/js/controller/models/01-civitai.mjs b/src/js/controller/models/01-civitai.mjs
index 8d0dd117..58624c26 100644
--- a/src/js/controller/models/01-civitai.mjs
+++ b/src/js/controller/models/01-civitai.mjs
@@ -44,7 +44,6 @@ class CivitAIItemView extends View {
*/
async build() {
// TODO: clean this up, it's messy
- console.log(this.item);
let node = await super.build(),
selectedVersion = this.item.modelVersions[0].name,
name = E.h2().content(
@@ -74,7 +73,15 @@ class CivitAIItemView extends View {
),
versionImageNodes = versionImages.map(
(image) => {
- let node = E.img().src(image.url).css({
+ let node;
+ if (this.item.type === "MotionModule") {
+ node = E.video().content(
+ E.source().src(image.url)
+ ).autoplay(true).muted(true).loop(true).controls(false);
+ } else {
+ node = E.img().src(image.url);
+ }
+ node.css({
"max-width": `${((1/versionImages.length)*100).toFixed(2)}%`
});
if (!isEmpty(image.meta) && !isEmpty(image.meta.prompt)) {
@@ -379,6 +386,14 @@ class CivitAIBrowserView extends TabbedView {
(...args) => download("inversion", ...args)
)
);
+ this.addTab(
+ "Motion Modules",
+ new CivitAICategoryBrowserView(
+ config,
+ (...args) => getCategoryData("motion", ...args),
+ (...args) => download("motion", ...args)
+ )
+ );
}
};
diff --git a/src/js/controller/sidebar/01-canvas.mjs b/src/js/controller/sidebar/01-canvas.mjs
new file mode 100644
index 00000000..ffb4518e
--- /dev/null
+++ b/src/js/controller/sidebar/01-canvas.mjs
@@ -0,0 +1,83 @@
+/** @module controller/sidebar/01-canvas */
+import { isEmpty } from "../../base/helpers.mjs";
+import { CanvasFormView } from "../../forms/enfugue/canvas.mjs";
+import { Controller } from "../base.mjs";
+
+/**
+ * Extend the Controller to put the form in the sidebar and trigger changes.
+ */
+class CanvasController extends Controller {
+ /**
+ * Get state from the form
+ */
+ getState(includeImages = true) {
+ return { "canvas": this.canvasForm.values };
+ }
+
+ /**
+ * Gets default state
+ */
+ getDefaultState() {
+ return {
+ "canvas": {
+ "tileHorizontal": false,
+ "tileVertical": false,
+ "width": this.config.model.invocation.width,
+ "height": this.config.model.invocation.height,
+ "useTiling": false,
+ "tilingSize": this.config.model.invocation.tilingSize,
+ "tilingStride": this.config.model.invocation.tilingStride,
+ "tilingMaskType": this.config.model.invocation.tilingMaskType
+ }
+ };
+ }
+
+ /**
+ * Set state on the form
+ */
+ setState(newState) {
+ if (!isEmpty(newState.canvas)) {
+ this.canvasForm.setValues(newState.canvas).then(() => this.canvasForm.submit());
+ }
+ }
+
+ /**
+ * On initialize, create form and bind events.
+ */
+ async initialize() {
+ // Create form
+ this.canvasForm = new CanvasFormView(this.config);
+ this.canvasForm.onSubmit(async (values) => {
+ if (!this.images.hasClass("has-sample")) {
+ this.images.setDimension(values.width, values.height);
+ }
+ this.engine.width = values.width;
+ this.engine.height = values.height;
+ this.engine.tileHorizontal = values.tileHorizontal;
+ this.engine.tileVertical = values.tileVertical;
+ if (values.useTiling || values.tileHorizontal || values.tileVertical) {
+ this.engine.tilingSize = values.tilingSize;
+ this.engine.tilingMaskType = values.tilingMaskType;
+ this.engine.tilingStride = isEmpty(values.tilingStride) ? 64 : values.tilingStride;
+ } else {
+ this.engine.tilingStride = 0;
+ }
+ });
+
+ // Add form to sidebar
+ this.application.sidebar.addChild(this.canvasForm);
+
+ // Add a callback when the image dimension is manually set
+ this.images.onSetDimension((newWidth, newHeight) => {
+ let currentState = this.canvasForm.values;
+ currentState.width = newWidth;
+ currentState.height = newHeight;
+ this.canvasForm.setValues(currentState).then(() => this.canvasForm.submit());
+ });
+
+ // Trigger once the app is ready to change shapes as needed
+ this.subscribe("applicationReady", () => this.canvasForm.submit());
+ }
+}
+
+export { CanvasController as SidebarController };
diff --git a/src/js/controller/sidebar/01-engine.mjs b/src/js/controller/sidebar/01-engine.mjs
deleted file mode 100644
index 66f8aab1..00000000
--- a/src/js/controller/sidebar/01-engine.mjs
+++ /dev/null
@@ -1,88 +0,0 @@
-/** @module controller/sidebar/01-engine */
-import { isEmpty } from "../../base/helpers.mjs";
-import { EngineFormView } from "../../forms/enfugue/engine.mjs";
-import { Controller } from "../base.mjs";
-
-/**
- * Extend the menu controller to bind initialize
- */
-class EngineController extends Controller {
- /**
- * Return data from the engine form
- */
- getState(includeImages = true) {
- return { "engine": this.engineForm.values };
- }
-
- /**
- * Sets state in the form
- */
- setState(newState) {
- if (!isEmpty(newState.engine)) {
- this.engineForm.setValues(newState.engine).then(() => this.engineForm.submit());
- }
- }
-
- /**
- * Gets default state
- */
- getDefaultState() {
- return {
- "engine": {
- "size": null,
- "refinerSize": null,
- "inpainterSize": null
- }
- }
- };
-
- /**
- * On initialization, append the engine form
- */
- async initialize() {
- // Builds form
- this.engineForm = new EngineFormView(this.config);
-
- // Bind submit
- this.engineForm.onSubmit(async (values) => {
- this.engine.size = values.size;
- this.engine.refinerSize = values.refinerSize;
- this.engine.inpainterSize = values.inpainterSize;
- });
-
- // Add to sidebar
- this.application.sidebar.addChild(this.engineForm);
-
- // Bind events to listen for when to show form and fields
- this.subscribe("engineModelTypeChange", (newType) => {
- if (isEmpty(newType) || newType !== "model") {
- this.engineForm.show();
- } else {
- this.engineForm.hide();
- }
- });
- this.subscribe("engineRefinerChange", (newRefiner) => {
- if (isEmpty(newRefiner)) {
- this.engineForm.removeClass("show-refiner");
- } else {
- this.engineForm.addClass("show-refiner");
- }
- });
- this.subscribe("engineInpainterChange", (newInpainter) => {
- if (isEmpty(newInpainter)) {
- this.engineForm.removeClass("show-inpainter");
- } else {
- this.engineForm.addClass("show-inpainter");
- }
- });
- this.subscribe("engineModelChange", (newModel) => {
- if (isEmpty(newModel)) {
- this.engineForm.show();
- } else if (this.engine.modelType === "model"){
- this.engineForm.hide();
- }
- });
- }
-}
-
-export { EngineController as SidebarController }
diff --git a/src/js/controller/sidebar/02-canvas.mjs b/src/js/controller/sidebar/02-canvas.mjs
deleted file mode 100644
index 7e208c9a..00000000
--- a/src/js/controller/sidebar/02-canvas.mjs
+++ /dev/null
@@ -1,97 +0,0 @@
-/** @module controller/sidebar/02-canvas */
-import { isEmpty } from "../../base/helpers.mjs";
-import { CanvasFormView } from "../../forms/enfugue/canvas.mjs";
-import { Controller } from "../base.mjs";
-
-/**
- * Extend the Controller to put the form in the sidebar and trigger changes.
- */
-class CanvasController extends Controller {
- /**
- * Get state from the form
- */
- getState(includeImages = true) {
- return { "canvas": this.canvasForm.values };
- }
-
- /**
- * Gets default state
- */
- getDefaultState() {
- return {
- "canvas": {
- "width": this.config.model.invocation.width,
- "height": this.config.model.invocation.height,
- "useChunking": false,
- "chunkingSize": this.config.model.invocation.chunkingSize,
- "chunkingMaskType": this.config.model.invocation.chunkingMaskType
- }
- };
- }
-
- /**
- * Set state on the form
- */
- setState(newState) {
- if (!isEmpty(newState.canvas)) {
- this.canvasForm.setValues(newState.canvas).then(() => this.canvasForm.submit());
- }
- }
-
- /**
- * On initialize, create form and bind events.
- */
- async initialize() {
- // Create form
- this.canvasForm = new CanvasFormView(this.config);
- this.canvasForm.onSubmit(async (values) => {
- this.images.setDimension(values.width, values.height);
- this.engine.width = values.width;
- this.engine.height = values.height;
- if (values.useChunking) {
- this.engine.chunkingSize = values.chunkingSize
- this.engine.chunkingMaskType = values.chunkingMaskType;
- } else {
- this.engine.chunkingSize = 0;
- }
- });
-
- // Add form to sidebar
- this.application.sidebar.addChild(this.canvasForm);
-
- // Subscribe to model changes to look for defaults
- this.subscribe("modelPickerChange", async (newModel) => {
- if (!isEmpty(newModel)) {
- let defaultConfig = newModel.defaultConfiguration,
- canvasConfig = {};
-
- if (!isEmpty(defaultConfig.width)) {
- canvasConfig.width = defaultConfig.width;
- }
- if (!isEmpty(defaultConfig.height)) {
- canvasConfig.height = defaultConfig.height;
- }
- if (!isEmpty(defaultConfig.chunking_size)) {
- canvasConfig.chunkingSize = defaultConfig.chunking_size;
- if (canvasConfig.chunkingSize === 0) {
- canvasConfig.useChunking = false;
- }
- }
- if (!isEmpty(defaultConfig.chunking_mask_type)) {
- canvasConfig.chunkingMaskType = defaultConfig.chunking_mask_type;
- }
-
- if (!isEmpty(canvasConfig)) {
- if (isEmpty(canvasConfig.useChunking)) {
- canvasConfig.useChunking = true;
- }
- await this.canvasForm.setValues(canvasConfig);
- await this.canvasForm.submit();
- }
- }
- });
- this.subscribe("applicationReady", () => this.canvasForm.submit());
- }
-}
-
-export { CanvasController as SidebarController };
diff --git a/src/js/controller/sidebar/02-denoising.mjs b/src/js/controller/sidebar/02-denoising.mjs
new file mode 100644
index 00000000..2314e509
--- /dev/null
+++ b/src/js/controller/sidebar/02-denoising.mjs
@@ -0,0 +1,65 @@
+/** @module controlletr/sidebar/02-denoising */
+import { isEmpty } from "../../base/helpers.mjs";
+import { Controller } from "../base.mjs";
+import { DenoisingFormView } from "../../forms/enfugue/denoising.mjs";
+
+/**
+ * Extends the menu controller for state and init
+ */
+class DenoisingController extends Controller {
+ /**
+ * Get data from the generation form
+ */
+ getState(includeImages = true) {
+ return { "denoising": this.denoisingForm.values };
+ }
+
+ /**
+ * Gets default state
+ */
+ getDefaultState() {
+ return {
+ "denoising": {
+ "strength": 0.99
+ }
+ }
+ }
+
+ /**
+ * Set state in the generation form
+ */
+ setState(newState) {
+ if (!isEmpty(newState.denoising)) {
+ this.denoisingForm.setValues(newState.denoising).then(() => this.denoisingForm.submit());
+ }
+ };
+
+ /**
+ * On init, append form
+ */
+ async initialize() {
+ this.denoisingForm = new DenoisingFormView(this.config);
+ this.denoisingForm.hide();
+ this.denoisingForm.onSubmit(async (values) => {
+ this.engine.strength = values.strength;
+ });
+ this.application.sidebar.addChild(this.denoisingForm);
+ let showForDenoising = false,
+ showForInpainting = false,
+ checkShow = () => {
+ if (showForDenoising || showForInpainting) {
+ this.denoisingForm.show();
+ } else {
+ this.denoisingForm.hide();
+ }
+ };
+ this.subscribe("layersChanged", (layers) => {
+ showForDenoising = layers.reduce((carry, item) => carry || item.denoise, false);
+ checkShow();
+ });
+ this.subscribe("inpaintEnabled", () => { showForInpainting = true; checkShow(); });
+ this.subscribe("inpaintDisabled", () => { showForInpainting = false; checkShow(); });
+ }
+}
+
+export { DenoisingController as SidebarController };
diff --git a/src/js/controller/sidebar/03-inpainting.mjs b/src/js/controller/sidebar/03-inpainting.mjs
new file mode 100644
index 00000000..1fdf27b8
--- /dev/null
+++ b/src/js/controller/sidebar/03-inpainting.mjs
@@ -0,0 +1,191 @@
+/** @module controller/sidebar/03-inpainting */
+import { isEmpty } from "../../base/helpers.mjs";
+import { Controller } from "../base.mjs";
+import { ScribbleView } from "../../view/scribble.mjs";
+import { ToolbarView } from "../../view/menu.mjs";
+import { InpaintingFormView } from "../../forms/enfugue/inpainting.mjs";
+
+/**
+ * Register controller to add to sidebar and manage state
+ */
+class InpaintingController extends Controller {
+ /**
+ * When asked for state, return data from form
+ */
+ getState(includeImages = true) {
+ return {
+ "inpainting": {
+ "options": this.inpaintForm.values,
+ "mask": includeImages ? this.scribbleView.src : null,
+ }
+ };
+ }
+
+ /**
+ * Get default state
+ */
+ getDefaultState() {
+ return {
+ "inpainting": {
+ "options": {
+ "outpaint": true,
+ "inpaint": false,
+ "croppedInpaint": true,
+ "croppedInpaintFeather": 32
+ }
+ }
+ };
+ }
+
+ /**
+ * Set state in the inpainting form
+ */
+ setState(newState) {
+ if (!isEmpty(newState.inpainting)) {
+ if (!isEmpty(newState.inpainting.options)) {
+ this.inpaintForm.setValues(newState.inpainting.options).then(() => this.inpaintForm.submit());
+ }
+ if (!isEmpty(newState.inpainting.mask)) {
+ let image = new Image();
+ image.onload = () => {
+ this.scribbleView.setMemory(image);
+ }
+ image.src = newState.inpainting.mask;
+ }
+ }
+ }
+
+ /**
+ * Prepares a menu (either in the header or standalone)
+ */
+ async prepareMenu(menu) {
+ let pencilShape = await menu.addItem("Toggle Pencil Shape", "fa-regular fa-square", "q"),
+ pencilErase = await menu.addItem("Toggle Eraser", "fa-solid fa-eraser", "s"),
+ pencilClear = await menu.addItem("Clear Canvas", "fa-solid fa-delete-left", "l"),
+ pencilFill = await menu.addItem("Fill Canvas", "fa-solid fa-fill-drip", "v"),
+ pencilIncrease = await menu.addItem("Increase Pencil Size", "fa-solid fa-plus", "i"),
+ pencilDecrease = await menu.addItem("Decrease Pencil Size", "fa-solid fa-minus", "d"),
+ hideMask = await menu.addItem("Toggle Mask Visibility", "fa-solid fa-eye", "y"),
+ lockMask = await menu.addItem("Toggle Mask Locked/Unlocked", "fa-solid fa-lock", "k");
+
+ pencilShape.onClick(() => {
+ if (this.scribbleView.shape === "circle") {
+ this.scribbleView.shape = "square";
+ pencilShape.setIcon("fa-regular fa-circle");
+ } else {
+ this.scribbleView.shape = "circle";
+ pencilShape.setIcon("fa-regular fa-square");
+ }
+ });
+ pencilErase.onClick(() => {
+ if (this.scribbleView.isEraser) {
+ this.scribbleView.isEraser = false;
+ pencilErase.setIcon("fa-solid fa-eraser");
+ } else {
+ this.scribbleView.isEraser = true;
+ pencilErase.setIcon("fa-solid fa-pencil");
+ }
+ });
+ pencilClear.onClick(() => { this.scribbleView.clearMemory(); });
+ pencilFill.onClick(() => { this.scribbleView.fillMemory(); });
+ pencilIncrease.onClick(() => { this.scribbleView.increaseSize(); });
+ pencilDecrease.onClick(() => { this.scribbleView.decreaseSize(); });
+ hideMask.onClick(() => {
+ if (this.scribbleView.hidden) {
+ this.scribbleView.show();
+ hideMask.setIcon("fa-solid fa-eye");
+ } else {
+ this.scribbleView.hide();
+ hideMask.setIcon("fa-solid fa-eye-slash");
+ }
+ });
+ lockMask.onClick(() => {
+ if (this.scribbleView.hasClass("locked")) {
+ this.scribbleView.removeClass("locked");
+ lockMask.setIcon("fa-solid fa-lock");
+ } else {
+ this.scribbleView.addClass("locked");
+ lockMask.setIcon("fa-solid fa-unlock");
+ }
+ });
+ }
+
+ /**
+ * Resizes the mask to the engine width
+ */
+ resize() {
+ this.scribbleView.resizeCanvas(
+ this.images.width,
+ this.images.height
+ );
+ }
+
+ /**
+ * On initialize, build sub controllers and add DOM nodes
+ */
+ async initialize() {
+ this.scribbleView = new ScribbleView(
+ this.config,
+ this.engine.width,
+ this.engine.height
+ );
+ this.scribbleView.hide();
+ let setMaskTimer;
+ this.scribbleView.onDraw(() => {
+ clearTimeout(setMaskTimer);
+ setMaskTimer = setTimeout(() => {
+ this.engine.mask = this.scribbleView.src;
+ }, 100);
+ });
+ this.scribbleToolbar = new ToolbarView(this.config);
+ this.scribbleToolbar.addClass("inpainting");
+ this.scribbleToolbar.hide()
+
+ await this.prepareMenu(this.scribbleToolbar);
+
+ this.inpaintForm = new InpaintingFormView(this.config);
+ this.inpaintForm.hide();
+ this.inpaintForm.onSubmit((values) => {
+ // Show/hide parts
+ if (values.inpaint) {
+ this.publish("inpaintEnabled");
+ this.scribbleView.show();
+ this.scribbleToolbar.show();
+ this.engine.mask = this.scribbleView.src;
+ } else {
+ this.publish("inpaintDisabled");
+ this.scribbleView.hide();
+ this.scribbleToolbar.hide();
+ this.engine.mask = null;
+ }
+ // Set engine values
+ this.engine.outpaint = values.outpaint;
+ this.engine.cropInpaint = values.cropInpaint;
+ this.engine.inpaintFeather = values.inpaintFeather;
+ });
+
+ this.subscribe("engineWidthChange", () => this.resize());
+ this.subscribe("engineHeightChange", () => this.resize());
+ this.subscribe("layersChanged", (layers) => {
+ if (isEmpty(layers)) {
+ this.inpaintForm.hide();
+ this.scribbleView.hide();
+ this.scribbleToolbar.hide();
+ this.engine.mask = null;
+ } else {
+ this.inpaintForm.show();
+ if (this.inpaintForm.values.inpaint) {
+ this.engine.mask = this.scribbleView.src;
+ this.scribbleView.show();
+ this.scribbleToolbar.show();
+ }
+ }
+ });
+
+ this.application.sidebar.addChild(this.inpaintForm);
+ this.application.container.appendChild(await this.scribbleToolbar.render());
+ (await this.application.images.getNode()).find("enfugue-image-editor-overlay").append(await this.scribbleView.getNode());
+ }
+}
+
+export { InpaintingController as SidebarController };
diff --git a/src/js/controller/sidebar/04-ip-adapter.mjs b/src/js/controller/sidebar/04-ip-adapter.mjs
new file mode 100644
index 00000000..f299c888
--- /dev/null
+++ b/src/js/controller/sidebar/04-ip-adapter.mjs
@@ -0,0 +1,57 @@
+/** @module controlletr/sidebar/04-ip-adapter */
+import { isEmpty } from "../../base/helpers.mjs";
+import { Controller } from "../base.mjs";
+import { IPAdapterFormView } from "../../forms/enfugue/ip-adapter.mjs";
+
+/**
+ * Extends the menu controller for state and init
+ */
+class IPAdapterController extends Controller {
+ /**
+ * Get data from the IP adapter form
+ */
+ getState(includeImages = true) {
+ return { "ip": this.ipAdapterForm.values };
+ }
+
+ /**
+ * Gets default state
+ */
+ getDefaultState() {
+ return {
+ "ip": {
+ "ipAdapterModel": "default"
+ }
+ }
+ }
+
+ /**
+ * Set state in the IP adapter form
+ */
+ setState(newState) {
+ if (!isEmpty(newState.ip)) {
+ this.ipAdapterForm.setValues(newState.ip).then(() => this.ipAdapterForm.submit());
+ }
+ };
+
+ /**
+ * On init, append form
+ */
+ async initialize() {
+ this.ipAdapterForm = new IPAdapterFormView(this.config);
+ this.ipAdapterForm.hide();
+ this.ipAdapterForm.onSubmit(async (values) => {
+ this.engine.ipAdapterModel = values.ipAdapterModel;
+ });
+ this.application.sidebar.addChild(this.ipAdapterForm);
+ this.subscribe("layersChanged", (newLayers) => {
+ if (newLayers.reduce((carry, item) => carry || item.imagePrompt, false)) {
+ this.ipAdapterForm.show();
+ } else {
+ this.ipAdapterForm.hide();
+ }
+ });
+ }
+}
+
+export { IPAdapterController as SidebarController };
diff --git a/src/js/controller/sidebar/03-tweaks.mjs b/src/js/controller/sidebar/05-tweaks.mjs
similarity index 96%
rename from src/js/controller/sidebar/03-tweaks.mjs
rename to src/js/controller/sidebar/05-tweaks.mjs
index cebc6dda..0376950c 100644
--- a/src/js/controller/sidebar/03-tweaks.mjs
+++ b/src/js/controller/sidebar/05-tweaks.mjs
@@ -1,4 +1,4 @@
-/** @module controller/sidebar/03-tweaks */
+/** @module controller/sidebar/05-tweaks */
import { isEmpty } from "../../base/helpers.mjs";
import { Controller } from "../base.mjs";
import { TweaksFormView } from "../../forms/enfugue/tweaks.mjs";
@@ -39,7 +39,7 @@ class TweaksController extends Controller {
"freeUSkip1": 0.9,
"freeUSkip2": 0.2,
"noiseOffset": 0.0,
- "noiseMethod": "perlin",
+ "noiseMethod": "simplex",
"noiseBlendMethod": "inject"
}
}
@@ -101,12 +101,11 @@ class TweaksController extends Controller {
tweaksConfig.freeUBackbone2 = defaultConfig.freeu_factors[1];
tweaksConfig.freeUSkip1 = defaultConfig.freeu_factors[2];
tweaksConfig.freeUSkip2 = defaultConfig.freeu_factors[3];
- } else {
- tweaksConfig.enableFreeU = false;
}
if (!isEmpty(newModel.scheduler)) {
tweaksConfig.scheduler = newModel.scheduler[0].name;
}
+
if (!isEmpty(tweaksConfig)) {
this.tweaksForm.setValues(tweaksConfig);
}
diff --git a/src/js/controller/sidebar/04-generation.mjs b/src/js/controller/sidebar/06-generation.mjs
similarity index 96%
rename from src/js/controller/sidebar/04-generation.mjs
rename to src/js/controller/sidebar/06-generation.mjs
index 390ff5f5..f06a1fba 100644
--- a/src/js/controller/sidebar/04-generation.mjs
+++ b/src/js/controller/sidebar/06-generation.mjs
@@ -1,4 +1,4 @@
-/** @module controlletr/sidebar/04-generation */
+/** @module controlletr/sidebar/03-generation */
import { isEmpty } from "../../base/helpers.mjs";
import { Controller } from "../base.mjs";
import { GenerationFormView } from "../../forms/enfugue/generation.mjs";
diff --git a/src/js/controller/sidebar/07-animation.mjs b/src/js/controller/sidebar/07-animation.mjs
new file mode 100644
index 00000000..ecc445f4
--- /dev/null
+++ b/src/js/controller/sidebar/07-animation.mjs
@@ -0,0 +1,88 @@
+/** @module controlletr/sidebar/04-animation */
+import { isEmpty } from "../../base/helpers.mjs";
+import { Controller } from "../base.mjs";
+import { AnimationFormView } from "../../forms/enfugue/animation.mjs";
+
+/**
+ * Extends the menu controller for state and init
+ */
+class AnimationController extends Controller {
+ /**
+ * Get data from the animation form
+ */
+ getState(includeImages = true) {
+ return { "animation": this.animationForm.values };
+ }
+
+ /**
+ * Gets default state
+ */
+ getDefaultState() {
+ return {
+ "animation": {
+ "animationEnabled": false,
+ "animationFrames": 16,
+ "animationRate": 8,
+ "animationSlicing": true,
+ "animationSize": 16,
+ "animationStride": 8,
+ "animationLoop": null,
+ "animationRate": 8,
+ "animationInterpolation": null,
+ }
+ };
+ }
+
+ /**
+ * Set state in the animation form
+ */
+ setState(newState) {
+ if (!isEmpty(newState.animation)) {
+ this.animationForm.setValues(newState.animation).then(() => this.animationForm.submit());
+ }
+ };
+
+ /**
+ * On init, append form and hide until SDXL gets selected
+ */
+ async initialize() {
+ this.animationForm = new AnimationFormView(this.config);
+ this.animationForm.onSubmit(async (values) => {
+ if (values.animationEnabled) {
+ this.engine.animationFrames = values.animationFrames;
+ this.engine.animationRate = values.animationRate;
+ this.engine.animationInterpolation = values.animationInterpolation;
+ this.engine.animationLoop = values.animationLoop;
+
+ if (values.animationMotionScaleEnabled) {
+ this.engine.animationMotionScale = values.animationMotionScale;
+ } else {
+
+ this.engine.animationMotionScale = null;
+ }
+ if (values.animationPositionEncodingSliceEnabled) {
+ this.engine.animationPositionEncodingTruncateLength = values.animationPositionEncodingTruncateLength;
+ this.engine.animationPositionEncodingScaleLength = values.animationPositionEncodingScaleLength;
+ } else {
+ this.engine.animationPositionEncodingTruncateLength = null;
+ this.engine.animationPositionEncodingScaleLength = null;
+ }
+
+ if (values.animationSlicing || values.animationLoop) {
+ this.engine.animationSize = values.animationSize;
+ this.engine.animationStride = values.animationStride;
+ } else {
+ this.engine.animationSize = null;
+ this.engine.animationStride = null;
+ }
+
+ this.engine.animationInterpolation = values.animationInterpolation;
+ } else {
+ this.engine.animationFrames = 0;
+ }
+ });
+ this.application.sidebar.addChild(this.animationForm);
+ }
+}
+
+export { AnimationController as SidebarController };
diff --git a/src/js/controller/sidebar/05-refining.mjs b/src/js/controller/sidebar/08-refining.mjs
similarity index 100%
rename from src/js/controller/sidebar/05-refining.mjs
rename to src/js/controller/sidebar/08-refining.mjs
diff --git a/src/js/controller/sidebar/06-upscale.mjs b/src/js/controller/sidebar/09-upscale.mjs
similarity index 92%
rename from src/js/controller/sidebar/06-upscale.mjs
rename to src/js/controller/sidebar/09-upscale.mjs
index e89adec4..ca4d0320 100644
--- a/src/js/controller/sidebar/06-upscale.mjs
+++ b/src/js/controller/sidebar/09-upscale.mjs
@@ -1,4 +1,4 @@
-/** @module controller/sidebar/05-upscale */
+/** @module controller/sidebar/06-upscale */
import { isEmpty, deepClone } from "../../base/helpers.mjs";
import { Controller } from "../base.mjs";
import { UpscaleStepsFormView } from "../../forms/enfugue/upscale.mjs";
@@ -68,7 +68,9 @@ class UpscaleController extends Controller {
}
}
});
-
+ this.subscribe("quickUpscale", (upscaleStep) => {
+ this.upscaleForm.setValues({"steps": [upscaleStep]});
+ });
this.application.sidebar.addChild(this.upscaleForm);
}
}
diff --git a/src/js/controller/sidebar/07-prompts.mjs b/src/js/controller/sidebar/11-prompts.mjs
similarity index 52%
rename from src/js/controller/sidebar/07-prompts.mjs
rename to src/js/controller/sidebar/11-prompts.mjs
index d9d4c5d3..0109e420 100644
--- a/src/js/controller/sidebar/07-prompts.mjs
+++ b/src/js/controller/sidebar/11-prompts.mjs
@@ -1,4 +1,4 @@
-/** @module controller/sidebar/06-prompts */
+/** @module controller/sidebar/07-prompts */
import { isEmpty } from "../../base/helpers.mjs";
import { Controller } from "../base.mjs";
import { PromptsFormView } from "../../forms/enfugue/prompts.mjs";
@@ -24,6 +24,7 @@ class PromptsController extends Controller {
"prompts": {
"prompt": null,
"negativePrompt": null,
+ "usePromptTravel": false,
}
};
}
@@ -39,15 +40,40 @@ class PromptsController extends Controller {
* On init, append fields
*/
async initialize() {
+ let isAnimation = false;
this.promptsForm = new PromptsFormView(this.config);
this.promptsForm.onSubmit(async (values) => {
this.engine.prompt = values.prompt;
this.engine.negativePrompt = values.negativePrompt;
+ if (values.usePromptTravel && isAnimation) {
+ this.publish("promptTravelEnabled");
+ this.promptsForm.addClass("use-prompt-travel");
+ } else {
+ this.publish("promptTravelDisabled");
+ this.promptsForm.removeClass("use-prompt-travel");
+ }
});
this.promptsForm.onShortcutSubmit(() => {
this.application.publish("tryInvoke");
});
this.application.sidebar.addChild(this.promptsForm);
+ this.subscribe("engineAnimationFramesChange", (frames) => {
+ isAnimation = !isEmpty(frames) && frames > 0;
+ if (isAnimation) {
+ this.promptsForm.addClass("show-prompt-travel");
+ if (this.promptsForm.values.usePromptTravel) {
+ this.publish("promptTravelEnabled");
+ this.promptsForm.addClass("use-prompt-travel");
+ } else {
+ this.publish("promptTravelDisabled");
+ this.promptsForm.removeClass("use-prompt-travel");
+ }
+ } else {
+ this.promptsForm.removeClass("show-prompt-travel");
+ this.promptsForm.removeClass("use-prompt-travel");
+ this.publish("promptTravelDisabled");
+ }
+ });
}
}
diff --git a/src/js/controller/common/logs.mjs b/src/js/controller/sidebar/98-logs.mjs
similarity index 91%
rename from src/js/controller/common/logs.mjs
rename to src/js/controller/sidebar/98-logs.mjs
index 32c20d50..34438dbf 100644
--- a/src/js/controller/common/logs.mjs
+++ b/src/js/controller/sidebar/98-logs.mjs
@@ -22,7 +22,7 @@ class LogGlanceView extends View {
/**
* @var int The maximum number of logs to show.
*/
- static maximumLogs = 5;
+ static maximumLogs = 15;
/**
* On construct, set time and hide ourselves
@@ -30,7 +30,6 @@ class LogGlanceView extends View {
constructor(config) {
super(config);
this.start = (new Date()).getTime();
- this.hide();
}
/**
@@ -66,11 +65,17 @@ class LogGlanceView extends View {
async build() {
let node = await super.build();
node.append(
- E.div().class("log-header").content(
- E.h2().content("Most Recent Logs"),
- E.a().href("#").content("Show More").on("click", (e) => { e.preventDefault(); e.stopPropagation(); this.showMore(); })
- ),
- E.div().class("logs")
+ E.div().class("log-container").content(
+ E.button()
+ .content(E.i().class("fa-regular fa-square-caret-right"))
+ .data("tooltip", "Show More")
+ .on("click", (e) => {
+ e.preventDefault();
+ e.stopPropagation();
+ this.showMore();
+ }),
+ E.div().class("logs").content("Welcome to ENFUGUE! When the diffusion engine logs to file, the most recent lines will appear here.")
+ )
);
return node;
}
@@ -181,11 +186,6 @@ class LogsController extends Controller {
*/
static maximumDetailLogs = 100;
- /**
- * @var int Maximum logs to show at once in the glance window
- */
- static maximumGlanceLogs = 5;
-
/**
* @var int Log tail interval in MS
*/
@@ -268,9 +268,11 @@ class LogsController extends Controller {
async initialize() {
this.glanceView = new LogGlanceView(this.config);
this.glanceView.onShowMore = () => this.showLogDetails();
- this.application.container.appendChild(await this.glanceView.render());
+ this.application.sidebar.addChild(this.glanceView);
this.startLogTailer();
}
};
-export { LogsController };
+export {
+ LogsController as SidebarController
+};
diff --git a/src/js/controller/sidebar/99-invoke.mjs b/src/js/controller/sidebar/99-invoke.mjs
index 6d580540..3bef5187 100644
--- a/src/js/controller/sidebar/99-invoke.mjs
+++ b/src/js/controller/sidebar/99-invoke.mjs
@@ -44,150 +44,49 @@ class InvokeButtonController extends Controller {
/**
* Gets the step data from the canvas for invocation.
*/
- getNodes() {
- let canvasState = this.images.getState(),
- nodes = Array.isArray(canvasState)
- ? canvasState
- : canvasState.nodes || [],
- warningMessages = [];
-
- return nodes.map((datum, i) => {
+ getLayers() {
+ let layerState = this.application.layers.getState();
+ console.log(layerState);
+ return layerState.layers.map((datum, i) => {
let formattedState = {
"x": datum.x,
"y": datum.y,
"w": datum.w,
"h": datum.h,
- "inference_steps": datum.inferenceSteps,
- "guidance_scale": datum.guidanceScale,
- "scale_to_model_size": datum.scaleToModelSize,
- "remove_background": datum.removeBackground
+ "remove_background": datum.removeBackground,
+ "image": datum.src
};
-
- if (Array.isArray(datum.prompt)) {
- formattedState["prompt"], formattedState["prompt_2"] = datum.prompt;
- } else {
- formattedState["prompt"] = datum.prompt;
- }
-
- if (Array.isArray(datum.negativePrompt)) {
- formattedState["negative_prompt"], formattedState["negative_prompt_2"] = datum.negativePrompt;
- } else {
- formattedState["negative_prompt"] = datum.negativePrompt;
- }
switch (datum.classname) {
- case "ImageEditorPromptNodeView":
- break;
case "ImageEditorScribbleNodeView":
- formattedState["control_images"] = [
- {"image": datum.src, "process": false, "invert": true, "controlnet": "scribble"}
+ formattedState["control_units"] = [
+ {"process": false, "controlnet": "scribble"}
];
break;
case "ImageEditorImageNodeView":
+ case "ImageEditorVideoNodeView":
formattedState["fit"] = datum.fit;
formattedState["anchor"] = datum.anchor;
- if (datum.infer || datum.inpaint || (!datum.infer && !datum.inpaint && !datum.imagePrompt && !datum.control)) {
- formattedState["image"] = datum.src;
- }
- if (datum.infer) {
- formattedState["strength"] = datum.strength;
- }
- if (datum.inpaint) {
- formattedState["mask"] = datum.scribbleSrc;
- formattedState["invert_mask"] = true; // The UI is inversed
- formattedState["crop_inpaint"] = datum.cropInpaint;
- formattedState["inpaint_feather"] = datum.inpaintFeather;
- }
- if (datum.imagePromptPlus) {
- formattedState["ip_adapter_plus"] = true;
- if (datum.imagePromptFace) {
- formattedState["ip_adapter_face"] = true;
- }
- }
+ formattedState["denoise"] = !!datum.denoise;
if (datum.imagePrompt) {
- formattedState["ip_adapter_images"] = [
- {
- "image": datum.src,
- "scale": datum.imagePromptScale,
- "fit": datum.fit,
- "anchor": datum.anchor
- }
- ];
+ formattedState["ip_adapter_scale"] = datum.imagePromptScale;
}
if (datum.control) {
- formattedState["control_images"] = [
- {
- "image": datum.src,
- "process": datum.processControlImage,
- "invert": datum.invertControlImage === true,
- "controlnet": datum.controlnet,
- "scale": datum.conditioningScale,
- "fit": datum.fit,
- "anchor": datum.anchor,
- "start": datum.conditioningStart,
- "end": datum.conditioningEnd
- }
- ];
+ formattedState["control_units"] = datum.controlnetUnits.map((unit) => {
+ return {
+ "process": unit.processControlImage,
+ "start": unit.conditioningStart,
+ "end": unit.conditioningEnd,
+ "scale": unit.conditioningScale,
+ "controlnet": unit.controlnet
+ };
+ });
}
- break;
- case "ImageEditorCompoundImageNodeView":
- let imageNodeIndex, promptImageNodeIndex;
- for (let j = 0; j < datum.children.length; j++) {
- let child = datum.children[j];
- if (child.infer || child.inpaint) {
- if (!isEmpty(imageNodeIndex)) {
- messages.push(`Node {i+1}: Base image set in image {imageNodeIndex+1}, ignoring additional set in {j+1}`);
- } else {
- imageNodeIndex = j;
- formattedState["image"] = child.src;
- formattedState["anchor"] = child.anchor;
- formattedState["fit"] = child.fit;
- }
- }
- if (child.infer && imageNodeIndex == j) {
- formattedState["strength"] = child.strength;
- }
- if (child.inpaint && imageNodeIndex == j) {
- formattedState["mask"] = child.scribbleSrc;
- formattedState["invert_mask"] = true; // The UI is inversed
- formattedState["crop_inpaint"] = child.cropInpaint;
- formattedState["inpaint_feather"] = child.inpaintFeather;
- }
- if (child.imagePrompt) {
- if (isEmpty(formattedState["ip_adapter_images"])) {
- formattedState["ip_adapter_images"] = [];
- }
- if (child.imagePromptPlus) {
- formattedState["ip_adapter_plus"] = true;
- if (child.imagePromptFace) {
- formattedState["ip_adapter_face"] = true;
- }
- }
- formattedState["ip_adapter_images"].push(
- {
- "image": child.src,
- "scale": child.imagePromptScale,
- "fit": child.fit,
- "anchor": child.anchor
- }
- );
- }
- if (child.control) {
- if (isEmpty(formattedState["control_images"])) {
- formattedState["control_images"] = [];
- }
- formattedState["control_images"].push(
- {
- "image": child.src,
- "process": child.processControlImage,
- "invert": child.colorSpace == "invert",
- "controlnet": child.controlnet,
- "scale": child.conditioningScale,
- "fit": child.fit,
- "anchor": child.anchor
- }
- );
- }
+ if (!isEmpty(datum.skipFrames)) {
+ formattedState["skip_frames"] = datum.skipFrames;
+ }
+ if (!isEmpty(datum.divideFrames)) {
+ formattedState["divide_frames"] = datum.divideFrames;
}
break;
default:
@@ -207,7 +106,7 @@ class InvokeButtonController extends Controller {
this.invokeButton.disable().addClass("sliding-gradient");
try {
this.application.autosave();
- await this.application.invoke({"nodes": this.getNodes()});
+ await this.application.invoke({"layers": this.getLayers()});
} catch(e) {
console.error(e);
let errorMessage = `${e}`;
diff --git a/src/js/controller/system/03-installation.mjs b/src/js/controller/system/03-installation.mjs
index 8f8a147d..45fb05b3 100644
--- a/src/js/controller/system/03-installation.mjs
+++ b/src/js/controller/system/03-installation.mjs
@@ -200,7 +200,7 @@ class InstallationSummaryView extends View {
let node = await super.build();
await this.update();
return node.content(
- E.h2().content("Weights and Configuration"),
+ E.h2().content("Installation Directories"),
await this.summaryTable.getNode(),
E.h2().content("TensorRT Engines"),
await this.engineTable.getNode()
diff --git a/src/js/forms/enfugue/animation.mjs b/src/js/forms/enfugue/animation.mjs
new file mode 100644
index 00000000..c88e7639
--- /dev/null
+++ b/src/js/forms/enfugue/animation.mjs
@@ -0,0 +1,178 @@
+/** @module forms/enfugue/animation */
+import { FormView } from "../base.mjs";
+import {
+ NumberInputView,
+ CheckboxInputView,
+ AnimationLoopInputView,
+ AnimationInterpolationStepsInputView,
+} from "../input.mjs";
+
+/**
+ * The AnimationFormView gathers inputs for AnimateDiff animation
+ */
+class AnimationFormView extends FormView {
+ /**
+ * @var bool Hide submit
+ */
+ static autoSubmit = true;
+
+ /**
+ * @var bool Start collapsed
+ */
+ static collapseFieldSets = true;
+
+ /**
+ * @var object All the inputs in this controller
+ */
+ static fieldSets = {
+ "Animation": {
+ "animationEnabled": {
+ "label": "Enable Animation",
+ "class": CheckboxInputView,
+ },
+ "animationFrames": {
+ "label": "Animation Frames",
+ "class": NumberInputView,
+ "config": {
+ "min": 8,
+ "step": 1,
+ "value": 16,
+ "tooltip": "The number of animation frames the overall animation should be. Divide this number by the animation rate to determine the overall length of the animation in seconds."
+ }
+ },
+ "animationLoop": {
+ "label": "Loop Animation",
+ "class": AnimationLoopInputView
+ },
+ "animationSlicing": {
+ "label": "Use Frame Attention Slicing",
+ "class": CheckboxInputView,
+ "config": {
+ "value": true,
+ "tooltip": "Similar to slicing along the width or height of an image, when using frame slicing, only a portion of the overall animation will be rendered at once. This will reduce the memory required for long animations, but make the process of creating it take longer overall.
Since the animation model is trained on short burts of animation, this can help improve the overall coherence and motion of an animation as well."
+ }
+ },
+ "animationSize": {
+ "label": "Frame Window Size",
+ "class": NumberInputView,
+ "config": {
+ "min": 8,
+ "max": 64,
+ "value": 16,
+ "tooltip": "This is the number of frames to render at once when used sliced animation diffusion. Higher values will require more VRAM, but reduce the overall number of slices needed to render the final animation."
+ }
+ },
+ "animationStride": {
+ "label": "Frame Window Stride",
+ "class": NumberInputView,
+ "config": {
+ "min": 1,
+ "max": 32,
+ "value": 8,
+ "tooltip": "This is the numbers of frames to move the frame window by when using sliced animation diffusion. It is recommended to keep this value at least two fewer than the animation engine size, as that will leave at least two frames of overlap between slices and ease the transition between them."
+ }
+ },
+ "animationMotionScaleEnabled": {
+ "label": "Use Motion Attention Scaling",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When enabled, a scale will be applied to the influence of motion data on the animation that is proportional to the ratio between the motion training resolution and the image resolution. This will generally increase the amount of motion in the final animation."
+ }
+ },
+ "animationMotionScale": {
+ "label": "Motion Attention Scale Multiplier",
+ "class": NumberInputView,
+ "config": {
+ "min": 0.0,
+ "step": 0.01,
+ "value": 1.0,
+ "tooltip": "When using motion attention scaling, this multiplier will be applied to the scaling. You can use this to decrease the amount of motion (values less than 1.0) or increase the amount of motion (values greater than 1.0) in the resulting animation."
+ }
+ },
+ "animationPositionEncodingSliceEnabled": {
+ "label": "Use Position Encoding Slicing",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When enabled, you can control the length of position encoding data, slicing it short and/or scaling it linearly. Slicing can improve consistency by removing unused late-animation embeddings beyond a frame window, and scaling can act as a timescale modifier."
+ }
+ },
+ "animationPositionEncodingTruncateLength": {
+ "label": "Position Encoding Truncate Length",
+ "class": NumberInputView,
+ "config": {
+ "min": 8,
+ "max": 24,
+ "value": 16,
+ "tooltip": "Where to end position encoding data. Position tensors are generally 24 frames long, so a value of 16 will remove the final 8 frames of data."
+ }
+ },
+ "animationPositionEncodingScaleLength": {
+ "label": "Position Encoding Scale Length",
+ "class": NumberInputView,
+ "config": {
+ "min": 8,
+ "value": 24,
+ "tooltip": "How long position encoding data should be after truncating and scaling. For example, if you truncate position data to 16 frames and scale position data to 24 frames, you will have removed the final 8 frames of training data, then altered the timescale of the animation by one half - i.e., the animation will appear about 50% slower. This feature is experimental and may result in strange movement."
+ }
+ },
+ "animationRate": {
+ "label": "Frame Rate",
+ "class": NumberInputView,
+ "config": {
+ "min": 8,
+ "value": 8,
+ "max": 128,
+ "tooltip": "The frame rate of the output video. Note that the animations are saved as individual frames, not as videos - so this can be changed later without needing to re-process the invocation. Also note that the frame rate of the AI model is fixed at 8 frames per second, so any values higher than this will result in sped-up motion. Use this value in combination with frame interpolation to control the smoothness of the output video."
+ }
+ },
+ "animationInterpolation": {
+ "label": "Frame Interpolation",
+ "class": AnimationInterpolationStepsInputView
+ }
+ }
+ };
+
+ /**
+ * On submit, add/remove CSS for hiding/showing
+ */
+ async submit() {
+ await super.submit();
+
+ if (this.values.animationEnabled) {
+ this.removeClass("no-animation");
+ } else {
+ this.addClass("no-animation");
+ }
+
+ if (this.values.animationMotionScaleEnabled) {
+ this.removeClass("no-animation-scaling");
+ } else {
+ this.addClass("no-animation-scaling");
+ }
+
+ if (this.values.animationPositionEncodingSliceEnabled) {
+ this.removeClass("no-position-slicing");
+ } else {
+ this.addClass("no-position-slicing");
+ }
+
+ let useSlicing = this.values.animationSlicing,
+ slicingInput = await this.getInputView("animationSlicing");
+
+ if (this.values.animationLoop === "loop") {
+ useSlicing = true;
+ slicingInput.setValue(true, false);
+ slicingInput.disable();
+ } else {
+ slicingInput.enable();
+ }
+
+ if (useSlicing) {
+ this.removeClass("no-animation-slicing");
+ } else {
+ this.addClass("no-animation-slicing");
+ }
+ }
+}
+
+export { AnimationFormView };
diff --git a/src/js/forms/enfugue/canvas.mjs b/src/js/forms/enfugue/canvas.mjs
index 385766d3..99789529 100644
--- a/src/js/forms/enfugue/canvas.mjs
+++ b/src/js/forms/enfugue/canvas.mjs
@@ -5,12 +5,13 @@ import {
NumberInputView,
CheckboxInputView,
SelectInputView,
- MaskTypeInputView
+ MaskTypeInputView,
+ EngineSizeInputView,
} from "../input.mjs";
let defaultWidth = 512,
defaultHeight = 512,
- defaultChunkingSize = 64;
+ defaultTilingStride = 64;
if (
!isEmpty(window.enfugue) &&
@@ -25,8 +26,8 @@ if (
if (!isEmpty(invocationConfig.height)) {
defaultHeight = invocationConfig.height;
}
- if (!isEmpty(invocationConfig.chunkingSize)) {
- defaultChunkingSize = invocationConfig.chunkingSize;
+ if (!isEmpty(invocationConfig.tilingSize)) {
+ defaultTilingStride = invocationConfig.tilingSize;
}
}
@@ -39,11 +40,6 @@ class CanvasFormView extends FormView {
*/
static className = "canvas-form-view";
- /**
- * @var bool Collapse these fields by default
- */
- static collapseFieldSets = true;
-
/**
* @var bool Hide submit button
*/
@@ -62,7 +58,8 @@ class CanvasFormView extends FormView {
"max": 16384,
"value": defaultWidth,
"step": 8,
- "tooltip": "The width of the canvas in pixels."
+ "tooltip": "The width of the canvas in pixels.",
+ "allowNull": false
}
},
"height": {
@@ -73,28 +70,51 @@ class CanvasFormView extends FormView {
"max": 16384,
"value": defaultHeight,
"step": 8,
- "tooltip": "The height of the canvas in pixels."
+ "tooltip": "The height of the canvas in pixels.",
+ "allowNull": false
+ }
+ },
+ "tileHorizontal": {
+ "label": "Horizontally
Tiling",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When enabled, the resulting image will tile horizontally, i.e., when duplicated and placed side-by-side, there will be no seams between the copies."
}
},
- "useChunking": {
- "label": "Use Chunking",
+ "tileVertical": {
+ "label": "Vertically
Tiling",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When enabled, the resulting image will tile vertically, i.e., when duplicated and placed with on image on top of the other, there will be no seams between the copies."
+ }
+ },
+ "useTiling": {
+ "label": "Enabled Tiled Diffusion/VAE",
"class": CheckboxInputView,
"config": {
"tooltip": "When enabled, the engine will only ever process a square in the size of the configured model size at once. After each square, the frame will be moved by the configured amount of pixels along either the horizontal or vertical axis, and then the image is re-diffused. When this is disabled, the entire canvas will be diffused at once. This can have varying results, but a guaranteed result is increased VRAM use.",
"value": false
}
},
- "chunkingSize": {
- "label": "Chunking Size",
+ "tilingSize": {
+ "label": "Tile Size",
+ "class": EngineSizeInputView,
+ "config": {
+ "required": false,
+ "value": null
+ }
+ },
+ "tilingStride": {
+ "label": "Tile Stride",
"class": SelectInputView,
"config": {
"options": ["8", "16", "32", "64", "128", "256", "512"],
- "value": `${defaultChunkingSize}`,
- "tooltip": "The number of pixels to move the frame when doing chunked diffusion. A low number can produce more detailed results, but can be noisy, and takes longer to process. A high number is faster to process, but can have poor results especially along frame boundaries. The recommended value is set by default."
+ "value": `${defaultTilingStride}`,
+ "tooltip": "The number of pixels to move the frame when doing tiled diffusion. A low number can produce more detailed results, but can be noisy, and takes longer to process. A high number is faster to process, but can have poor results especially along frame boundaries. The recommended value is set by default."
}
},
- "chunkingMaskType": {
- "label": "Chunking Mask",
+ "tilingMaskType": {
+ "label": "Tile Mask",
"class": MaskTypeInputView
}
}
@@ -105,10 +125,18 @@ class CanvasFormView extends FormView {
*/
async submit() {
await super.submit();
- if (this.values.useChunking) {
- this.removeClass("no-chunking");
+ let chunkInput = (await this.getInputView("useTiling"));
+ if (this.values.tileHorizontal || this.values.tileVertical) {
+ this.removeClass("no-tiling");
+ chunkInput.setValue(true, false);
+ chunkInput.disable();
} else {
- this.addClass("no-chunking");
+ chunkInput.enable();
+ if (this.values.useTiling) {
+ this.removeClass("no-tiling");
+ } else {
+ this.addClass("no-tiling");
+ }
}
}
};
diff --git a/src/js/forms/enfugue/denoising.mjs b/src/js/forms/enfugue/denoising.mjs
new file mode 100644
index 00000000..92d59ff4
--- /dev/null
+++ b/src/js/forms/enfugue/denoising.mjs
@@ -0,0 +1,36 @@
+/** @module forms/enfugue/denoising */
+import { isEmpty, deepClone } from "../../base/helpers.mjs";
+import { FormView } from "../base.mjs";
+import { SliderPreciseInputView } from "../input.mjs";
+
+/**
+ * The form class containing the strength slider
+ */
+class DenoisingFormView extends FormView {
+ /**
+ * @var object All field sets and their config
+ */
+ static fieldSets = {
+ "Denoising Strength": {
+ "strength": {
+ "class": SliderPreciseInputView,
+ "config": {
+ "value": 0.99,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "tooltip": "The amount of the image to change. A value of 1.0 means the final image will be completely different from the input image, and a value of 0.0 means the final image will not change from the input image. A value of 0.8 usually represents a good balance of changing the image while maintaining similar features."
+ }
+ }
+ }
+ };
+
+ /**
+ * @var bool Hide submit
+ */
+ static autoSubmit = true;
+};
+
+export {
+ DenoisingFormView
+};
diff --git a/src/js/forms/enfugue/engine.mjs b/src/js/forms/enfugue/engine.mjs
deleted file mode 100644
index e4943ed9..00000000
--- a/src/js/forms/enfugue/engine.mjs
+++ /dev/null
@@ -1,56 +0,0 @@
-/** @module forms/enfugue/engine */
-import { FormView } from "../base.mjs";
-import {
- EngineSizeInputView,
- RefinerEngineSizeInputView,
- InpainterEngineSizeInputView
-} from "../input.mjs";
-
-/**
- * The forms that allow for engine configuration when not using preconfigured models
- */
-class EngineFormView extends FormView {
- /**
- * @var bool Don't show submit
- */
- static autoSubmit = true;
-
- /**
- * @var bool Start collapsed
- */
- static collapseFieldSets = true;
-
- /**
- * @var object The field sets for the form
- */
- static fieldSets = {
- "Engine": {
- "size": {
- "label": "Engine Size",
- "class": EngineSizeInputView,
- "config": {
- "required": false,
- "value": null
- }
- },
- "refinerSize": {
- "label": "Refining Engine Size",
- "class": RefinerEngineSizeInputView,
- "config": {
- "required": false,
- "value": null
- }
- },
- "inpainterSize": {
- "label": "Inpainting Engine Size",
- "class": InpainterEngineSizeInputView,
- "config": {
- "required": false,
- "value": null
- }
- }
- }
- };
-};
-
-export { EngineFormView };
diff --git a/src/js/forms/enfugue/image-editor.mjs b/src/js/forms/enfugue/image-editor.mjs
index d7e08fec..2ebbb87f 100644
--- a/src/js/forms/enfugue/image-editor.mjs
+++ b/src/js/forms/enfugue/image-editor.mjs
@@ -1,617 +1,18 @@
/** @module forms/enfugue/image-editor */
-import { isEmpty } from "../../base/helpers.mjs";
-import { FormView } from "../../forms/base.mjs";
-import {
- PromptInputView,
- FloatInputView,
- NumberInputView,
- CheckboxInputView,
- ImageColorSpaceInputView,
- ControlNetInputView,
- ImageFitInputView,
- ImageAnchorInputView,
- FilterSelectInputView,
- SliderPreciseInputView
-} from "../../forms/input.mjs";
-
-class ImageEditorNodeOptionsFormView extends FormView {
- /**
- * @var object The fieldsets of the options form for image mode.
- */
- static fieldSets = {
- "Prompts": {
- "prompt": {
- "label": "Prompt",
- "class": PromptInputView,
- "config": {
- "tooltip": "This prompt will control what is in this frame. When left blank, the global prompt will be used."
- }
- },
- "negativePrompt": {
- "label": "Negative Prompt",
- "class": PromptInputView,
- "config": {
- "tooltip": "This prompt will control what is in not this frame. When left blank, the global negative prompt will be used."
- }
- },
- },
- "Tweaks": {
- "guidanceScale": {
- "label": "Guidance Scale",
- "class": FloatInputView,
- "config": {
- "min": 0.0,
- "max": 100.0,
- "step": 0.1,
- "value": null,
- "tooltip": "How closely to follow the text prompt; high values result in high-contrast images closely adhering to your text, low values result in low-contrast images with more randomness."
- }
- },
- "inferenceSteps": {
- "label": "Inference Steps",
- "class": NumberInputView,
- "config": {
- "min": 5,
- "max": 250,
- "step": 1,
- "value": null,
- "tooltip": "How many steps to take during primary inference, larger values take longer to process."
- }
- }
- },
- "Other": {
- "scaleToModelSize": {
- "label": "Scale to Model Size",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "When this node has any dimension smaller than the size of the configured model, scale it up so it's smallest dimension is the same size as the model, then scale it down after diffusion.
This generally improves image quality in slightly rectangular shapes or square shapes smaller than the engine size, but can also result in ghosting and increased processing time.
This will have no effect if your node is larger than the model size in all dimensions.
If unchecked and your node is smaller than the model size, TensorRT will be disabled for this node."
- },
- },
- "removeBackground": {
- "label": "Remove Background",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "After diffusion, run the resulting image though an AI background removal algorithm. This can improve image consistency when using multiple nodes."
- }
- }
- },
- };
-
- /**
- * @var bool Never show submit button
- */
- static autoSubmit = true;
-
- /**
- * @var string An additional classname for this form
- */
- static className = "options-form-view";
-
- /**
- * @var array Collapsed field sets
- */
- static collapseFieldSets = ["Tweaks"];
-};
-
-/**
- * This form combines all image options.
- */
-class ImageEditorImageNodeOptionsFormView extends FormView {
- /**
- * @var object The fieldsets of the options form for image mode.
- */
- static fieldSets = {
- "Base": {
- "fit": {
- "label": "Image Fit",
- "class": ImageFitInputView
- },
- "anchor": {
- "label": "Image Anchor",
- "class": ImageAnchorInputView
- },
- "inpaint": {
- "label": "Use for Inpainting",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "When checked, you will be able to paint where on the image you wish for the AI to fill in details. Any gaps in the frame or transparency in the image will also be filled."
- }
- },
- "infer": {
- "label": "Use for Inference",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "When checked, use this image as input for primary diffusion. Inpainting will be performed first, filling any painted sections as well as gaps in the frame and transparency in the image."
- }
- },
- "control": {
- "label": "Use for Control",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "When checked, use this image as input for ControlNet. Inpainting will be performed first, filling any painted sections as well as gaps in the frame and transparency in the image.
Unless otherwise specified, your image will be processed using the appropriate algorithm for the chosen ControlNet."
- }
- },
- "imagePrompt": {
- "label": "Use for Prompt",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "When checked, use this image for Image Prompting. This uses a technique whereby your image is analzyed for descriptors automatically and the 'image prompt' is merged with your real prompt. This can help produce variations of an image without adhering too closely to the original image, and without you having to describe the image yourself."
- }
- }
- },
- "Other": {
- "scaleToModelSize": {
- "label": "Scale to Model Size",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "When this has any dimension smaller than the size of the configured model, scale it up so it's smallest dimension is the same size as the model, then scale it down after diffusion.
This generally improves image quality in rectangular shapes, but can also result in ghosting and increased processing time.
This will have no effect if your node is larger than the model size in all dimensions.
If unchecked and your node is smaller than the model size, TensorRT will be disabled for this node."
- },
- },
- "removeBackground": {
- "label": "Remove Background",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "Before processing, run this image through an AI background removal process. If you are additionally inpainting, inferencing or using this image for ControlNet, that background will then be filled in within this frame. If you are not, that background will be filled when the overall canvas image is finally painted in."
- }
- }
- },
- "Image Prompt": {
- "imagePromptScale": {
- "label": "Image Prompt Scale",
- "class": FloatInputView,
- "config": {
- "tooltip": "How much strength to give to the image. A higher strength will reduce the effect of your prompt, and a lower strength will increase the effect of your prompt but reduce the effect of the image.",
- "min": 0,
- "step": 0.01,
- "value": 0.5
- }
- },
- "imagePromptPlus": {
- "label": "Use Fine-Grained Model",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "Use this to enable fine-grained feature inspection on the source image. This can improve details in the resulting image, but can also make the overall image less similar.
Note that when using multiple source images for image prompting, enabling fine-grained feature inspection on any image enables it for all images."
- }
- },
- "imagePromptFace": {
- "label": "Use Face-Specific Model",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "Use this to focus strongly on a face in the input image. This can work very well to copy a face from one image to another in a natural way, instead of needing a separate face-fixing step.
Note that at present moment, this feature is only available for Stable Diffusion 1.5 models. This checkbox does nothing for SDXL models."
- }
- }
- },
- "Prompts": {
- "prompt": {
- "label": "Prompt",
- "class": PromptInputView,
- "config": {
- "tooltip": "This prompt will control what is in this frame. When left blank, the global prompt will be used."
- }
- },
- "negativePrompt": {
- "label": "Negative Prompt",
- "class": PromptInputView,
- "config": {
- "tooltip": "This prompt will control what is in not this frame. When left blank, the global negative prompt will be used."
- }
- },
- },
- "Tweaks": {
- "guidanceScale": {
- "label": "Guidance Scale",
- "class": FloatInputView,
- "config": {
- "min": 0.0,
- "max": 100.0,
- "step": 0.1,
- "value": null,
- "tooltip": "How closely to follow the text prompt; high values result in high-contrast images closely adhering to your text, low values result in low-contrast images with more randomness."
- }
- },
- "inferenceSteps": {
- "label": "Inference Steps",
- "class": NumberInputView,
- "config": {
- "min": 5,
- "max": 250,
- "step": 1,
- "value": null,
- "tooltip": "How many steps to take during primary inference, larger values take longer to process."
- }
- }
- },
- "Inference": {
- "strength": {
- "label": "Denoising Strength",
- "class": SliderPreciseInputView,
- "config": {
- "min": 0.0,
- "max": 1.0,
- "step": 0.01,
- "value": 0.8,
- "tooltip": "How much of the input image to replace with new information. A value of 1.0 represents total input image destruction, and 0.0 represents no image modifications being made."
- }
- }
- },
- "Inpainting": {
- "cropInpaint": {
- "label": "Use Cropped Inpainting",
- "class": CheckboxInputView,
- "config": {
- "tooltip": "When checked, the image will be cropped to the area you've shaded prior to executing. This will reduce processing time on large images, but can result in losing the composition of the image.",
- "value": true
- }
- },
- "inpaintFeather": {
- "label": "Cropped Inpaint Feather",
- "class": NumberInputView,
- "config": {
- "min": 16,
- "max": 256,
- "step": 8,
- "value": 32,
- "tooltip": "When using cropped inpainting, this is the number of pixels to feather along the edge of the crop in order to help blend in with the rest of the image."
- }
- }
- },
- "Control": {
- "controlnet": {
- "label": "ControlNet",
- "class": ControlNetInputView
- },
- "conditioningScale": {
- "label": "Conditioning Scale",
- "class": FloatInputView,
- "config": {
- "min": 0.0,
- "step": 0.01,
- "value": 1.0,
- "tooltip": "How closely to follow ControlNet's influence. Typical values vary, usually values between 0.5 and 1.0 produce good conditioning with balanced randomness, but other values may produce something closer to the desired result."
- }
- },
- "conditioningStart": {
- "label": "Conditioning Start",
- "class": SliderPreciseInputView,
- "config": {
- "min": 0.0,
- "max": 1.0,
- "step": 0.01,
- "value": 0.0,
- "tooltip": "When to begin using this ControlNet for influence. Defaults to the beginning of generation."
- }
- },
- "conditioningEnd": {
- "label": "Conditioning End",
- "class": SliderPreciseInputView,
- "config": {
- "min": 0.0,
- "max": 1.0,
- "step": 0.01,
- "value": 1.0,
- "tooltip": "When to stop using this ControlNet for influence. Defaults to the end of generation."
- }
- },
- "processControlImage": {
- "label": "Process Image for ControlNet",
- "class": CheckboxInputView,
- "config": {
- "value": true,
- "tooltip": "When checked, the image will be processed through the appropriate edge detection algorithm for the ControlNet. Only uncheck this if your image has already been processed through edge detection."
- }
- },
- "invertControlImage": {
- "label": "Invert Image for ControlNet",
- "class": CheckboxInputView,
- "config": {
- "value": false,
- "tooltip": "Invert the colors of the control image prior to using it."
- }
- }
- }
- };
-
- /**
- * @var object The conditions for display of some inputs.
- */
- static fieldSetConditions = {
- "Prompts": (values) => values.infer || values.inpaint || values.control,
- "Tweaks": (values) => values.infer || values.inpaint || values.control,
- "Inpainting": (values) => values.inpaint,
- "Inference": (values) => values.infer,
- "Control": (values) => values.control,
- "Image Prompt": (values) => values.imagePrompt
- };
-
- /**
- * @var bool Never show submit button
- */
- static autoSubmit = true;
-
- /**
- * @var string An additional classname for this form
- */
- static className = "image-options-form-view";
-
- /**
- * @var array Field sets to collapse
- */
- static collapseFieldSets = ["Tweaks"];
-
- /**
- * On input change, enable/disable flags
- */
- async inputChanged(fieldName, inputView) {
- if (fieldName === "inpaint") {
- let inference = await this.getInputView("infer");
- if (inputView.getValue()) {
- inference.setValue(true, false);
- inference.disable();
- this.values.infer = true;
- this.evaluateConditions();
- } else {
- inference.enable();
- }
- }
- if (fieldName === "processControlImage") {
- if (inputView.getValue()) {
- this.removeClass("no-process");
- } else {
- this.addClass("no-process");
- }
- }
- if (fieldName === "imagePromptPlus") {
- if (inputView.getValue()) {
- this.addClass("prompt-plus");
- } else {
- this.removeClass("prompt-plus");
- }
- }
- return super.inputChanged.call(this, fieldName, inputView);
- }
-
- /**
- * On set values, check and set classes.
- */
- async setValues() {
- await super.setValues.apply(this, Array.from(arguments));
- if (this.values.control && !this.values.processControlImage) {
- this.addClass("no-process");
- } else {
- this.removeClass("no-process");
- }
- if (this.values.imagePromptPlus) {
- this.addClass("prompt-plus");
- } else {
- this.removeClass("prompt-plus");
- }
- let inference = await this.getInputView("infer");
- if (this.values.inpaint) {
- this.values.infer = true;
- inference.setValue(true, false);
- inference.disable();
- } else {
- inference.enable();
- }
- }
-};
-
-/**
- * Creates a common form view base for filter forms
- */
-class ImageFilterFormView extends FormView {
- /**
- * @var bool autosubmit
- */
- static autoSubmit = true;
-
- /**
- * @var bool Disable disabling
- */
- static disableOnSubmit = false;
-
- /**
- * Fieldsets include the main filter, then inputs for filter types
- */
- static fieldSets = {
- "Filter": {
- "filter": {
- "class": FilterSelectInputView,
- }
- },
- "Size": {
- "size": {
- "class": SliderPreciseInputView,
- "config": {
- "min": 4,
- "max": 64,
- "step": 1,
- "value": 4
- }
- }
- },
- "Radius": {
- "radius": {
- "class": SliderPreciseInputView,
- "config": {
- "min": 1,
- "max": 64,
- "step": 1,
- "value": 1
- }
- }
- },
- "Weight": {
- "weight": {
- "class": SliderPreciseInputView,
- "config": {
- "min": 0,
- "max": 100,
- "step": 1,
- "value": 0
- }
- }
- }
- };
-
- /**
- * @var object Default values
- */
- static defaultValues = {
- "filter": null,
- "size": 16,
- "radius": 2,
- "weight": 0
- };
-
- /**
- * @var object Callable conditions for fieldset display
- */
- static fieldSetConditions = {
- "Size": (values) => ["pixelize"].indexOf(values.filter) !== -1,
- "Radius": (values) => ["gaussian", "box", "sharpen"].indexOf(values.filter) !== -1,
- "Weight": (values) => ["sharpen"].indexOf(values.filter) !== -1
- };
-};
-
-/**
- * Creates a form view for controlling the ImageAdjustmentFilter
- */
-class ImageAdjustmentFormView extends ImageFilterFormView {
- /**
- * @var object Various options available
- */
- static fieldSets = {
- "Color Channel Adjustments": {
- "red": {
- "label": "Red Amount",
- "class": SliderPreciseInputView,
- "config": {
- "min": -100,
- "max": 100,
- "value": 0
- }
- },
- "green": {
- "label": "Green Amount",
- "class": SliderPreciseInputView,
- "config": {
- "min": -100,
- "max": 100,
- "value": 0
- }
- },
- "blue": {
- "label": "Blue Amount",
- "class": SliderPreciseInputView,
- "config": {
- "min": -100,
- "max": 100,
- "value": 0
- }
- }
- },
- "Brightness and Contrast": {
- "brightness": {
- "label": "Brightness Adjustment",
- "class": SliderPreciseInputView,
- "config": {
- "min": -100,
- "max": 100,
- "value": 0
- }
- },
- "contrast": {
- "label": "Contrast Adjustment",
- "class": SliderPreciseInputView,
- "config": {
- "min": -100,
- "max": 100,
- "value": 0
- }
- }
- },
- "Hue, Saturation and Lightness": {
- "hue": {
- "label": "Hue Shift",
- "class": SliderPreciseInputView,
- "config": {
- "min": -100,
- "max": 100,
- "value": 0
- }
- },
- "saturation": {
- "label": "Saturation Adjustment",
- "class": SliderPreciseInputView,
- "config": {
- "min": -100,
- "max": 100,
- "value": 0
- }
- },
- "lightness": {
- "label": "Lightness Enhancement",
- "class": SliderPreciseInputView,
- "config": {
- "min": 0,
- "max": 100,
- "value": 0
- }
- }
- },
- "Noise": {
- "hueNoise": {
- "label": "Hue Noise",
- "class": SliderPreciseInputView,
- "config": {
- "min": 0,
- "max": 100,
- "value": 0
- }
- },
- "saturationNoise": {
- "label": "Saturation Noise",
- "class": SliderPreciseInputView,
- "config": {
- "min": 0,
- "max": 100,
- "value": 0
- }
- },
- "lightnessNoise": {
- "label": "Lightness Noise",
- "class": SliderPreciseInputView,
- "config": {
- "min": 0,
- "max": 100,
- "value": 0
- }
- }
- }
- };
-
- /**
- * @var object Default values
- */
- static defaultValues = {
- "red": 0,
- "green": 0,
- "blue": 0,
- "brightness": 0,
- "contrast": 0,
- "hue": 0,
- "saturation": 0,
- "lightness": 0,
- "hueNoise": 0,
- "saturationNoise": 0,
- "lightnessNoise": 0
- };
-};
+import { ImageEditorPromptNodeOptionsFormView } from "./image-editor/prompt.mjs";
+import { ImageEditorScribbleNodeOptionsFormView } from "./image-editor/scribble.mjs";
+import { ImageEditorImageNodeOptionsFormView } from "./image-editor/image.mjs";
+import { ImageEditorVideoNodeOptionsFormView } from "./image-editor/video.mjs";
+import {
+ ImageFilterFormView,
+ ImageAdjustmentFormView
+} from "./image-editor/filter.mjs";
export {
- ImageEditorNodeOptionsFormView,
+ ImageEditorScribbleNodeOptionsFormView,
+ ImageEditorPromptNodeOptionsFormView,
ImageEditorImageNodeOptionsFormView,
+ ImageEditorVideoNodeOptionsFormView,
ImageFilterFormView,
ImageAdjustmentFormView
};
diff --git a/src/js/forms/enfugue/image-editor/filter.mjs b/src/js/forms/enfugue/image-editor/filter.mjs
new file mode 100644
index 00000000..0c5762c1
--- /dev/null
+++ b/src/js/forms/enfugue/image-editor/filter.mjs
@@ -0,0 +1,225 @@
+/** @module forms/enfugue/image-editor/filter */
+import { isEmpty } from "../../../base/helpers.mjs";
+import { FormView } from "../../../forms/base.mjs";
+import {
+ FilterSelectInputView,
+ SliderPreciseInputView
+} from "../../../forms/input.mjs";
+
+/**
+ * Creates a common form view base for filter forms
+ */
+class ImageFilterFormView extends FormView {
+ /**
+ * @var bool autosubmit
+ */
+ static autoSubmit = true;
+
+ /**
+ * @var bool Disable disabling
+ */
+ static disableOnSubmit = false;
+
+ /**
+ * Fieldsets include the main filter, then inputs for filter types
+ */
+ static fieldSets = {
+ "Filter": {
+ "filter": {
+ "class": FilterSelectInputView,
+ }
+ },
+ "Size": {
+ "size": {
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 4,
+ "max": 64,
+ "step": 1,
+ "value": 4
+ }
+ }
+ },
+ "Radius": {
+ "radius": {
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 1,
+ "max": 64,
+ "step": 1,
+ "value": 1
+ }
+ }
+ },
+ "Weight": {
+ "weight": {
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0,
+ "max": 100,
+ "step": 1,
+ "value": 0
+ }
+ }
+ }
+ };
+
+ /**
+ * @var object Default values
+ */
+ static defaultValues = {
+ "filter": null,
+ "size": 16,
+ "radius": 2,
+ "weight": 0
+ };
+
+ /**
+ * @var object Callable conditions for fieldset display
+ */
+ static fieldSetConditions = {
+ "Size": (values) => ["pixelize"].indexOf(values.filter) !== -1,
+ "Radius": (values) => ["gaussian", "box", "sharpen"].indexOf(values.filter) !== -1,
+ "Weight": (values) => ["sharpen"].indexOf(values.filter) !== -1
+ };
+};
+
+/**
+ * Creates a form view for controlling the ImageAdjustmentFilter
+ */
+class ImageAdjustmentFormView extends ImageFilterFormView {
+ /**
+ * @var object Various options available
+ */
+ static fieldSets = {
+ "Color Channel Adjustments": {
+ "red": {
+ "label": "Red Amount",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": -100,
+ "max": 100,
+ "value": 0
+ }
+ },
+ "green": {
+ "label": "Green Amount",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": -100,
+ "max": 100,
+ "value": 0
+ }
+ },
+ "blue": {
+ "label": "Blue Amount",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": -100,
+ "max": 100,
+ "value": 0
+ }
+ }
+ },
+ "Brightness and Contrast": {
+ "brightness": {
+ "label": "Brightness Adjustment",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": -100,
+ "max": 100,
+ "value": 0
+ }
+ },
+ "contrast": {
+ "label": "Contrast Adjustment",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": -100,
+ "max": 100,
+ "value": 0
+ }
+ }
+ },
+ "Hue, Saturation and Lightness": {
+ "hue": {
+ "label": "Hue Shift",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": -100,
+ "max": 100,
+ "value": 0
+ }
+ },
+ "saturation": {
+ "label": "Saturation Adjustment",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": -100,
+ "max": 100,
+ "value": 0
+ }
+ },
+ "lightness": {
+ "label": "Lightness Enhancement",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0,
+ "max": 100,
+ "value": 0
+ }
+ }
+ },
+ "Noise": {
+ "hueNoise": {
+ "label": "Hue Noise",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0,
+ "max": 100,
+ "value": 0
+ }
+ },
+ "saturationNoise": {
+ "label": "Saturation Noise",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0,
+ "max": 100,
+ "value": 0
+ }
+ },
+ "lightnessNoise": {
+ "label": "Lightness Noise",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0,
+ "max": 100,
+ "value": 0
+ }
+ }
+ }
+ };
+
+ /**
+ * @var object Default values
+ */
+ static defaultValues = {
+ "red": 0,
+ "green": 0,
+ "blue": 0,
+ "brightness": 0,
+ "contrast": 0,
+ "hue": 0,
+ "saturation": 0,
+ "lightness": 0,
+ "hueNoise": 0,
+ "saturationNoise": 0,
+ "lightnessNoise": 0
+ };
+};
+
+export {
+ ImageFilterFormView,
+ ImageAdjustmentFormView
+};
diff --git a/src/js/forms/enfugue/image-editor/image.mjs b/src/js/forms/enfugue/image-editor/image.mjs
new file mode 100644
index 00000000..b64a7463
--- /dev/null
+++ b/src/js/forms/enfugue/image-editor/image.mjs
@@ -0,0 +1,105 @@
+/** @module forms/enfugue/image-editor/image */
+import { isEmpty } from "../../../base/helpers.mjs";
+import { FormView } from "../../../forms/base.mjs";
+import {
+ PromptInputView,
+ ButtonInputView,
+ FloatInputView,
+ NumberInputView,
+ CheckboxInputView,
+ ImageColorSpaceInputView,
+ ControlNetInputView,
+ ImageFitInputView,
+ ImageAnchorInputView,
+ FilterSelectInputView,
+ SliderPreciseInputView,
+ ControlNetUnitsInputView,
+} from "../../../forms/input.mjs";
+
+/**
+ * This form combines all image options.
+ */
+class ImageEditorImageNodeOptionsFormView extends FormView {
+ /**
+ * @var object The fieldsets of the options form for image mode.
+ */
+ static fieldSets = {
+ "Image Modifications": {
+ "fit": {
+ "label": "Image Fit",
+ "class": ImageFitInputView
+ },
+ "anchor": {
+ "label": "Image Anchor",
+ "class": ImageAnchorInputView
+ },
+ "removeBackground": {
+ "label": "Remove Background",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "Before processing, run this image through an AI background removal process. If you are additionally inpainting, inferencing or using this image for ControlNet, that background will then be filled in within this frame. If you are not, that background will be filled when the overall canvas image is finally painted in."
+ }
+ }
+ },
+ "Image Roles": {
+ "denoise": {
+ "label": "Denoising (Image to Image)",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When checked, use this image as input for primary diffusion. If unchecked and no other roles are selected, this will be treated as a pass-through image to be displayed in it's position as-is."
+ }
+ },
+ "imagePrompt": {
+ "label": "Prompt (IP Adapter)",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When checked, use this image for Image Prompting. This uses a technique whereby your image is analyzed for descriptors automatically and the 'image prompt' is merged with your real prompt. This can help produce variations of an image without adhering too closely to the original image, and without you having to describe the image yourself."
+ }
+ },
+ "control": {
+ "label": "ControlNet (Canny Edge, Depth, etc.)",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When checked, use this image for ControlNet input. This is a technique where your image is processed in some way prior to being used alongside primary inference to try and guide the diffusion process. Effectively, this will allow you to 'extract' features from your image such as edges or a depth map, and transfer them to a new image."
+ }
+ }
+ },
+ "Image Prompt": {
+ "imagePromptScale": {
+ "label": "Image Prompt Scale",
+ "class": FloatInputView,
+ "config": {
+ "tooltip": "How much strength to give to the image. A higher strength will reduce the effect of your prompt, and a lower strength will increase the effect of your prompt but reduce the effect of the image.",
+ "min": 0,
+ "step": 0.01,
+ "value": 0.5
+ }
+ },
+ },
+ "ControlNet Units": {
+ "controlnetUnits": {
+ "class": ControlNetUnitsInputView
+ }
+ }
+ };
+
+ /**
+ * @var object The conditions for display of some inputs.
+ */
+ static fieldSetConditions = {
+ "Image Prompt": (values) => values.imagePrompt,
+ "ControlNet Units": (values) => values.control
+ };
+
+ /**
+ * @var bool Never show submit button
+ */
+ static autoSubmit = true;
+
+ /**
+ * @var string An additional classname for this form
+ */
+ static className = "image-options-form-view";
+};
+
+export { ImageEditorImageNodeOptionsFormView };
diff --git a/src/js/forms/enfugue/image-editor/prompt.mjs b/src/js/forms/enfugue/image-editor/prompt.mjs
new file mode 100644
index 00000000..643257b7
--- /dev/null
+++ b/src/js/forms/enfugue/image-editor/prompt.mjs
@@ -0,0 +1,90 @@
+/** @module forms/enfugue/image-editor/prompt */
+import { isEmpty } from "../../../base/helpers.mjs";
+import { FormView } from "../../../forms/base.mjs";
+import {
+ PromptInputView,
+ FloatInputView,
+ NumberInputView,
+ CheckboxInputView,
+} from "../../../forms/input.mjs";
+
+class ImageEditorPromptNodeOptionsFormView extends FormView {
+ /**
+ * @var object The fieldsets of the options form for image mode.
+ */
+ static fieldSets = {
+ "Prompts": {
+ "prompt": {
+ "label": "Prompt",
+ "class": PromptInputView,
+ "config": {
+ "tooltip": "This prompt will control what is in the region prompt node. The global prompt will not be used."
+ }
+ },
+ "negativePrompt": {
+ "label": "Negative Prompt",
+ "class": PromptInputView,
+ "config": {
+ "tooltip": "This prompt will control what is in the region prompt node. The global prompt will not be used."
+ }
+ },
+ },
+ "Node": {
+ "scaleToModelSize": {
+ "label": "Scale to Model Size",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When this node has any dimension smaller than the size of the configured model, scale it up so it's smallest dimension is the same size as the model, then scale it down after diffusion.
This generally improves image quality in slightly rectangular shapes or square shapes smaller than the engine size, but can also result in ghosting and increased processing time.
This will have no effect if your node is larger than the model size in all dimensions.
If unchecked and your node is smaller than the model size, TensorRT will be disabled for this node."
+ },
+ },
+ "removeBackground": {
+ "label": "Remove Background",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "After diffusion, run the resulting image though an AI background removal algorithm. This can improve image consistency when using multiple nodes."
+ }
+ }
+ },
+ "Global Tweaks Overrides": {
+ "guidanceScale": {
+ "label": "Guidance Scale",
+ "class": FloatInputView,
+ "config": {
+ "min": 0.0,
+ "max": 100.0,
+ "step": 0.1,
+ "value": null,
+ "tooltip": "How closely to follow the text prompt; high values result in high-contrast images closely adhering to your text, low values result in low-contrast images with more randomness. When left blank, the global guidance scale will be used."
+ }
+ },
+ "inferenceSteps": {
+ "label": "Inference Steps",
+ "class": NumberInputView,
+ "config": {
+ "min": 5,
+ "max": 250,
+ "step": 1,
+ "value": null,
+ "tooltip": "How many steps to take during primary inference, larger values take longer to process. When left blank, the global inference steps will be used."
+ }
+ }
+ }
+ };
+
+ /**
+ * @var array Hide override fields
+ */
+ static collapseFieldSets = ["Global Tweaks Overrides"];
+
+ /**
+ * @var bool Never show submit button
+ */
+ static autoSubmit = true;
+
+ /**
+ * @var string An additional classname for this form
+ */
+ static className = "options-form-view";
+};
+
+export { ImageEditorPromptNodeOptionsFormView };
diff --git a/src/js/forms/enfugue/image-editor/scribble.mjs b/src/js/forms/enfugue/image-editor/scribble.mjs
new file mode 100644
index 00000000..d5c631a8
--- /dev/null
+++ b/src/js/forms/enfugue/image-editor/scribble.mjs
@@ -0,0 +1,64 @@
+/** @module forms/enfugue/image-editor/scribble */
+import { isEmpty } from "../../../base/helpers.mjs";
+import { FormView } from "../../../forms/base.mjs";
+import {
+ PromptInputView,
+ FloatInputView,
+ NumberInputView,
+ CheckboxInputView,
+ SliderPreciseInputView
+} from "../../../forms/input.mjs";
+
+class ImageEditorScribbleNodeOptionsFormView extends FormView {
+ /**
+ * @var object The fieldsets of the options form for image mode.
+ */
+ static fieldSets = {
+ "Scribble ControlNet Parameters": {
+ "conditioningScale": {
+ "label": "Conditioning Scale",
+ "class": FloatInputView,
+ "config": {
+ "min": 0.0,
+ "step": 0.01,
+ "value": 1.0,
+ "tooltip": "How closely to follow the Scribble ControlNet's influence. Typical values vary, usually values between 0.5 and 1.0 produce good conditioning with balanced randomness, but other values may produce something closer to the desired result."
+ }
+ },
+ "conditioningStart": {
+ "label": "Conditioning Start",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "value": 0.0,
+ "tooltip": "When to begin using the Scribble ControlNet for influence. Defaults to the beginning of generation."
+ }
+ },
+ "conditioningEnd": {
+ "label": "Conditioning End",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "value": 1.0,
+ "tooltip": "When to stop using the Scribble ControlNet for influence. Defaults to the end of generation."
+ }
+ },
+ }
+ };
+
+ /**
+ * @var bool Never show submit button
+ */
+ static autoSubmit = true;
+
+ /**
+ * @var string An additional classname for this form
+ */
+ static className = "options-form-view";
+};
+
+export { ImageEditorScribbleNodeOptionsFormView };
diff --git a/src/js/forms/enfugue/image-editor/video.mjs b/src/js/forms/enfugue/image-editor/video.mjs
new file mode 100644
index 00000000..fe383190
--- /dev/null
+++ b/src/js/forms/enfugue/image-editor/video.mjs
@@ -0,0 +1,35 @@
+/** @module forms/enfugue/image-editor/video */
+import { NumberInputView } from "../../../forms/input.mjs";
+import { ImageEditorImageNodeOptionsFormView } from "./image.mjs";
+
+class ImageEditorVideoNodeOptionsFormView extends ImageEditorImageNodeOptionsFormView {
+ static fieldSets = {
+ ...ImageEditorImageNodeOptionsFormView.fieldSets,
+ ...{
+ "Video Options": {
+ "skipFrames": {
+ "label": "Skip Frames",
+ "class": NumberInputView,
+ "config": {
+ "min": 0,
+ "step": 1,
+ "value": 0,
+ "tooltip": "If set, this many frames will be skipped from the beginning of the video."
+ }
+ },
+ "divideFrames": {
+ "class": NumberInputView,
+ "label": "Divide Frames",
+ "config": {
+ "min": 1,
+ "step": 1,
+ "value": 1,
+ "tooltip": "If set, only the frames that are divided evenly by this number will be extracted. A value of 1 represents all frames being extracted. A value of 2 represents every other frame, 3 every third frame, etc."
+ }
+ }
+ }
+ }
+ }
+};
+
+export { ImageEditorVideoNodeOptionsFormView };
diff --git a/src/js/forms/enfugue/inpainting.mjs b/src/js/forms/enfugue/inpainting.mjs
new file mode 100644
index 00000000..fa41f1cd
--- /dev/null
+++ b/src/js/forms/enfugue/inpainting.mjs
@@ -0,0 +1,87 @@
+/** @module forms/enfugue/prompts */
+import { FormView } from "../base.mjs";
+import {
+ CheckboxInputView,
+ NumberInputView
+} from "../input.mjs";
+
+/**
+ * The prompts form is always shown and allows for two text inputs
+ */
+class InpaintingFormView extends FormView {
+ /**
+ * @var bool Don't show submit button
+ */
+ static autoSubmit = true;
+
+ /**
+ * @var object The field sets
+ */
+ static fieldSets = {
+ "Inpainting": {
+ "outpaint": {
+ "label": "Enable Outpainting",
+ "class": CheckboxInputView,
+ "config": {
+ "value": true,
+ "tooltip": "When enabled, enfugue will automatically fill any transparency remaining after merging layers down. If there is no transparency, this step will be skipped."
+ }
+ },
+ "inpaint": {
+ "label": "Enable Inpainting",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When checked, you can additionally draw your own mask over any images on the canvas. If there is transparency, it will also be filled in addition to anywhere you draw."
+ }
+ },
+ "cropInpaint": {
+ "label": "Use Cropped Inpainting/Outpainting",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When enabled, enfugue will crop to the inpainted region before executing. This saves processing time and can help with small modifications on large images.",
+ "value": true
+ }
+ },
+ "inpaintFeather": {
+ "label": "Feather Amount",
+ "class": NumberInputView,
+ "config": {
+ "min": 0,
+ "max": 512,
+ "value": 32,
+ "step": 1,
+ "tooltip": "The number of pixels to use as a blending region for cropped inpainting. These will be blended smoothly into the final image to relieve situations where the cropped inpaint is noticably different from the rest of the image."
+ }
+ }
+ }
+ };
+
+ /**
+ * Check classes on submit
+ */
+ async submit() {
+ await super.submit();
+
+ if (this.values.inpaint || this.values.outpaint) {
+ this.removeClass("no-inpaint");
+ } else {
+ this.addClass("no-inpaint");
+ }
+
+ if (this.values.cropInpaint) {
+ this.removeClass("no-cropped-inpaint");
+ } else {
+ this.addClass("no-cropped-inpaint");
+ }
+
+ let outpaintInput = (await this.getInputView("outpaint"));
+ if (this.values.inpaint) {
+ outpaintInput.setValue(true, false);
+ outpaintInput.disable();
+ } else {
+ outpaintInput.enable();
+ }
+ }
+}
+
+export { InpaintingFormView };
diff --git a/src/js/forms/enfugue/ip-adapter.mjs b/src/js/forms/enfugue/ip-adapter.mjs
new file mode 100644
index 00000000..b5a57651
--- /dev/null
+++ b/src/js/forms/enfugue/ip-adapter.mjs
@@ -0,0 +1,36 @@
+/** @module forms/enfugue/ip-adapter */
+import { FormView } from "../base.mjs";
+import { SelectInputView } from "../input.mjs";
+
+/**
+ * The form class containing the strength slider
+ */
+class IPAdapterFormView extends FormView {
+ /**
+ * @var object All field sets and their config
+ */
+ static fieldSets = {
+ "IP Adapter": {
+ "ipAdapterModel": {
+ "label": "Model",
+ "class": SelectInputView,
+ "config": {
+ "value": "default",
+ "options": {
+ "default": "Default",
+ "plus": "Plus",
+ "plus-face": "Plus Face"
+ },
+ "tooltip": "Which IP adapter model to use. 'Plus' will in general find more detail in the source image while considerably adjusting the impact of your prompt, and 'Plus Face' will ignore much of the image except for facial features."
+ }
+ }
+ }
+ };
+
+ /**
+ * @var bool Hide submit
+ */
+ static autoSubmit = true;
+};
+
+export { IPAdapterFormView };
diff --git a/src/js/forms/enfugue/models.mjs b/src/js/forms/enfugue/models.mjs
index 784e5622..0e311e3b 100644
--- a/src/js/forms/enfugue/models.mjs
+++ b/src/js/forms/enfugue/models.mjs
@@ -14,14 +14,12 @@ import {
ModelMergeModeInputView,
VaeInputView,
CheckpointInputView,
- EngineSizeInputView,
- InpainterEngineSizeInputView,
- RefinerEngineSizeInputView,
SchedulerInputView,
PromptInputView,
SliderPreciseInputView,
FloatInputView,
- MaskTypeInputView
+ MaskTypeInputView,
+ MotionModuleInputView,
} from "../input.mjs";
/**
@@ -107,22 +105,9 @@ class ModelFormView extends FormView {
"class": VaeInputView,
"label": "Inpainting VAE"
},
- },
- "Engine": {
- "size": {
- "class": EngineSizeInputView,
- "label": "Size",
- "config": {
- "required": true,
- }
- },
- "refiner_size": {
- "class": RefinerEngineSizeInputView,
- "label": "Refiner Size"
- },
- "inpainter_size": {
- "class": InpainterEngineSizeInputView,
- "label": "Inpainter Size"
+ "motion_module": {
+ "label": "Motion Module",
+ "class": MotionModuleInputView
}
},
"Prompts": {
@@ -142,43 +127,6 @@ class ModelFormView extends FormView {
"class": SchedulerInputView,
"label": "Scheduler"
},
- "width": {
- "label": "Width",
- "class": NumberInputView,
- "config": {
- "tooltip": "The width of the canvas in pixels.",
- "min": 128,
- "max": 16384,
- "step": 8,
- "value": null
- }
- },
- "height": {
- "label": "Height",
- "class": NumberInputView,
- "config": {
- "tooltip": "The height of the canvas in pixels.",
- "min": 128,
- "max": 16384,
- "step": 8,
- "value": null
- }
- },
- "chunking_size": {
- "label": "Chunk Size",
- "class": NumberInputView,
- "config": {
- "tooltip": "The number of pixels to move the frame when doing chunked diffusion.
When this number is greater than 0, the engine will only ever process a square in the size of the configured model size at once. After each square, the frame will be moved by this many pixels along either the horizontal or vertical axis, and then the image is re-diffused. When this number is 0, chunking is disabled, and the entire canvas will be diffused at once.
Disabling this (setting it to 0) can have varying visual results, but a guaranteed result is drastically increased VRAM usage for large images. A low number can produce more detailed results, but can be noisy, and takes longer to process. A high number is faster to process, but can have poor results especially along frame boundaries. The recommended value is set by default.
",
- "min": 0,
- "max": 2048,
- "step": 8,
- "value": null
- }
- },
- "chunking_mask_type": {
- "label": "Chunking Mask",
- "class": MaskTypeInputView
- },
"num_inference_steps": {
"label": "Inference Steps",
"class": NumberInputView,
@@ -274,16 +222,6 @@ class ModelFormView extends FormView {
}
};
- /**
- * @var array Fieldsets to hide
- */
- static collapseFieldSets = [
- "Adaptations and Modifications",
- "Additional Models",
- "Defaults",
- "Refining Defaults"
- ];
-
static fieldSetConditions = {
"Refining Defaults": (values) => !isEmpty(values.refiner)
};
@@ -298,20 +236,15 @@ class AbridgedModelFormView extends ModelFormView {
*/
static className = "model-configuration-form-view";
- /**
- * @var boolean no submit button
- */
- static autoSubmit = true;
-
/**
* @var bool No cancel
*/
static canCancel = false;
/**
- * @var boolean Start hidden
+ * @var boolean No hiding
*/
- static collapseFieldSets = true;
+ static collapseFieldSets = false;
/**
* @var object one fieldset describes all inputs
@@ -366,6 +299,10 @@ class AbridgedModelFormView extends ModelFormView {
"inpainter_vae": {
"label": "Inpainting VAE",
"class": VaeInputView
+ },
+ "motion_module": {
+ "label": "Motion Module",
+ "class": MotionModuleInputView
}
}
};
diff --git a/src/js/forms/enfugue/prompts.mjs b/src/js/forms/enfugue/prompts.mjs
index 89d39b4f..961aeaa5 100644
--- a/src/js/forms/enfugue/prompts.mjs
+++ b/src/js/forms/enfugue/prompts.mjs
@@ -1,6 +1,10 @@
/** @module forms/enfugue/prompts */
import { FormView } from "../base.mjs";
-import { PromptInputView } from "../input.mjs";
+import {
+ PromptInputView,
+ NumberInputView,
+ CheckboxInputView
+} from "../input.mjs";
/**
* Extends the prompt input view to look for ctrl+enter to auto-submit parent form
@@ -39,6 +43,13 @@ class PromptsFormView extends FormView {
*/
static fieldSets = {
"Prompts": {
+ "usePromptTravel": {
+ "label": "Use Prompt Travel",
+ "class": CheckboxInputView,
+ "config": {
+ "tooltip": "When enabled, you can change prompts throughout an animation using a timeline interface. When disabled, the same problem will be used throughout the entire animation."
+ }
+ },
"prompt": {
"label": "Prompt",
"class": SubmitPromptInputView
@@ -68,4 +79,44 @@ class PromptsFormView extends FormView {
}
}
-export { PromptsFormView };
+/**
+ * The prompt travel form is form prompts with start/end frames
+ */
+class PromptTravelFormView extends FormView {
+ /**
+ * @var bool Don't show submit button
+ */
+ static autoSubmit = true;
+
+ /**
+ * @var object The field sets
+ */
+ static fieldSets = {
+ "Prompts": {
+ "positive": {
+ "label": "Prompt",
+ "class": SubmitPromptInputView
+ },
+ "negative": {
+ "label": "Negative Prompt",
+ "class": SubmitPromptInputView
+ }
+ },
+ "Weight": {
+ "weight": {
+ "class": NumberInputView,
+ "config": {
+ "min": 0.01,
+ "value": 1.0,
+ "step": 0.01,
+ "tooltip": "The weight of this prompt. It is recommended to keep your highest-weight prompt at 1.0 and scale others relative to that, but this is unconstrained."
+ }
+ }
+ }
+ };
+}
+
+export {
+ PromptsFormView,
+ PromptTravelFormView
+};
diff --git a/src/js/forms/enfugue/tweaks.mjs b/src/js/forms/enfugue/tweaks.mjs
index 69b06d3f..587bdd79 100644
--- a/src/js/forms/enfugue/tweaks.mjs
+++ b/src/js/forms/enfugue/tweaks.mjs
@@ -73,7 +73,7 @@ class TweaksFormView extends FormView {
"label": "Noise Method",
"class": NoiseMethodInputView,
"config": {
- "value": "perlin"
+ "value": "simplex"
}
},
"noiseBlendMethod": {
diff --git a/src/js/forms/enfugue/upscale.mjs b/src/js/forms/enfugue/upscale.mjs
index f0df6586..751ced30 100644
--- a/src/js/forms/enfugue/upscale.mjs
+++ b/src/js/forms/enfugue/upscale.mjs
@@ -72,17 +72,17 @@ class UpscaleFormView extends FormView {
"label": "Guidance Scale",
"class": UpscaleDiffusionGuidanceScaleInputView
},
- "chunkingSize": {
- "label": "Chunking Size",
+ "tilingStride": {
+ "label": "Tiling Stride",
"class": SelectInputView,
"config": {
- "options": ["8", "16", "32", "64", "128", "256", "512"],
+ "options": ["0", "8", "16", "32", "64", "128", "256", "512"],
"value": "128",
- "tooltip": "The number of pixels to move the frame by during diffusion. Smaller values produce better results, but take longer."
+ "tooltip": "The number of pixels to move the frame by during diffusion. Smaller values produce better results, but take longer. Set to 0 to disable."
}
},
- "chunkingMaskType": {
- "label": "Chunking Mask",
+ "tilingMaskType": {
+ "label": "Tiling Mask",
"class": MaskTypeInputView,
},
"noiseOffset": {
@@ -99,7 +99,7 @@ class UpscaleFormView extends FormView {
"label": "Noise Method",
"class": NoiseMethodInputView,
"config": {
- "value": "perlin"
+ "value": "simplex"
}
},
"noiseBlendMethod": {
diff --git a/src/js/forms/input.mjs b/src/js/forms/input.mjs
index c4dc5e71..ffaac2d9 100644
--- a/src/js/forms/input.mjs
+++ b/src/js/forms/input.mjs
@@ -43,18 +43,24 @@ import {
MultiLoraInputView,
MultiLycorisInputView,
MultiInversionInputView,
- EngineSizeInputView,
- RefinerEngineSizeInputView,
- InpainterEngineSizeInputView,
VaeInputView,
DefaultVaeInputView,
- SchedulerInputView,
ModelPickerStringInputView,
ModelPickerListInputView,
ModelPickerInputView,
- MaskTypeInputView,
- ModelMergeModeInputView
+ ModelMergeModeInputView,
+ MotionModuleInputView,
} from "./input/enfugue/models.mjs";
+import {
+ EngineSizeInputView,
+ RefinerEngineSizeInputView,
+ InpainterEngineSizeInputView,
+ SchedulerInputView,
+ MaskTypeInputView,
+ ControlNetInputView,
+ ControlNetUnitsInputView,
+ ImageColorSpaceInputView,
+} from "./input/enfugue/engine.mjs";
import {
PipelineInpaintingModeInputView,
PipelineSwitchModeInputView,
@@ -62,10 +68,8 @@ import {
PipelinePrecisionModeInputView
} from "./input/enfugue/settings.mjs";
import {
- ControlNetInputView,
ImageAnchorInputView,
ImageFitInputView,
- ImageColorSpaceInputView,
FilterSelectInputView
} from "./input/enfugue/image-editor.mjs";
import {
@@ -84,7 +88,10 @@ import {
NoiseMethodInputView,
BlendMethodInputView
} from "./input/enfugue/noise.mjs";
-
+import {
+ AnimationLoopInputView,
+ AnimationInterpolationStepsInputView,
+} from "./input/enfugue/animation.mjs";
export {
InputView,
HiddenInputView,
@@ -115,6 +122,7 @@ export {
CivitAISortInputView,
CivitAITimePeriodInputView,
ControlNetInputView,
+ ControlNetUnitsInputView,
ImageAnchorInputView,
ImageFitInputView,
ImageColorSpaceInputView,
@@ -152,5 +160,8 @@ export {
UpscaleDiffusionGuidanceScaleInputView,
NoiseOffsetInputView,
NoiseMethodInputView,
- BlendMethodInputView
+ BlendMethodInputView,
+ AnimationLoopInputView,
+ AnimationInterpolationStepsInputView,
+ MotionModuleInputView,
};
diff --git a/src/js/forms/input/enfugue/animation.mjs b/src/js/forms/input/enfugue/animation.mjs
new file mode 100644
index 00000000..79e0aab1
--- /dev/null
+++ b/src/js/forms/input/enfugue/animation.mjs
@@ -0,0 +1,72 @@
+/** @module forms/input/enfugue/animation */
+import { SelectInputView } from "../enumerable.mjs";
+import { RepeatableInputView } from "../parent.mjs";
+import { NumberInputView } from "../numeric.mjs";
+
+/**
+ * This class allows selecting looping options (or none)
+ */
+class AnimationLoopInputView extends SelectInputView {
+ /**
+ * @var bool enable selecting null
+ */
+ static allowEmpty = true;
+
+ /**
+ * @var string Text to show in null option
+ */
+ static placeholder = "No Looping";
+
+ /**
+ * @var object Selectable options
+ */
+ static defaultOptions = {
+ "reflect": "Reflect",
+ "loop": "Loop"
+ };
+
+ /**
+ * @var string tooltip to display
+ */
+ static tooltip = "When enabled the animation will loop seamlessly such that there will be no hitches when the animation is repeated. Reflect will add a reverse version of the animation at the end, with interpolation to ease the inflection points. Loop will create a properly looking animation through sliced diffusion; this will increase render time.";
+}
+
+/**
+ * Provides a repeatable input for interpolation steps
+ */
+class AnimationInterpolationStepsInputView extends RepeatableInputView {
+ /**
+ * @var class member class
+ */
+ static memberClass = NumberInputView;
+
+ /**
+ * @var object config for member class
+ */
+ static memberConfig = {
+ "min": 2,
+ "max": 10,
+ "step": 1,
+ "value": 2
+ };
+
+ /**
+ * @var string tooltip
+ */
+ static tooltip = "Interpolation is the process of trying to calculate a frame between two other frames such that when the calculated frame is placed between the two other frames, there appears to be a smooth motion betweent he three.
You can add multiple interpolation factors, where a value of 2 means that there will be two frames for every one frame (thus one frame placed in-between every frame,) a value of 3 means there will be three frames for every one frame (and thus two frames placed in-between every frame, attempting to maintain a smooth motion across all of them.) Multiple factors will be executed recursively. The smoothest results can be obtained via repeated factors of 2.";
+
+ /**
+ * @var string Text to show when no items
+ */
+ static noItemsLabel = "No Interpolation";
+
+ /**
+ * @var string Text to show in the 'add item' button
+ */
+ static addItemLabel = "Add Interpolation Step";
+}
+
+export {
+ AnimationLoopInputView,
+ AnimationInterpolationStepsInputView,
+}
diff --git a/src/js/forms/input/enfugue/engine.mjs b/src/js/forms/input/enfugue/engine.mjs
new file mode 100644
index 00000000..7a6cee12
--- /dev/null
+++ b/src/js/forms/input/enfugue/engine.mjs
@@ -0,0 +1,345 @@
+/** @module forms/input/enfugue/engine */
+import { isEmpty, deepClone, createElementsFromString } from "../../../base/helpers.mjs";
+import {
+ NumberInputView,
+ FloatInputView,
+ SliderPreciseInputView,
+} from "../numeric.mjs";
+import {
+ FormInputView,
+ RepeatableInputView
+} from "../parent.mjs";
+import { FormView } from "../../base.mjs";
+import { CheckboxInputView } from "../bool.mjs";
+import { SelectInputView } from "../enumerable.mjs";
+
+/**
+ * Engine size input
+ */
+class EngineSizeInputView extends NumberInputView {
+ /**
+ * @var int Minimum pixel size
+ */
+ static min = 128;
+
+ /**
+ * @var int Maximum pixel size
+ */
+ static max = 2048;
+
+ /**
+ * @var int Multiples of 8
+ */
+ static step = 8;
+
+ /**
+ * @var string The tooltip to display to the user
+ */
+ static tooltip = "When using tiled diffusion, this is the size of the window (in pixels) that will be encoded, decoded or inferred at once. When left blank, the tile size is equal to the training size of the base model - 512 for Stable Diffusion 1.5, or 1024 for Stable Diffusion XL.";
+};
+
+/**
+ * Default VAE Input View
+ */
+class DefaultVaeInputView extends SelectInputView {
+ /**
+ * @var object Option values and labels
+ */
+ static defaultOptions = {
+ "ema": "EMA 560000",
+ "mse": "MSE 840000",
+ "xl": "SDXL",
+ "xl16": "SDXL FP16",
+ "other": "Other"
+ };
+
+ /**
+ * @var string Default text
+ */
+ static placeholder = "Default";
+
+ /**
+ * @var bool Allow null
+ */
+ static allowEmpty = true;
+
+ /**
+ * @var string Tooltip to display
+ */
+ static tooltip = "Variational Autoencoders are the model that translates images between pixel space - images that you can see - and latent space - images that the AI model understands. In general you do not need to select a particular VAE model, but you may find slight differences in sharpness of resulting images.";
+};
+
+/**
+ * Mask Type Input View
+ */
+class MaskTypeInputView extends SelectInputView {
+ /**
+ * @var object Option values and labels
+ */
+ static defaultOptions = {
+ "constant": "Constant",
+ "bilinear": "Bilinear",
+ "gaussian": "Gaussian"
+ };
+
+ /**
+ * @var string The tooltip
+ */
+ static tooltip = "During multi-diffusion (tiled diffusion), only a square of the size of the engine is rendered at any given time. This can cause hard edges between the frames, especially when using a large stride. Using a mask allows for blending along the edges - this can remove seams, but also reduce precision.";
+
+ /**
+ * @var string Default value
+ */
+ static defaultValue = "bilinear";
+}
+
+/**
+ * Scheduler Input View
+ */
+class SchedulerInputView extends SelectInputView {
+ /**
+ * @var object Option values and labels
+ */
+ static defaultOptions = {
+ "ddim": "DDIM: Denoising Diffusion Implicit Models",
+ "ddpm": "DDPM: Denoising Diffusion Probabilistic Models",
+ "deis": "DEIS: Diffusion Exponential Integrator Sampler",
+ "dpmss": "DPM-Solver++ SDE",
+ "dpmssk": "DPM-Solver++ SDE Karras",
+ "dpmsm": "DPM-Solver++ 2M",
+ "dpmsmk": "DPM-Solver++ 2M Karras",
+ "dpmsms": "DPM-Solver++ 2M SDE",
+ "dpmsmka": "DPM-Solver++ 2M SDE Karras",
+ "heun": "Heun Discrete Scheduler",
+ "dpmd": "DPM Discrete Scheduler (KDPM2)",
+ "dpmdk": "DPM Discrete Scheduler (KDPM2) Karras",
+ "adpmd": "DPM Ancestral Discrete Scheduler (KDPM2A)",
+ "adpmdk": "DPM Ancestral Discrete Scheduler (KDPM2A) Karras",
+ "dpmsde": "DPM Solver SDE Scheduler",
+ "unipc": "UniPC: Predictor (UniP) and Corrector (UniC)",
+ "lmsd": "LMS: Linear Multi-Step Discrete Scheduler",
+ "lmsdk": "LMS: Linear Multi-Step Discrete Scheduler Karras",
+ "pndm": "PNDM: Pseudo Numerical Methods for Diffusion Models",
+ "eds": "Euler Discrete Scheduler",
+ "eads": "Euler Ancestral Discrete Scheduler",
+ };
+
+ /**
+ * @var string The tooltip
+ */
+ static tooltip = "Schedulers control how an image is denoiser over the course of the inference steps. Schedulers can have small effects, such as creating 'sharper' or 'softer' images, or drastically change the way images are constructed. Experimentation is encouraged, if additional information is sought, search Diffusers Schedulers in your search engine of choice.";
+
+ /**
+ * @var string Default text
+ */
+ static placeholder = "Default";
+
+ /**
+ * @var bool Allow null
+ */
+ static allowEmpty = true;
+};
+
+/**
+ * Add text for inpainter engine size
+ */
+class InpainterEngineSizeInputView extends EngineSizeInputView {
+ /**
+ * @var string The tooltip to display to the user
+ */
+ static tooltip = "This engine size functions the same as the base engine size, but only applies when inpainting.\n\n" + EngineSizeInputView.tooltip;
+
+ /**
+ * @var ?int no default value
+ */
+ static defaultValue = null;
+};
+
+/**
+ * Add text for refiner engine size
+ */
+class RefinerEngineSizeInputView extends EngineSizeInputView {
+ /**
+ * @var string The tooltip to display to the user
+ */
+ static tooltip = "This engine size functions the same as the base engine size, but only applies when refining.\n\n" + EngineSizeInputView.tooltip;
+
+ /**
+ * @var ?int no default value
+ */
+ static defaultValue = null;
+};
+
+/**
+ * This input allows the user to specify what colors an image is, so we can determine
+ * if we need to invert them on the backend.
+ */
+class ImageColorSpaceInputView extends SelectInputView {
+ /**
+ * @var object Only one option
+ */
+ static defaultOptions = {
+ "invert": "White on Black"
+ };
+
+ /**
+ * @var string The default option is to invert
+ */
+ static defaultValue = "invert";
+
+ /**
+ * @var string The empty option text
+ */
+ static placeholder = "Black on White";
+
+ /**
+ * @var bool Always show empty
+ */
+ static allowEmpty = true;
+}
+
+/**
+ * Which controlnets are available
+ */
+class ControlNetInputView extends SelectInputView {
+ /**
+ * @var string Set the default to the easiest and fastest
+ */
+ static defaultValue = "canny";
+
+ /**
+ * @var object The options allowed.
+ */
+ static defaultOptions = {
+ "canny": "Canny Edge Detection",
+ "hed": "Holistically-nested Edge Detection (HED)",
+ "pidi": "Soft Edge Detection (PIDI)",
+ "mlsd": "Mobile Line Segment Detection (MLSD)",
+ "line": "Line Art",
+ "anime": "Anime Line Art",
+ "scribble": "Scribble",
+ "depth": "Depth Detection (MiDaS)",
+ "normal": "Normal Detection (Estimate)",
+ "pose": "Pose Detection (DWPose/OpenPose)",
+ "qr": "QR Code"
+ };
+
+ /**
+ * @var string The tooltip to display
+ */
+ static tooltip = "The ControlNet to use depends on your input image. Unless otherwise specified, your input image will be processed through the appropriate algorithm for this ControlNet prior to diffusion.
" +
+ "Canny Edge: This network is trained on images and the edges of that image after having run through Canny Edge detection.
" +
+ "HED: Short for Holistically-Nested Edge Detection, this edge-detection algorithm is best used when the input image is too blurry or too noisy for Canny Edge detection.
" +
+ "Soft Edge Detection: Using a Pixel Difference Network, this edge-detection algorithm can be used in a wide array of applications.
" +
+ "MLSD: Short for Mobile Line Segment Detection, this edge-detection algorithm searches only for straight lines, and is best used for geometric or architectural images.
" +
+ "Line Art: This model is capable of rendering images to line art drawings. The controlnet was trained on the model output, this provides a great way to provide your own hand-drawn pieces as well as another means of edge detection.
" +
+ "Anime Line Art: This is similar to the above, but focusing specifically on anime style.
" +
+ "Scribble: This ControlNet was trained on a variant of the HED edge-detection algorithm, and is good for hand-drawn scribbles with thick, variable lines.
" +
+ "Depth: This uses Intel's MiDaS model to estimate monocular depth from a single image. This uses a greyscale image showing the distance from the camera to any given object.
" +
+ "Normal: Normal maps are similar to depth maps, but instead of using a greyscale depth, three sets of distance data is encoded into red, green and blue channels.
" +
+ "DWPose/OpenPose: OpenPose is an AI model from the Carnegie Mellon University's Perceptual Computing Lab detects human limb, face and digit poses from an image, and DWPose is a faster and more accurate model built on top of OpenPose. Using this data, you can generate different people in the same pose.
" +
+ "QR Code is a specialized control network designed to generate images from QR codes that are scannable QR codes themselves.";
+};
+
+/**
+ * All options for a control image in a form
+ */
+class ControlNetUnitFormView extends FormView {
+ /**
+ * @var object field sets for a control image
+ */
+ static fieldSets = {
+ "Control Unit": {
+ "controlnet": {
+ "label": "ControlNet",
+ "class": ControlNetInputView
+ },
+ "conditioningScale": {
+ "label": "Conditioning Scale",
+ "class": FloatInputView,
+ "config": {
+ "min": 0.0,
+ "step": 0.01,
+ "value": 1.0,
+ "tooltip": "How closely to follow ControlNet's influence. Typical values vary, usually values between 0.5 and 1.0 produce good conditioning with balanced randomness, but other values may produce something closer to the desired result."
+ }
+ },
+ "conditioningStart": {
+ "label": "Conditioning Start",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "value": 0.0,
+ "tooltip": "When to begin using this ControlNet for influence. Defaults to the beginning of generation."
+ }
+ },
+ "conditioningEnd": {
+ "label": "Conditioning End",
+ "class": SliderPreciseInputView,
+ "config": {
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "value": 1.0,
+ "tooltip": "When to stop using this ControlNet for influence. Defaults to the end of generation."
+ }
+ },
+ "processControlImage": {
+ "label": "Process Image for ControlNet",
+ "class": CheckboxInputView,
+ "config": {
+ "value": true,
+ "tooltip": "When checked, the image will be processed through the appropriate edge detection algorithm for the ControlNet. Only uncheck this if your image has already been processed through edge detection."
+ }
+ }
+ }
+ };
+
+ /**
+ * @var bool Hide submit button
+ */
+ static autoSubmit = true;
+};
+
+/**
+ * The control image form as an input
+ */
+class ControlNetUnitFormInputView extends FormInputView {
+ /**
+ * @var class The form class
+ */
+ static formClass = ControlNetUnitFormView;
+};
+
+/**
+ * The control image form input as a repeatable input
+ */
+class ControlNetUnitsInputView extends RepeatableInputView {
+ /**
+ * @var class the input class
+ */
+ static memberClass = ControlNetUnitFormInputView;
+
+ /**
+ * @var string The text to show when no items present
+ */
+ static noItemsLabel = "No ControlNet Units";
+
+ /**
+ * @var string Text in the add button
+ */
+ static addItemLabel = "Add ControlNet Unit";
+}
+
+export {
+ EngineSizeInputView,
+ RefinerEngineSizeInputView,
+ InpainterEngineSizeInputView,
+ SchedulerInputView,
+ MaskTypeInputView,
+ ControlNetInputView,
+ ControlNetUnitsInputView,
+ ImageColorSpaceInputView,
+};
diff --git a/src/js/forms/input/enfugue/image-editor.mjs b/src/js/forms/input/enfugue/image-editor.mjs
index db323c92..fcdbdb09 100644
--- a/src/js/forms/input/enfugue/image-editor.mjs
+++ b/src/js/forms/input/enfugue/image-editor.mjs
@@ -1,78 +1,6 @@
/** @module forms/input/enfugue/image-editor */
import { SelectInputView } from "../enumerable.mjs";
-/**
- * This input allows the user to specify what colors an image is, so we can determine
- * if we need to invert them on the backend.
- */
-class ImageColorSpaceInputView extends SelectInputView {
- /**
- * @var object Only one option
- */
- static defaultOptions = {
- "invert": "White on Black"
- };
-
- /**
- * @var string The default option is to invert
- */
- static defaultValue = "invert";
-
- /**
- * @var string The empty option text
- */
- static placeholder = "Black on White";
-
- /**
- * @var bool Always show empty
- */
- static allowEmpty = true;
-}
-
-/**
- * These are the ControlNet options
- */
-class ControlNetInputView extends SelectInputView {
- /**
- * @var string Set the default to the easiest and fastest
- */
- static defaultValue = "canny";
-
- /**
- * @var object The options allowed.
- */
- static defaultOptions = {
- "canny": "Canny Edge Detection",
- "hed": "Holistically-nested Edge Detection (HED)",
- "pidi": "Soft Edge Detection (PIDI)",
- "mlsd": "Mobile Line Segment Detection (MLSD)",
- "line": "Line Art",
- "anime": "Anime Line Art",
- "scribble": "Scribble",
- "depth": "Depth Detection (MiDaS)",
- "normal": "Normal Detection (Estimate)",
- "pose": "Pose Detection (DWPose/OpenPose)",
- "qr": "QR Code"
- };
-
- /**
- * @var string The tooltip to display
- */
- static tooltip = "The ControlNet to use depends on your input image. Unless otherwise specified, your input image will be processed through the appropriate algorithm for this ControlNet prior to diffusion.
" +
- "Canny Edge: This network is trained on images and the edges of that image after having run through Canny Edge detection.
" +
- "HED: Short for Holistically-Nested Edge Detection, this edge-detection algorithm is best used when the input image is too blurry or too noisy for Canny Edge detection.
" +
- "Soft Edge Detection: Using a Pixel Difference Network, this edge-detection algorithm can be used in a wide array of applications.
" +
- "MLSD: Short for Mobile Line Segment Detection, this edge-detection algorithm searches only for straight lines, and is best used for geometric or architectural images.
" +
- "Line Art: This model is capable of rendering images to line art drawings. The controlnet was trained on the model output, this provides a great way to provide your own hand-drawn pieces as well as another means of edge detection.
" +
- "Anime Line Art: This is similar to the above, but focusing specifically on anime style.
" +
- "Scribble: This ControlNet was trained on a variant of the HED edge-detection algorithm, and is good for hand-drawn scribbles with thick, variable lines.
" +
- "Depth: This uses Intel's MiDaS model to estimate monocular depth from a single image. This uses a greyscale image showing the distance from the camera to any given object.
" +
- "Normal: Normal maps are similar to depth maps, but instead of using a greyscale depth, three sets of distance data is encoded into red, green and blue channels.
" +
- "DWPose/OpenPose: OpenPose is an AI model from the Carnegie Mellon University's Perceptual Computing Lab detects human limb, face and digit poses from an image, and DWPose is a faster and more accurate model built on top of OpenPose. Using this data, you can generate different people in the same pose.
" +
- "QR Code is a specialized control network designed to generate images from QR codes that are scannable QR codes themselves.";
-};
-
-
/**
* The fit options
*/
@@ -155,9 +83,7 @@ class FilterSelectInputView extends SelectInputView {
};
export {
- ControlNetInputView,
ImageAnchorInputView,
ImageFitInputView,
- ImageColorSpaceInputView,
FilterSelectInputView
};
diff --git a/src/js/forms/input/enfugue/models.mjs b/src/js/forms/input/enfugue/models.mjs
index 2a36a23f..942164e2 100644
--- a/src/js/forms/input/enfugue/models.mjs
+++ b/src/js/forms/input/enfugue/models.mjs
@@ -60,31 +60,6 @@ class ModelPickerInputView extends SearchListInputView {
static listInputClass = ModelPickerListInputView;
};
-/**
- * Engine size input
- */
-class EngineSizeInputView extends NumberInputView {
- /**
- * @var int Minimum pixel size
- */
- static min = 128;
-
- /**
- * @var int Maximum pixel size
- */
- static max = 2048;
-
- /**
- * @var int Multiples of 8
- */
- static step = 8;
-
- /**
- * @var string The tooltip to display to the user
- */
- static tooltip = "When using chunked diffusion, this is the size of the window (in pixels) that will be encoded, decoded or inferred at once. Set the chunking size to 0 in the sidebar to disable chunked diffusion and always try to process the entire image at once.";
-};
-
/**
* Default VAE Input View
*/
@@ -216,107 +191,6 @@ class VaeInputView extends InputView {
}
};
-/**
- * Mask Type Input View
- */
-class MaskTypeInputView extends SelectInputView {
- /**
- * @var object Option values and labels
- */
- static defaultOptions = {
- "constant": "Constant",
- "bilinear": "Bilinear",
- "gaussian": "Gaussian"
- };
-
- /**
- * @var string The tooltip
- */
- static tooltip = "During multi-diffusion, only a square of the size of the engine is rendereda at any given time. This can cause hard edges between the frames, especially when using a large chunking size. Using a mask allows for blending along the edges - this can remove seams, but also reduce precision.";
-
- /**
- * @var string Default value
- */
- static defaultValue = "bilinear";
-}
-
-/**
- * Scheduler Input View
- */
-class SchedulerInputView extends SelectInputView {
- /**
- * @var object Option values and labels
- */
- static defaultOptions = {
- "ddim": "DDIM: Denoising Diffusion Implicit Models",
- "ddpm": "DDPM: Denoising Diffusion Probabilistic Models",
- "deis": "DEIS: Diffusion Exponential Integrator Sampler",
- "dpmss": "DPM-Solver++ SDE",
- "dpmssk": "DPM-Solver++ SDE Karras",
- "dpmsm": "DPM-Solver++ 2M",
- "dpmsmk": "DPM-Solver++ 2M Karras",
- "dpmsms": "DPM-Solver++ 2M SDE",
- "dpmsmka": "DPM-Solver++ 2M SDE Karras",
- "heun": "Heun Discrete Scheduler",
- "dpmd": "DPM Discrete Scheduler (KDPM2)",
- "dpmdk": "DPM Discrete Scheduler (KDPM2) Karras",
- "adpmd": "DPM Ancestral Discrete Scheduler (KDPM2A)",
- "adpmdk": "DPM Ancestral Discrete Scheduler (KDPM2A) Karras",
- "dpmsde": "DPM Solver SDE Scheduler",
- "unipc": "UniPC: Predictor (UniP) and Corrector (UniC)",
- "lmsd": "LMS: Linear Multi-Step Discrete Scheduler",
- "lmsdk": "LMS: Linear Multi-Step Discrete Scheduler Karras",
- "pndm": "PNDM: Pseudo Numerical Methods for Diffusion Models",
- "eds": "Euler Discrete Scheduler",
- "eads": "Euler Ancestral Discrete Scheduler",
- };
-
- /**
- * @var string The tooltip
- */
- static tooltip = "Schedulers control how an image is denoiser over the course of the inference steps. Schedulers can have small effects, such as creating 'sharper' or 'softer' images, or drastically change the way images are constructed. Experimentation is encouraged, if additional information is sought, search Diffusers Schedulers in your search engine of choice.";
-
- /**
- * @var string Default text
- */
- static placeholder = "Default";
-
- /**
- * @var bool Allow null
- */
- static allowEmpty = true;
-};
-
-/**
- * Add text for inpainter engine size
- */
-class InpainterEngineSizeInputView extends EngineSizeInputView {
- /**
- * @var string The tooltip to display to the user
- */
- static tooltip = "This engine size functions the same as the base engine size, but only applies when inpainting.\n\n" + EngineSizeInputView.tooltip;
-
- /**
- * @var ?int no default value
- */
- static defaultValue = null;
-};
-
-/**
- * Add text for refiner engine size
- */
-class RefinerEngineSizeInputView extends EngineSizeInputView {
- /**
- * @var string The tooltip to display to the user
- */
- static tooltip = "This engine size functions the same as the base engine size, but only applies when refining.\n\n" + EngineSizeInputView.tooltip;
-
- /**
- * @var ?int no default value
- */
- static defaultValue = null;
-};
-
/**
* Inversion input - will be populated at init.
*/
@@ -357,6 +231,16 @@ class CheckpointInputView extends SearchListInputView {
static stringInputClass = ModelPickerStringInputView;
};
+/**
+ * Motion module input - will be populated at init.
+ */
+class MotionModuleInputView extends SearchListInputView {
+ /**
+ * @var class The class of the string input, override so we can override setValue
+ */
+ static stringInputClass = ModelPickerStringInputView;
+};
+
/**
* Lora input additionally has weight; create the FormView here,
* then define a RepeatableInputView of a FormInputView
@@ -530,15 +414,11 @@ export {
MultiLoraInputView,
MultiLycorisInputView,
MultiInversionInputView,
- EngineSizeInputView,
- RefinerEngineSizeInputView,
- InpainterEngineSizeInputView,
VaeInputView,
DefaultVaeInputView,
- SchedulerInputView,
ModelPickerStringInputView,
ModelPickerListInputView,
ModelPickerInputView,
- MaskTypeInputView,
- ModelMergeModeInputView
+ ModelMergeModeInputView,
+ MotionModuleInputView,
};
diff --git a/src/js/forms/input/enfugue/noise.mjs b/src/js/forms/input/enfugue/noise.mjs
index fccf1084..3ac1c633 100644
--- a/src/js/forms/input/enfugue/noise.mjs
+++ b/src/js/forms/input/enfugue/noise.mjs
@@ -27,7 +27,7 @@ class NoiseMethodInputView extends SelectInputView {
/**
* @var string Default value
*/
- static defaultValue = "perlin";
+ static defaultValue = "simplex";
/**
* @var string tooltip
diff --git a/src/js/forms/input/enumerable.mjs b/src/js/forms/input/enumerable.mjs
index d226b479..cb49e205 100644
--- a/src/js/forms/input/enumerable.mjs
+++ b/src/js/forms/input/enumerable.mjs
@@ -562,6 +562,7 @@ class SearchListInputView extends EnumerableInputView {
this.stringInput = new this.constructor.stringInputClass(config, "autofill", {placeholder: this.constructor.placeholder});
this.stringInput.onInput((searchValue) => this.searchChanged(searchValue));
this.stringInput.onFocus(async () => {
+ await this.search(await this.stringInput.getValue());
let stringInputNode = await this.stringInput.getNode(),
listInputNode = await this.listInput.getNode(),
inputPosition = this.node.element.getBoundingClientRect(),
diff --git a/src/js/forms/input/numeric.mjs b/src/js/forms/input/numeric.mjs
index 637dc476..4d0b35f9 100644
--- a/src/js/forms/input/numeric.mjs
+++ b/src/js/forms/input/numeric.mjs
@@ -36,6 +36,11 @@ class NumberInputView extends InputView {
*/
static useMouseWheelAltKey = true;
+ /**
+ * @var Whether or not to allow null values
+ */
+ static allowNull = true;
+
/**
* The constructor just sets static values to mutable local ones.
*/
@@ -50,6 +55,9 @@ class NumberInputView extends InputView {
this.stepValue = isEmpty(fieldConfig) || isEmpty(fieldConfig.step)
? this.constructor.step
: fieldConfig.step;
+ this.allowNull = isEmpty(fieldConfig) || isEmpty(fieldConfig.allowNull)
+ ? this.constructor.allowNull
+ : fieldConfig.allowNull;
}
/**
@@ -104,6 +112,12 @@ class NumberInputView extends InputView {
let inputValue = this.getValue(),
lastValue = this.value;
+ if (isEmpty(inputValue) || isNaN(inputValue)) {
+ if (this.allowNull) return;
+ inputValue = isEmpty(this.minValue) ? 0 : this.minValue;
+ this.setValue(inputValue, false);
+ }
+
if (!isEmpty(this.minValue) && inputValue < this.minValue) {
this.setValue(this.minValue, false);
} else if (!isEmpty(this.maxValue) && inputValue > this.maxValue) {
diff --git a/src/js/forms/input/parent.mjs b/src/js/forms/input/parent.mjs
index 143ad71a..67f57d69 100644
--- a/src/js/forms/input/parent.mjs
+++ b/src/js/forms/input/parent.mjs
@@ -331,6 +331,7 @@ class FormInputView extends InputView {
if (!isEmpty(this.value)) {
this.form.setValues(this.value);
}
+ this.form.onSubmit(() => this.changed());
}
/**
diff --git a/src/js/model/enfugue.mjs b/src/js/model/enfugue.mjs
index 418a008a..6ec55c78 100644
--- a/src/js/model/enfugue.mjs
+++ b/src/js/model/enfugue.mjs
@@ -9,11 +9,33 @@ import { Model, ModelObject } from "./index.mjs";
* @var array $apiScope The scope(s) for this model (the variables that need to be
*/
class DiffusionModel extends ModelObject {
+ /**
+ * @var bool Always ask for inclusions
+ */
static alwaysInclude = true;
+
+ /**
+ * @var array Related models to include
+ */
+ static apiInclude = [
+ "refiner", "inpainter", "lora",
+ "lycoris", "inversion", "scheduler",
+ "vae", "config", "motion_module"
+ ];
+
+ /**
+ * @var string Model root
+ */
static apiRoot = "models";
+
+ /**
+ * @var array Scope for an individual item (i.e. primary keys)
+ */
static apiScope = ["name"];
- static apiInclude = ["refiner", "inpainter", "lora", "lycoris", "inversion", "scheduler", "vae", "config"];
+ /**
+ * Gets the status for a configured model
+ */
getStatus() {
return this.queryModel("get", `${this.url}/status`);
}
@@ -39,6 +61,7 @@ class DiffusionModelInpainter extends ModelObject {};
class DiffusionModelInversion extends ModelObject {};
class DiffusionModelLora extends ModelObject {};
class DiffusionModelLycoris extends ModelObject {};
+class DiffusionModelMotionModule extends ModelObject {};
class DiffusionModelDefaultConfiguration extends ModelObject {};
class DiffusionInvocation extends ModelObject {
@@ -67,6 +90,7 @@ class EnfugueModel extends Model {
DiffusionModelInpainter,
DiffusionModelInversion,
DiffusionModelScheduler,
+ DiffusionModelMotionModule,
DiffusionModelDefaultConfiguration,
DiffusionModelVAE,
DiffusionModelLora,
diff --git a/src/js/nodes/base.mjs b/src/js/nodes/base.mjs
index 194a0101..a15901ad 100644
--- a/src/js/nodes/base.mjs
+++ b/src/js/nodes/base.mjs
@@ -1,27 +1,27 @@
/** @module nodes/base */
-import { isEmpty, camelCase, merge, sleep } from '../base/helpers.mjs';
-import { View } from '../view/base.mjs';
-import { FormView } from '../forms/base.mjs';
-import { InputView, ListInputView } from '../forms/input.mjs';
-import { ElementBuilder } from '../base/builder.mjs';
-import { Point, Drawable } from '../graphics/geometry.mjs';
+import { isEmpty, camelCase, merge, sleep } from "../base/helpers.mjs";
+import { View } from "../view/base.mjs";
+import { FormView } from "../forms/base.mjs";
+import { InputView, ListInputView } from "../forms/input.mjs";
+import { ElementBuilder } from "../base/builder.mjs";
+import { Point, Drawable } from "../graphics/geometry.mjs";
const E = new ElementBuilder({
- nodeContainer: 'enfugue-node-container',
- nodeContents: 'enfugue-node-contents',
- nodeHeader: 'enfugue-node-header',
- nodeName: 'enfugue-node-name',
- nodeButton: 'enfugue-node-button',
- nodeOptionsContents: 'enfugue-node-options-contents',
- nodeOptionsInputsOutputs: 'enfugue-node-options-inputs-outputs',
- nodeOptions: 'enfugue-node-options',
- nodeInputs: 'enfugue-node-inputs',
- inputModes: 'enfugue-node-input-modes',
- nodeOutputs: 'enfugue-node-outputs',
- nodeInput: 'enfugue-node-input',
- nodeInputGroup: 'enfugue-node-input-group',
- nodeOutput: 'enfugue-node-output',
- nodeOutputGroup: 'enfugue-node-output-group'
+ nodeContainer: "enfugue-node-container",
+ nodeContents: "enfugue-node-contents",
+ nodeHeader: "enfugue-node-header",
+ nodeName: "enfugue-node-name",
+ nodeButton: "enfugue-node-button",
+ nodeOptionsContents: "enfugue-node-options-contents",
+ nodeOptionsInputsOutputs: "enfugue-node-options-inputs-outputs",
+ nodeOptions: "enfugue-node-options",
+ nodeInputs: "enfugue-node-inputs",
+ inputModes: "enfugue-node-input-modes",
+ nodeOutputs: "enfugue-node-outputs",
+ nodeInput: "enfugue-node-input",
+ nodeInputGroup: "enfugue-node-input-group",
+ nodeOutput: "enfugue-node-output",
+ nodeOutputGroup: "enfugue-node-output-group"
});
/**
@@ -51,7 +51,7 @@ class NodeView extends View {
/**
* @var string The tag name of the view.
*/
- static tagName = 'enfugue-node';
+ static tagName = "enfugue-node";
/**
* @var int The number of pixels outside of the node to allow for edge handling.
@@ -121,7 +121,7 @@ class NodeView extends View {
/**
* @var string The default cursor to show when no actions are present.
*/
- static defaultCursor = 'default';
+ static defaultCursor = "default";
/**
* @var bool Whether or not the height of the contents are fixed.
@@ -139,22 +139,22 @@ class NodeView extends View {
static nodeButtons = {};
/**
- * @var string The text of the copy button's tooltip
+ * @var string The text of the copy button"s tooltip
*/
static copyText = "Copy";
/**
- * @var string The text of the close button's tooltip
+ * @var string The text of the close button"s tooltip
*/
static closeText = "Close";
/**
- * @var string The text of the flip buttons' tooltip
+ * @var string The text of the flip buttons" tooltip
*/
static headerBottomText = "Flip Header to Bottom";
/**
- * @var string The text of the flip buttons' tooltip on the bottom
+ * @var string The text of the flip buttons" tooltip on the bottom
*/
static headerTopText = "Flip Header to Top";
@@ -168,11 +168,6 @@ class NodeView extends View {
*/
static headerTopIcon = "fa-solid fa-arrow-turn-up";
- /**
- * Whether or not this node can be merged with other nodes of the same class
- */
- static canMerge = false;
-
constructor(editor, name, content, left, top, width, height) {
super(editor.config);
@@ -198,6 +193,7 @@ class NodeView extends View {
this.canMerge = this.constructor.canMerge;
this.closeCallbacks = [];
this.resizeCallbacks = [];
+ this.nameChangeCallbacks = [];
}
/**
@@ -222,7 +218,7 @@ class NodeView extends View {
this.content = newContent;
if (this.node !== undefined) {
this.node
- .find(E.getCustomTag('nodeContents'))
+ .find(E.getCustomTag("nodeContents"))
.content(await this.getContent());
}
return this;
@@ -345,7 +341,7 @@ class NodeView extends View {
);
if (this.node !== undefined) {
- this.node.find(E.getCustomTag('nodeName')).content(this.name);
+ this.node.find(E.getCustomTag("nodeName")).content(this.name);
}
return this;
@@ -356,13 +352,29 @@ class NodeView extends View {
*
* @param string $newName The new name, which will populate in the DOM.
*/
- setName(newName) {
+ setName(newName, fillNode = true, triggerCallbacks = true) {
this.name = newName;
- if (this.node !== undefined) {
- this.node.find(E.getCustomTag('nodeName')).content(newName);
+ if (fillNode) {
+ if (this.node !== undefined) {
+ this.node.find(E.getCustomTag("nodeName")).content(newName);
+ }
+ }
+ if (triggerCallbacks) {
+ for (let callback of this.nameChangeCallbacks) {
+ callback(newName);
+ }
}
}
+ /**
+ * Adds a callback when name is changed
+ *
+ * @param callable $callback The callback to execute
+ */
+ onNameChange(callback) {
+ this.nameChangeCallbacks.push(callback);
+ }
+
/**
* Gets the name either from memory or the DOM.
*
@@ -370,16 +382,18 @@ class NodeView extends View {
*/
getName() {
if (this.node === undefined) return this.name;
- return this.node.find(E.getCustomTag('nodeName')).getText();
+ return this.node.find(E.getCustomTag("nodeName")).getText();
}
/**
* Removes this node from the editor.
*/
- remove() {
+ remove(triggerClose = true) {
this.removed = true;
this.editor.removeNode(this);
- this.closed();
+ if (triggerClose) {
+ this.closed();
+ }
}
/**
@@ -474,9 +488,8 @@ class NodeView extends View {
* @param int $width The width of the node
* @param int $height The height of the node
* @param bool $save Whether or not to save tese dimensions as the new configured dimensions.
- * @param bool $triggerMerge whether or not to trigger merges if there is a mergeable node.
*/
- setDimension(left, top, width, height, save, triggerMerge = true) {
+ setDimension(left, top, width, height, save, triggerEvents = true) {
left = this.constructor.getNearestSnap(left);
top = this.constructor.getNearestSnap(top);
width = this.getWidthSnap(width, left);
@@ -522,7 +535,7 @@ class NodeView extends View {
this.visibleWidth = width;
this.visibleHeight = height;
- if (triggerMerge) {
+ if (triggerEvents) {
if (save) {
this.editor.nodePlaced(this);
} else {
@@ -566,12 +579,12 @@ class NodeView extends View {
let button = E.nodeButton()
.class(`node-button-${camelCase(buttonName)}`)
.content(E.i().class(buttonConfiguration.icon))
- .on('click', (e) => {
+ .on("click", (e) => {
buttonConfiguration.callback.call(buttonConfiguration.context || this, e);
});
if (buttonConfiguration.tooltip) {
- button.data('tooltip', buttonConfiguration.tooltip);
+ button.data("tooltip", buttonConfiguration.tooltip);
}
nodeHeader.append(button);
@@ -590,19 +603,22 @@ class NodeView extends View {
*/
rebuildHeaderButtons() {
if (this.node !== undefined) {
- let nodeHeader = this.node.find(E.getCustomTag("nodeHeader"));
-
- for (let currentButton of nodeHeader.children()) {
- if (currentButton.tagName == E.getCustomTag("nodeButton")) {
- try {
- nodeHeader.remove(currentButton);
- } catch(e) {
- // Might have been removed already, continue
+ this.lock.acquire().then((release) => {
+ let nodeHeader = this.node.find(E.getCustomTag("nodeHeader"));
+
+ for (let currentButton of nodeHeader.children()) {
+ if (currentButton.tagName == E.getCustomTag("nodeButton")) {
+ try {
+ nodeHeader.remove(currentButton);
+ } catch(e) {
+ // Might have been removed already, continue
+ }
}
}
- }
-
- this.buildHeaderButtons(nodeHeader, this.buttons);
+ this.buildHeaderButtons(nodeHeader, this.buttons);
+ nodeHeader.render();
+ release();
+ });
}
}
@@ -621,12 +637,14 @@ class NodeView extends View {
.content(nodeName)
.css({
height: `${this.constructor.headerHeight}px`,
- 'line-height': `${this.constructor.headerHeight}px`
+ "line-height": `${this.constructor.headerHeight}px`
});
if (this.constructor.canRename) {
- nodeName.editable();
- nodeHeader.on('dblclick', (e) => {
+ nodeName.editable().on("input", () => {
+ this.setName(nodeName.getText(), false);
+ });
+ nodeHeader.on("dblclick", (e) => {
e.preventDefault();
e.stopPropagation();
nodeName.focus();
@@ -635,7 +653,6 @@ class NodeView extends View {
if (this.constructor.hideHeader) {
node.addClass("hide-header");
- nodeHeader.css('height', 0);
}
let buttons = {};
@@ -645,7 +662,7 @@ class NodeView extends View {
}
if (this.constructor.canCopy) {
buttons.copy = {
- icon: 'fa-solid fa-copy',
+ icon: "fa-solid fa-copy",
tooltip: this.constructor.copyText,
shortcut: "p",
context: this,
@@ -668,7 +685,7 @@ class NodeView extends View {
if (this.constructor.canClose) {
buttons.close = {
shortcut: "v",
- icon: 'fa-solid fa-window-close',
+ icon: "fa-solid fa-window-close",
tooltip: this.constructor.closeText,
context: this,
callback: () => {
@@ -772,22 +789,7 @@ class NodeView extends View {
height: `${this.height}px`,
padding: `${this.constructor.padding}px`
})
- .on('mouseenter', (e) => {
- if (this.fixed) return;
- if (this.constructor.hideHeader) {
- nodeHeader.css(
- 'height',
- `${this.constructor.headerHeight}px`
- );
- }
- })
- .on('mouseleave', (e) => {
- if (this.fixed) return;
- if (this.constructor.hideHeader) {
- nodeHeader.css('height', '0');
- }
- })
- .on('mousemove', (e) => {
+ .on("mousemove", (e) => {
if (this.fixed) return;
if (cursorMode == NodeCursorMode.NONE) {
// If there is no cursor mode assigned,
@@ -860,31 +862,31 @@ class NodeView extends View {
switch (nextMode) {
// Set the cursor as the indication to the user.
case NodeCursorMode.MOVE:
- node.css('cursor', 'grab');
+ node.css("cursor", "grab");
break;
case NodeCursorMode.RESIZE_NE:
case NodeCursorMode.RESIZE_SW:
- node.css('cursor', 'nesw-resize');
+ node.css("cursor", "nesw-resize");
break;
case NodeCursorMode.RESIZE_N:
case NodeCursorMode.RESIZE_S:
- node.css('cursor', 'ns-resize');
+ node.css("cursor", "ns-resize");
break;
case NodeCursorMode.RESIZE_E:
case NodeCursorMode.RESIZE_W:
- node.css('cursor', 'ew-resize');
+ node.css("cursor", "ew-resize");
break;
case NodeCursorMode.RESIZE_NW:
case NodeCursorMode.RESIZE_SE:
- node.css('cursor', 'nwse-resize');
+ node.css("cursor", "nwse-resize");
break;
default:
- node.css('cursor', this.constructor.defaultCursor);
+ node.css("cursor", this.constructor.defaultCursor);
break;
}
}
})
- .on('mousedown', (e) => {
+ .on("mousedown", (e) => {
if (
this.fixed ||
e.which !== 1 ||
@@ -892,7 +894,7 @@ class NodeView extends View {
cursorMode !== NodeCursorMode.NONE
)
return;
- /* On mousedown, we'll determine the final mode of the cursor,
+ /* On mousedown, we"ll determine the final mode of the cursor,
* and initiate the action.
*
* We also bind the mousemove() and mouseup() listeners within
@@ -921,27 +923,27 @@ class NodeView extends View {
switch (cursorMode) {
case NodeCursorMode.MOVE:
- this.editor.node.css('cursor', 'grab');
+ this.editor.node.css("cursor", "grab");
break;
case NodeCursorMode.RESIZE_NE:
case NodeCursorMode.RESIZE_SW:
- this.editor.node.css('cursor', 'nesw-resize');
+ this.editor.node.css("cursor", "nesw-resize");
break;
case NodeCursorMode.RESIZE_N:
case NodeCursorMode.RESIZE_S:
- this.editor.node.css('cursor', 'ns-resize');
+ this.editor.node.css("cursor", "ns-resize");
break;
case NodeCursorMode.RESIZE_E:
case NodeCursorMode.RESIZE_W:
- this.editor.node.css('cursor', 'ew-resize');
+ this.editor.node.css("cursor", "ew-resize");
break;
case NodeCursorMode.RESIZE_NW:
case NodeCursorMode.RESIZE_SE:
- this.editor.node.css('cursor', 'nwse-resize');
+ this.editor.node.css("cursor", "nwse-resize");
break;
default:
this.editor.node.css(
- 'cursor',
+ "cursor",
this.constructor.defaultCursor
);
break;
@@ -953,27 +955,27 @@ class NodeView extends View {
cursorMode = NodeCursorMode.NONE;
[startPositionX, startPositionY] = [null, null];
this.editor.node
- .off('mouseup,mouseleave,mousemove')
- .css('cursor', this.constructor.defaultCursor);
- node.off('mouseup');
+ .off("mouseup,mouseleave,mousemove")
+ .css("cursor", this.constructor.defaultCursor);
+ node.off("mouseup");
if (this.editor.constructor.disableCursor) {
this.editor.node.css("pointer-events", "none");
}
};
this.editor.node
- .on('mousemove', (e2) => {
+ .on("mousemove", (e2) => {
e2.preventDefault();
e2.stopPropagation();
setNodeDimension(e2, false);
})
- .on('mouseup,mouseleave', (e2) => {
+ .on("mouseup,mouseleave", (e2) => {
e2.preventDefault();
e2.stopPropagation();
endCursor(e2);
});
- node.on('mouseup', (e2) => {
+ node.on("mouseup", (e2) => {
endCursor(e2);
});
@@ -997,7 +999,7 @@ class NodeView extends View {
contentContainer.content(content);
}
- contentContainer.on('mousedown', (e) => {
+ contentContainer.on("mousedown", (e) => {
if (this.fixed) return;
this.editor.focusNode(this);
});
@@ -1007,10 +1009,10 @@ class NodeView extends View {
this.constructor.hideHeader ||
this.constructor.fixedHeader
) {
- contentContainer.css('height', '100%');
+ contentContainer.css("height", "100%");
} else {
contentContainer.css(
- 'height',
+ "height",
`calc(100% - ${this.constructor.headerHeight}px)`
);
}
@@ -1019,468 +1021,6 @@ class NodeView extends View {
nodeContainer.append(contentContainer);
return node;
}
-
- /**
- * Determines if a node can be merged with another.
- * @return bool True if these nodes can be merged.
- */
- canMergeWith(node) {
- return (
- this instanceof node.constructor ||
- (node instanceof CompoundNodeView && node.canMergeWith(this))
- ) && node.canMerge && this.canMerge && !node.removed && !this.removed;
- }
-
- /**
- * Merges with a target node.
- */
- mergeWith(node) {
- if (!this.canMergeWith(node)) {
- if (!this.canMerge) {
- console.warn("This node is tagged as unmergeable.", this);
- }
- if (!node.canMerge) {
- console.warn("The target node is tagged as unmergeable.", node);
- }
- throw "Nodes cannot be merged.";
- }
- if (node instanceof CompoundNodeView) {
- this.fixed = true;
- this.canMerge = false;
- return node.mergeWith(this);
- } else {
- this.fixed = true;
- node.fixed = true;
- this.canMerge = false;
- node.canMerge = false;
- let compoundNodeClass = this.constructor.compoundNodeClass || CompoundNodeView;
- return new compoundNodeClass(
- this.editor,
- "Merged Node",
- new CompoundNodeContentView(this.config, [this, node]),
- this.left,
- this.top,
- this.width,
- this.height
- );
- }
- }
-}
-
-/**
- * The OptionsNodeView extends the NodeView by additionally offering a place for
- * a form or an input to go.
- */
-class OptionsNodeView extends NodeView {
- /**
- * @var FormView|InputView The options view.
- */
- static nodeOptions = null;
-
- /**
- * @var int The height of the options node in pixels.
- */
- static optionsHeight = 0;
-
- constructor(editor, name, content, left, top, width, height) {
- super(editor, name, content, left, top, width, height);
- this.options = this.constructor.nodeOptions;
- }
-
- /**
- * Overridde setState() to additionally populate the form/input.
- */
- async setState(newState) {
- await super.setState(newState);
- if (typeof this.options == 'function') {
- this.options = new this.options(this.config);
- }
-
- if (!isEmpty(newState.options)) {
- if (isEmpty(this.options)) {
- console.warn('Options passed, but no options present on node.');
- } else {
- if (this.options instanceof InputView) {
- let setValue =
- typeof newState.options == 'object' &&
- !isEmpty(newState.options.default)
- ? newState.options.default
- : newState.options;
- await this.options.setValue(setValue);
- } else if (this.options instanceof FormView) {
- await this.options.setValues(newState.options);
- this.options.submit();
- } else {
- this.options = newState.options;
- }
- }
- }
- return this;
- }
-
- /**
- * Override getContent() to additionally embed the options.
- */
- async getContent() {
- let node = E.nodeOptionsContents(),
- optionsNode = E.nodeOptions().css(
- 'height',
- `${this.constructor.optionsHeight}px`
- );
-
- if (!isEmpty(this.options)) {
- if (typeof this.options == 'function') {
- this.options = new this.options(this.config);
- optionsNode.append(await this.options.getNode());
- } else if (this.options instanceof View) {
- optionsNode.append(await this.options.getNode());
- } else {
- optionsNode.append(this.options);
- }
- }
-
- node.append(optionsNode, await super.getContent());
-
- return node;
- }
-
- /**
- * Override getState to additionally include the data from the form/input.
- */
- getState() {
- let parentState = super.getState();
-
- parentState.options = isEmpty(this.options)
- ? null
- : this.options instanceof InputView
- ? this.options.getValue()
- : this.options instanceof FormView
- ? this.options.values
- : this.options;
-
- return parentState;
- }
-}
-
-/**
- * Provides a view to select the current visible node in a compound node
- */
-class CompoundNodeContentView extends View {
- /**
- * @var string Custom tag name
- */
- static tagName = "enfugue-node-compound-contents";
-
- /**
- * On construct, pass children
- */
- constructor(config, children) {
- super(config);
- this.children = children;
- this.activeIndex = 0;
- this.chooser = new ListInputView(config, "activeNode", {"options": this.options, "value": "node-0"});
- this.chooser.onChange(() => {
- this.setActiveIndex(parseInt(this.chooser.value.split("-")[1]), false);
- });
- }
-
- /**
- * Sets the index of the active node
- */
- async setActiveIndex(newIndex, updateChooser = true){
- this.activeIndex = newIndex;
- if (updateChooser) {
- this.chooser.setValue(`node-${newIndex}`, false);
- }
- if (this.node !== undefined) {
- this.node.content(
- await this.chooser.getNode(),
- await this.selectedNode.getNode()
- );
- }
- }
-
- /**
- * Get the selected node
- */
- get selectedNode() {
- return this.children[this.activeIndex];
- }
-
- /**
- * Get options for the chooser
- */
- get options() {
- return this.children.reduce((carry, item, index) => {
- carry[`node-${index}`] = `${index+1}. ${item.name}`;
- return carry;
- }, {});
- }
-
- /**
- * When adding a node, re-build selector if needed but don't select
- */
- addNode(newNode){
- this.children.push(newNode);
- this.chooser.setOptions(this.options);
- }
-
- /**
- * When removing a node, it gets popped out back onto the editor
- */
- removeNode(newNode){
- let childIndex = this.children.indexOf(newNode);
- if (childIndex !== -1) {
- this.children = this.children.slice(0, childIndex).concat(this.children.slice(childIndex+1));
- }
- this.chooser.setOptions(this.options);
- if (this.activeIndex >= this.children.length) {
- this.setActiveIndex(this.children.length - 1);
- } else if(this.activeIndex == childIndex) {
- this.setActiveIndex(this.activeIndex); // Reset to build again
- }
- }
-
- /**
- * Add a setDimension similar to NodeView.setDimension that calls child functions
- */
- setDimension(width, height, save) {
- for (let child of this.children) {
- child.setDimension(
- -child.constructor.padding,
- -child.constructor.padding,
- width,
- height,
- save,
- false
- );
- child.resized();
- }
- }
-
- /**
- * On build, get selector + current node header + current node contents
- */
- async build() {
- let node = await super.build();
- node.content(
- await this.chooser.getNode(),
- await this.selectedNode.getNode()
- );
- return node;
- }
-}
-
-/**
- * Extend the node view to allow nodes to be compounded into tabs
- */
-class CompoundNodeView extends NodeView {
- /**
- * @var bool Disable copying compound nodes
- */
- static canCopy = false;
-
- /**
- * @var bool enable merging on already merged nodes
- */
- static canMerge = true;
-
- /**
- * On construct, bind content events and hijack button build
- */
- constructor(editor, name, content, left, top, width, height) {
- super(editor, name, content, left, top, width, height);
- if (!isEmpty(content)) {
- for (let node of content.children) {
- node.rebuildHeaderButtons = () => this.rebuildHeaderButtons();
- }
- this.content.chooser.onChange(() => {
- this.rebuildHeaderButtons();
- });
- }
- }
-
- /**
- * Determines if a node can be merged with another.
- * @return bool True if these nodes can be merged.
- */
- canMergeWith(node) {
- return this.content.children[0] instanceof node.constructor;
- }
-
- /**
- * Merges this node with another.
- */
- mergeWith(node) {
- // Hijack button build
- node.rebuildHeaderButtons = () => this.rebuildHeaderButtons();
- this.content.addNode(node);
- node.setDimension(
- -node.constructor.padding,
- -node.constructor.padding,
- this.width,
- this.height,
- true,
- false
- );
- return this;
- }
-
- /**
- * Rebuilds a node and re-adds it to the canvas.
- */
- async rebuildMergedNode(node, offset) {
- if (offset === undefined) offset = this.constructor.snapSize;
- let nodeState = node.getState();
- nodeState.x = this.left + offset;
- nodeState.y = this.top + offset;
- nodeState.w = this.width;
- nodeState.h = this.height;
- let newNode = new node.constructor(this.editor);
- await newNode.setState(nodeState);
- this.editor.addNode(newNode);
- return newNode;
- }
-
-
- /**
- * Unmerges the currently active node.
- */
- async unmergeNode() {
- let nodeToRemove = this.content.selectedNode;
- this.content.removeNode(nodeToRemove);
- this.rebuildMergedNode(nodeToRemove);
- if (this.content.children.length <= 1) {
- nodeToRemove = this.content.children[0];
- this.rebuildMergedNode(nodeToRemove, 0);
- this.editor.removeNode(this);
- }
- }
-
- /**
- * When setting dimensions for a compound node, similarly trigger on children
- */
- setDimension(left, top, width, height, save, triggerMerge) {
- super.setDimension(left, top, width, height, save, false);
- if (!isEmpty(this.content)) {
- this.content.setDimension(this.visibleWidth, this.visibleHeight, save, triggerMerge);
- }
- }
-
- /**
- * Gets the buttons for the current selected node
- */
- get selectedButtons() {
- let selectedNodeButtons = isEmpty(this.content) || isEmpty(this.content.selectedNode)
- ? {}
- : this.content.selectedNode.getButtons();
- if (isEmpty(selectedNodeButtons)) {
- selectedNodeButtons = {};
- } else {
- delete selectedNodeButtons.copy;
- delete selectedNodeButtons.flip;
- delete selectedNodeButtons.close;
- }
- return selectedNodeButtons;
- }
-
- /**
- * Gets the current button set plus current node buttons plus popout button
- */
- getButtons() {
- let x = merge(
- this.selectedButtons,
- {
- unmerge: {
- icon: "fa-solid fa-up-right-from-square",
- tooltip: "Unmerge Image",
- shorcut: "g",
- context: this,
- disabled: false,
- callback: () => {
- this.unmergeNode();
- }
- }
- },
- super.getButtons()
- );
- return x;
- }
-
- /**
- * Intercept buildHeaderButtons to add our current nodes buttons
- */
- buildHeaderButtons(nodeHeader, buttons) {
- return super.buildHeaderButtons(
- nodeHeader,
- merge(
- this.selectedButtons,
- {
- unmerge: {
- icon: "fa-solid fa-up-right-from-square",
- tooltip: "Unmerge Image",
- shorcut: "g",
- context: this,
- disabled: false,
- callback: () => {
- this.unmergeNode();
- }
- }
- },
- buttons
- )
- );
- }
-
- /**
- * On getState, return the child node state.
- */
- getState() {
- let baseState = super.getState();
- baseState.children = this.content.children.map((child) => {
- let childState = child.getState.apply(child, Array.from(arguments));
- delete childState.x;
- delete childState.y;
- delete childState.w;
- delete childState.h;
- return childState;
- });
- baseState.active = this.content.activeIndex;
- return baseState;
- }
-
- /**
- * On setState, set this nodes state then trigger child nodes */
- async setState(newState) {
- await super.setState(newState);
- let childNodes = [];
- if (!isEmpty(newState.children)) {
- for (let childState of newState.children) {
- let nodeClass = this.editor.getNodeClass(childState.classname),
- newNode = new nodeClass(this.editor);
- childState.x = 0;
- childState.y = 0;
- childState.w = newState.w;
- childState.h = newState.h;
- await newNode.setState(childState);
- newNode.rebuildHeaderButtons = () => this.rebuildHeaderButtons();
- childNodes.push(newNode);
- }
- }
-
- await this.setContent(new CompoundNodeContentView(this.editor.config, childNodes));
- this.content.chooser.onChange(() => {
- setTimeout(() => this.rebuildHeaderButtons(), 150);
- });
- if (newState.active !== undefined) {
- this.content.setActiveIndex(newState.active);
- }
- setTimeout(() => this.rebuildHeaderButtons(), 250);
- }
}
-export {
- NodeView,
- CompoundNodeView,
- OptionsNodeView
-};
+export { NodeView };
diff --git a/src/js/nodes/editor.mjs b/src/js/nodes/editor.mjs
index b9b65aff..dd7ebaad 100644
--- a/src/js/nodes/editor.mjs
+++ b/src/js/nodes/editor.mjs
@@ -9,12 +9,7 @@ import {
NodeEditorDecorationsView,
NodeConnectionSpline
} from './decorations.mjs';
-
-import {
- NodeView,
- OptionsNodeView,
- CompoundNodeView
-} from './base.mjs';
+import { NodeView } from './base.mjs';
const E = new ElementBuilder({
node: "enfugue-node",
@@ -93,7 +88,7 @@ class NodeEditorView extends View {
/**
* @var array All supported node classes. Used when re-instantiating from static data.
*/
- static nodeClasses = [NodeView, OptionsNodeView, CompoundNodeView];
+ static nodeClasses = [NodeView];
/**
* @var array Any number of classes
@@ -136,8 +131,12 @@ class NodeEditorView extends View {
this.left = 0;
this.top = 0;
}
+
this.nodes = [];
this.nodeClasses = [].concat(this.constructor.nodeClasses);
+ this.nodeFocusCallbacks = [];
+ this.nodeCopyCallbacks = [];
+ this.setDimensionCallbacks = [];
this.decorations = new NodeEditorDecorationsView(
config,
@@ -150,6 +149,22 @@ class NodeEditorView extends View {
window.addEventListener('resize', (e) => this.windowResized(e));
}
+ /**
+ * Gets a unique name for a node, adding numbers if needed.
+ *
+ * @param string $name The name of the node.
+ */
+ getUniqueNodeName(name) {
+ let currentName = name,
+ currentNames = this.nodes.map((node) => node.getName()),
+ duplicates = 1;
+
+ while (currentNames.indexOf(currentName) !== -1) {
+ currentName = `${name} ${++duplicates}`;
+ }
+ return currentName;
+ }
+
/**
* @param callable $callback A callback to perform when the window is resized
*/
@@ -157,6 +172,20 @@ class NodeEditorView extends View {
this.resizeCallbacks.push(callback);
}
+ /**
+ * @param callable $callback A callback to perform when a node is focused
+ */
+ onNodeFocus(callback) {
+ this.nodeFocusCallbacks.push(callback);
+ }
+
+ /**
+ * @param callable $callback A callback to perform when a node is copied
+ */
+ onNodeCopy(callback) {
+ this.nodeCopyCallbacks.push(callback);
+ }
+
/**
* Called when the window is resized.
*/
@@ -260,13 +289,21 @@ class NodeEditorView extends View {
}
}
+ /**
+ * Adds a callback when dimensions are set
+ * @param callable $callback The function to execute
+ */
+ onSetDimension(callback) {
+ this.setDimensionCallbacks.push(callback);
+ }
+
/**
* Sets a new width and height for this editor.
* @param int $newWidth The new width to set.
* @param int $newHeight The new height to set.
* @param bool $resetNodes Whether or not to reset the dimensions of the nodes on this canvas.
*/
- setDimension(newWidth, newHeight, resetNodes = true) {
+ setDimension(newWidth, newHeight, resetNodes = true, triggerCallbacks = false) {
if (isEmpty(newWidth)){
newWidth = this.canvasWidth;
}
@@ -288,6 +325,11 @@ class NodeEditorView extends View {
}
}
}
+ if (triggerCallbacks) {
+ for (let callback of this.setDimensionCallbacks) {
+ callback(newWidth, newHeight);
+ }
+ }
}
/**
@@ -304,45 +346,7 @@ class NodeEditorView extends View {
* @param Node $movedNode The node that was moved.
*/
nodeMoved(movedNode) {
- let mergeNode;
- for (let node of this.nodes) {
- node.removeClass("merge-source");
- node.removeClass("merge-target");
-
- // Check if nodes have zero intersection
- if (node == movedNode ||
- !node.canMergeWith(movedNode) ||
- movedNode.visibleLeft >= node.visibleRight ||
- movedNode.visibleRight <= node.visibleLeft ||
- movedNode.visibleTop >= node.visibleTop + node.visibleHeight ||
- movedNode.visibleTop + movedNode.visibleHeight <= node.visibleTop
- ) {
- continue;
- }
-
- // Check if dragged header is near target header
- if (Math.abs(movedNode.visibleTop - node.visibleTop) > movedNode.constructor.headerHeight / 2) {
- continue;
- }
-
- // Check if dragged node is sufficiently intersected by canvas node
- let intersectLeft = Math.max(movedNode.visibleLeft, node.visibleLeft),
- intersectTop = Math.max(movedNode.visibleTop, node.visibleTop),
- intersectRight = Math.min(movedNode.visibleLeft + movedNode.visibleWidth, node.visibleLeft + node.visibleWidth),
- intersectBottom = Math.min(movedNode.visibleTop + movedNode.visibleHeight, node.visibleTop + node.visibleHeight),
- intersectArea = (intersectRight - intersectLeft) * (intersectBottom - intersectTop),
- intersectRatio = intersectArea / (movedNode.visibleWidth * movedNode.visibleHeight);
-
- if (intersectRatio >= 0.33) {
- // Set node merge targets
- mergeNode = node;
- }
- }
-
- if (!isEmpty(mergeNode)) {
- movedNode.addClass("merge-source");
- mergeNode.addClass("merge-target");
- }
+ // TODO
}
/**
@@ -351,35 +355,7 @@ class NodeEditorView extends View {
*/
nodePlaced(node) {
this.nodeMoved(node);
- let sourceNode, targetNode;
- for (let childNode of this.nodes) {
- if (childNode.hasClass("merge-source")) {
- sourceNode = childNode;
- }
- if (childNode.hasClass("merge-target")) {
- targetNode = childNode;
- }
- }
- if (!isEmpty(sourceNode) && !isEmpty(targetNode)) {
- this.mergeNodes(sourceNode, targetNode);
- }
- }
-
- /**
- * Calls callbacks for when a node is placed (released somewhere or programmatically set)
- * @param Node $movedNode The node that was placed.
- */
- mergeNodes(sourceNode, targetNode) {
- sourceNode.removeClass("merge-source");
- targetNode.removeClass("merge-target");
- try {
- let mergedNode = sourceNode.mergeWith(targetNode);
- this.removeNode(sourceNode);
- this.removeNode(targetNode);
- this.addNode(mergedNode);
- } catch(e) {
- console.log("Experienced error merging nodes, ignoring", e);
- }
+ // TODO
}
/**
@@ -450,9 +426,6 @@ class NodeEditorView extends View {
newNode = nodeClass;
nodeClass.editor = this;
}
-
- let enableMerge = newNode.canMerge;
- newNode.canMerge = false;
this.nodes.push(newNode);
this.nodes = this.nodes.map((v, i) => {
v.index = i;
@@ -466,7 +439,6 @@ class NodeEditorView extends View {
canvas.append(childNode);
}
- newNode.canMerge = enableMerge;
return newNode;
}
@@ -495,20 +467,33 @@ class NodeEditorView extends View {
}
/**
- * Focus on an individual node by reordering it.
- * Pops out of the node array and DOM, then adds at the end.
- * @param object $node The node to focus on.
+ * Triggers callbacks for node focus.
+ */
+ async focusNode(node) {
+ for (let focusCallback of this.nodeFocusCallbacks) {
+ await focusCallback(node);
+ }
+ return;
+ }
+
+ /**
+ * Re-orders a node.
*/
- focusNode(node) {
- let nodeIndex = this.nodes.indexOf(node);
- if (nodeIndex === this.nodes.length - 1 || nodeIndex === -1) {
+ reorderNode(index, node) {
+ let currentNodeIndex = this.nodes.indexOf(node);
+ if (currentNodeIndex === -1) {
+ console.error("Couldn't reorder node, not found in array.");
return;
}
- this.nodes = this.nodes.slice(0, nodeIndex).concat(this.nodes.slice(nodeIndex + 1));
- this.nodes.push(node);
+ this.nodes = this.nodes.slice(0, currentNodeIndex).concat(this.nodes.slice(currentNodeIndex + 1));
+ this.nodes.splice(index, 0, node);
let nodeCanvas = this.node.find(E.getCustomTag("nodeCanvas"));
- nodeCanvas.remove(node.node).append(node.node);
- return;
+ nodeCanvas.remove(node.node);
+ if (index > currentNodeIndex) {
+ nodeCanvas.insert(index + 3, node.node);
+ } else {
+ nodeCanvas.insert(index + 2, node.node);
+ }
}
/**
@@ -518,13 +503,26 @@ class NodeEditorView extends View {
async copyNode(node) {
let data = node.getState(),
newNode = await this.addNode(node.constructor);
+ data.name += " (copy)";
data.x += node.constructor.padding;
data.y += node.constructor.padding;
await newNode.setState(data);
+ for (let copyCallback of this.nodeCopyCallbacks) {
+ await copyCallback(newNode, node);
+ }
this.focusNode(newNode);
return newNode;
}
+ /**
+ * Resets position and zoom.
+ */
+ resetCanvasPosition() {
+ if (!isEmpty(this.node)) {
+ this.node.find(E.getCustomTag("zoomReset")).trigger("click");
+ }
+ }
+
/**
* The build function creates nodes and binds handlers.
*/
@@ -585,6 +583,7 @@ class NodeEditorView extends View {
}
node.append(canvas);
+
if (this.constructor.disableCursor) {
node.css('pointer-events', 'none');
} else {
@@ -622,8 +621,8 @@ class NodeEditorView extends View {
let canvasReadoutX, canvasReadoutY;
if (this.constructor.centered) {
- let canvasCenterX = (this.constructor.canvasWidth / 2) - (node.element.clientWidth / 2),
- canvasCenterY = (this.constructor.canvasHeight / 2) - (node.element.clientHeight / 2);
+ let canvasCenterX = (this.width / 2) - (node.element.clientWidth / 2),
+ canvasCenterY = (this.height / 2) - (node.element.clientHeight / 2);
canvasReadoutX = -canvasCenterX - newX / this.zoom,
canvasReadoutY = -canvasCenterY - newY / this.zoom;
@@ -659,8 +658,8 @@ class NodeEditorView extends View {
let canvasReadoutX, canvasReadoutY;
if (this.constructor.centered) {
- let canvasCenterX = (this.constructor.canvasWidth / 2) - (node.element.clientWidth / 2),
- canvasCenterY = (this.constructor.canvasHeight / 2) - (node.element.clientHeight / 2);
+ let canvasCenterX = (this.width / 2) - (node.element.clientWidth / 2),
+ canvasCenterY = (this.height / 2) - (node.element.clientHeight / 2);
canvasReadoutX = -canvasCenterX - this.left / this.zoom,
canvasReadoutY = -canvasCenterY - this.top / this.zoom;
@@ -681,8 +680,8 @@ class NodeEditorView extends View {
e.preventDefault();
e.stopPropagation();
if (this.constructor.centered) {
- this.left = -(this.constructor.canvasWidth / 2) * this.zoom + node.element.clientWidth / 2;
- this.top = -(this.constructor.canvasHeight / 2) * this.zoom + node.element.clientHeight / 2;
+ this.left = -(this.width / 2) * this.zoom + node.element.clientWidth / 2;
+ this.top = -(this.height / 2) * this.zoom + node.element.clientHeight / 2;
} else {
this.left = 0;
this.top = 0;
@@ -721,8 +720,8 @@ class NodeEditorView extends View {
let canvasReadoutX, canvasReadoutY;
if (this.constructor.centered) {
- let canvasCenterX = (this.constructor.canvasWidth / 2) - (node.element.clientWidth / 2),
- canvasCenterY = (this.constructor.canvasHeight / 2) - (node.element.clientHeight / 2);
+ let canvasCenterX = (this.width / 2) - (node.element.clientWidth / 2),
+ canvasCenterY = (this.height / 2) - (node.element.clientHeight / 2);
canvasReadoutX = -canvasCenterX - this.left / this.zoom,
canvasReadoutY = -canvasCenterY - this.top / this.zoom;
diff --git a/src/js/nodes/image-editor.mjs b/src/js/nodes/image-editor.mjs
index e60375bd..60d960be 100644
--- a/src/js/nodes/image-editor.mjs
+++ b/src/js/nodes/image-editor.mjs
@@ -2,16 +2,14 @@
import { isEmpty, filterEmpty } from "../base/helpers.mjs";
import { ElementBuilder } from "../base/builder.mjs";
import { NodeEditorView } from "./editor.mjs";
-import { ImageView } from "../view/image.mjs";
-import { CompoundNodeView } from "./base.mjs";
+import { ImageView, BackgroundImageView } from "../view/image.mjs";
+import { VideoView } from "../view/video.mjs";
import { ImageEditorNodeView } from "./image-editor/base.mjs";
import { ImageEditorScribbleNodeView } from "./image-editor/scribble.mjs";
import { ImageEditorPromptNodeView } from "./image-editor/prompt.mjs";
-import {
- ImageEditorImageNodeView,
- ImageEditorCompoundImageNodeView
-} from "./image-editor/image.mjs";
-import { CurrentInvocationImageView } from "./image-editor/invocation.mjs";
+import { ImageEditorImageNodeView } from "./image-editor/image.mjs";
+import { ImageEditorVideoNodeView } from "./image-editor/video.mjs";
+import { NoImageView, NoVideoView } from "./image-editor/common.mjs";
const E = new ElementBuilder();
@@ -26,8 +24,6 @@ class ImageEditorView extends NodeEditorView {
constructor(application) {
super(application.config, window.innerWidth-300, window.innerHeight-70);
this.application = application;
- this.currentInvocation = new CurrentInvocationImageView(this);
- this.currentInvocation.hide();
}
/**
@@ -59,11 +55,9 @@ class ImageEditorView extends NodeEditorView {
* @var array The node classes for state set/get
*/
static nodeClasses = [
- CompoundNodeView,
ImageEditorScribbleNodeView,
ImageEditorImageNodeView,
- ImageEditorPromptNodeView,
- ImageEditorCompoundImageNodeView
+ ImageEditorVideoNodeView,
];
/**
@@ -72,10 +66,10 @@ class ImageEditorView extends NodeEditorView {
async focusNode(node) {
super.focusNode(node);
this.focusedNode = node;
- this.application.menu.removeCategory("Node");
+ this.application.menu.removeCategory("Element");
let nodeButtons = node.getButtons();
if (!isEmpty(nodeButtons)) {
- let menuCategory = await this.application.menu.addCategory("Node", "n");
+ let menuCategory = await this.application.menu.addCategory("Element", "e");
for (let buttonName in nodeButtons) {
let buttonConfiguration = nodeButtons[buttonName];
let menuItem = await menuCategory.addItem(
@@ -95,54 +89,7 @@ class ImageEditorView extends NodeEditorView {
super.removeNode(node);
if (this.focusedNode === node) {
this.focusedNode = null;
- this.application.menu.removeCategory("Node");
- }
- }
-
- /**
- * Removes the current invocation from the canvas view.
- */
- hideCurrentInvocation() {
- this.currentInvocation.hide();
- if (this.hasClass("has-image")) {
- this.removeClass("has-image");
- this.application.menu.removeCategory("Image");
- }
- this.resetDimension(false);
- }
-
- /**
- * Resets the editor to the previous set of dimensions
- */
- resetDimension(resetNodes = true) {
- if (!isEmpty(this.configuredWidth) && !isEmpty(this.configuredHeight)) {
- this.setDimension(this.configuredWidth, this.configuredHeight, resetNodes);
- this.configuredHeight = null;
- this.configuredWidth = null;
- }
- }
-
- /**
- * Sets a current invocation on the canvas view.
- * @param string $href The image source.
- */
- async setCurrentInvocationImage(href) {
- this.currentInvocation.setImage(href);
- await this.currentInvocation.waitForLoad();
- if (this.currentInvocation.width != this.width || this.currentInvocation.height != this.height) {
- if (isEmpty(this.configuredWidth)) {
- this.configuredWidth = this.width;
- }
- if (isEmpty(this.configuredHeight)) {
- this.configuredHeight = this.height;
- }
- this.setDimension(this.currentInvocation.width, this.currentInvocation.height, false);
- }
- this.currentInvocation.show();
- if (!this.hasClass("has-image")) {
- this.addClass("has-image");
- let menuCategory = await this.application.menu.addCategory("Image", "e");
- await this.currentInvocation.prepareMenu(menuCategory);
+ this.application.menu.removeCategory("Element");
}
}
@@ -170,50 +117,73 @@ class ImageEditorView extends NodeEditorView {
if (imageSource instanceof ImageView) {
imageView = imageSource;
+ } else if (!isEmpty(imageSource)) {
+ imageView = new BackgroundImageView(this.config, imageSource, false);
} else {
- imageView = new ImageView(this.config, imageSource);
+ imageView = new NoImageView(this.config);
+ }
+
+ if (imageView instanceof ImageView) {
+ await imageView.waitForLoad();
}
- await imageView.waitForLoad();
let newNode = await this.addNode(
ImageEditorImageNodeView,
- imageName,
+ this.getUniqueNodeName(imageName),
imageView,
x,
y,
imageView.width,
imageView.height
);
- setTimeout(() => newNode.toggleOptions(), 500);
+
return newNode;
}
/**
- * This is a shorthand helper for adding a scribble node.
- * @return NodeView The added view
+ * This is a shorthand helper functinon for adding a video URL.
+ * @param string $videoSource The source of the video - likely a data URL.
+ * @return NodeView The added view.
*/
- async addScribbleNode(scribbleName = "Scribble") {
- let [x, y] = this.getNextNodePoint();
- return await this.addNode(
- ImageEditorScribbleNodeView,
- scribbleName,
- null,
+ async addVideoNode(videoSource, videoName = "Video") {
+ let videoView = null,
+ [x, y] = this.getNextNodePoint();
+
+ if (videoSource instanceof VideoView) {
+ videoView = videoSource;
+ } else if (!isEmpty(videoSource)) {
+ videoView = new VideoView(this.config, videoSource, false);
+ } else {
+ videoView = new NoVideoView(this.config);
+ }
+
+ if (videoView instanceof VideoView) {
+ await videoView.waitForLoad();
+ }
+
+ let newNode = await this.addNode(
+ ImageEditorVideoNodeView,
+ this.getUniqueNodeName(videoName),
+ videoView,
x,
y,
- 256,
- 256
+ videoView.width,
+ videoView.height
);
+
+ return newNode;
}
-
+
+
/**
- * This is a shorthand helper for adding a prompt node.
- * @return NodeView The added view.
+ * This is a shorthand helper for adding a scribble node.
+ * @return NodeView The added view
*/
- async addPromptNode(promptName = "Prompt") {
+ async addScribbleNode(scribbleName = "Scribble") {
let [x, y] = this.getNextNodePoint();
return await this.addNode(
- ImageEditorPromptNodeView,
- promptName,
+ ImageEditorScribbleNodeView,
+ this.getUniqueNodeName(scribbleName),
null,
x,
y,
@@ -227,41 +197,12 @@ class ImageEditorView extends NodeEditorView {
*/
async build() {
let node = await super.build(),
+ overlays = E.createElement("enfugue-image-editor-overlay"),
grid = E.createElement("enfugue-image-editor-grid");
- node.find("enfugue-node-canvas").append(grid, await this.currentInvocation.getNode());
+ node.find("enfugue-node-canvas").append(overlays.content(grid));
return node;
}
- /**
- * Get state, includes current invocation
- */
- getState(includeImages = true) {
- let state = super.getState(includeImages);
- if (this.hasClass("has-image") && includeImages) {
- return {
- "image": this.currentInvocation.src,
- "nodes": state
- };
- }
- return state;
- }
-
- /**
- * Set state, may include current invocation
- */
- setState(newState) {
- if (Array.isArray(newState)) {
- this.hideCurrentInvocation();
- return super.setState(newState);
- }
- if (isEmpty(newState.image)) {
- this.hideCurrentInvocation();
- } else {
- this.setCurrentInvocationImage(newState.image);
- }
- return super.setState(newState.nodes);
- }
-
/**
* Gets base state when initializing from an image
*/
@@ -284,5 +225,4 @@ export {
ImageEditorNodeView,
ImageEditorImageNodeView,
ImageEditorScribbleNodeView,
- ImageEditorPromptNodeView
};
diff --git a/src/js/nodes/image-editor/base.mjs b/src/js/nodes/image-editor/base.mjs
index b66c7fd8..a99b6e8c 100644
--- a/src/js/nodes/image-editor/base.mjs
+++ b/src/js/nodes/image-editor/base.mjs
@@ -1,17 +1,11 @@
/** @module nodes/image-editor/base.mjs */
import { isEmpty } from "../../base/helpers.mjs";
-import { ImageEditorNodeOptionsFormView } from "../../forms/enfugue/image-editor.mjs";
import { NodeView } from "../base.mjs";
/**
* Nodes on the Image Editor use multiples of 8 instead of 10
*/
class ImageEditorNodeView extends NodeView {
- /**
- * @var bool Disable merging for most nodes
- */
- static canMerge = false;
-
/**
* @var string The name to show in the menu
*/
@@ -25,12 +19,12 @@ class ImageEditorNodeView extends NodeView {
/**
* @var int The minimum height, much smaller than normal minimum.
*/
- static minHeight = 32;
+ static minHeight = 64;
/**
* @var int The minimum width, much smaller than normal minimum.
*/
- static minWidth = 32;
+ static minWidth = 64;
/**
* @var int Change snap size from 10 to 8
@@ -47,11 +41,6 @@ class ImageEditorNodeView extends NodeView {
*/
static edgeHandlerTolerance = 8;
- /**
- * @var bool All nodes on the image editor try to be as minimalist as possible.
- */
- static hideHeader = true;
-
/**
* @var string Change from 'Close' to 'Remove'
*/
@@ -62,77 +51,61 @@ class ImageEditorNodeView extends NodeView {
* @see view/nodes/base
*/
static nodeButtons = {
- options: {
- icon: "fa-solid fa-sliders",
- tooltip: "Show/Hide Options",
- shortcut: "o",
- callback: function() {
- this.toggleOptions();
+ "nodeToCanvas": {
+ "icon": "fa-solid fa-maximize",
+ "tooltip": "Scale to Canvas Size",
+ "shortcut": "z",
+ "callback": function() {
+ this.scaleToCanvasSize();
+ }
+ },
+ "canvasToNode": {
+ "icon": "fa-solid fa-minimize",
+ "tooltip": "Scale Canvas to Image Size",
+ "shortcut": "g",
+ "callback": function() {
+ this.scaleCanvasToSize();
}
}
};
/**
- * @var class The form to use. Each node should have their own.
- */
- static optionsFormView = ImageEditorNodeOptionsFormView;
-
- /**
- * Can be overridden in the node classes; this is called when their options are changed.
+ * Gets the size to scale to, can be overridden
*/
- async updateOptions(values) {
- this.prompt = values.prompt;
- this.negativePrompt = values.negativePrompt;
- this.guidanceScale = values.guidanceScale;
- this.inferenceSteps = values.inferenceSteps;
- this.scaleToModelSize = values.scaleToModelSize;
- this.removeBackground = values.removeBackground;
+ async getCanvasScaleSize() {
+ return [
+ this.width - this.constructor.padding*2,
+ this.height - this.constructor.padding*2
+ ];
}
/**
- * Shows the options view.
+ * Scales the image up to the size of the canvas
*/
- async toggleOptions() {
- if (isEmpty(this.optionsForm)) {
- this.optionsForm = new this.constructor.optionsFormView(this.config);
- this.optionsForm.onSubmit((values) => this.updateOptions(values));
- let optionsNode = await this.optionsForm.getNode();
- this.optionsForm.setValues(this.getState(), false);
- this.node.find("enfugue-node-contents").append(optionsNode);
- } else if (this.optionsForm.hidden) {
- this.optionsForm.show();
- } else {
- this.optionsForm.hide();
- }
+ async scaleToCanvasSize() {
+ this.setDimension(
+ -this.constructor.padding,
+ -this.constructor.padding,
+ this.editor.width+this.constructor.padding*2,
+ this.editor.height+this.constructor.padding*2,
+ true
+ );
}
/**
- * When state is set, send to form
- */
- async setState(newState) {
- await super.setState(newState);
- this.updateOptions(newState);
-
- if (!isEmpty(this.optionsForm)) {
- this.optionsForm.setValues(newState);
- }
- }
-
- /**
- * Gets the base state and appends form values.
- */
- getState(includeImages = true) {
- let state = super.getState();
- state.prompt = this.prompt || null;
- state.negativePrompt = this.negativePrompt || null;
- state.guidanceScale = this.guidanceScale || null;
- state.inferenceSteps = this.inferenceSteps || null;
- state.removeBackground = this.removeBackground || false;
- state.scaleToModelSize = this.scaleToModelSize || false;
- return state;
+ * Scales the canvas size to this size
+ */
+ async scaleCanvasToSize() {
+ let [scaleWidth, scaleHeight] = await this.getCanvasScaleSize();
+ this.editor.setDimension(scaleWidth, scaleHeight, true, true);
+ this.setDimension(
+ -this.constructor.padding,
+ -this.constructor.padding,
+ scaleWidth+this.constructor.padding*2,
+ scaleHeight+this.constructor.padding*2,
+ true
+ );
}
};
-export {
- ImageEditorNodeView
-};
+export { ImageEditorNodeView };
diff --git a/src/js/nodes/image-editor/common.mjs b/src/js/nodes/image-editor/common.mjs
new file mode 100644
index 00000000..19127ddf
--- /dev/null
+++ b/src/js/nodes/image-editor/common.mjs
@@ -0,0 +1,52 @@
+/** @module nodes/image-editor/common */
+import { ElementBuilder } from "../../base/builder.mjs";
+import { View } from "../../view/base.mjs";
+
+const E = new ElementBuilder();
+
+/**
+ * This class is a placeholder for an image
+ */
+class NoImageView extends View {
+ /**
+ * @var string The tag name
+ */
+ static tagName="enfugue-placeholder-view";
+
+ /**
+ * @var string The icon
+ */
+ static placeholderIcon = "fa-solid fa-link-slash";
+
+ /**
+ * @var string The text
+ */
+ static placeholderText = "No image, use the options menu to add one.";
+
+ /**
+ * On build, append text and icon
+ */
+ async getNode() {
+ let node = await super.build();
+ node.content(
+ E.i().class(this.constructor.placeholderIcon),
+ E.p().content(this.constructor.placeholderText)
+ );
+ return node;
+ }
+}
+
+/**
+ * This class is a placeholder for a video
+ */
+class NoVideoView extends NoImageView {
+ /**
+ * @var string The text
+ */
+ static placeholderText = "No video, use the options menu to add one.";
+}
+
+export {
+ NoImageView,
+ NoVideoView
+}
diff --git a/src/js/nodes/image-editor/image.mjs b/src/js/nodes/image-editor/image.mjs
index 1e8809a4..effdf40c 100644
--- a/src/js/nodes/image-editor/image.mjs
+++ b/src/js/nodes/image-editor/image.mjs
@@ -1,292 +1,19 @@
-/** @module nodes/image-editor/image-node.mjs */
-import { isEmpty } from "../../base/helpers.mjs";
+/** @module nodes/image-editor/image.mjs */
+import { isEmpty, promptFiles } from "../../base/helpers.mjs";
import { View } from "../../view/base.mjs";
-import { ScribbleView } from "../../view/scribble.mjs";
import { ImageView, BackgroundImageView } from "../../view/image.mjs";
-import { ImageEditorImageNodeOptionsFormView } from "../../forms/enfugue/image-editor.mjs";
-import { CompoundNodeView } from "../base.mjs";
import { ImageEditorNodeView } from "./base.mjs";
-import { ImageEditorScribbleNodeView } from "./scribble.mjs";
-
-/**
- * Extend the compound node to help manage image merging settings
- */
-class ImageEditorCompoundImageNodeView extends CompoundNodeView {
- /**
- * @var bool Hide the header
- */
- static hideHeader = true;
-
- /**
- * @var int Modify snap size to 8
- */
- static snapSize = 8;
-
- /**
- * @var string The name to show in the menu
- */
- static nodeTypeName = "Images";
-
- /**
- * @var bool Enable header flipping
- */
- static canFlipHeader = true;
-
- /**
- * @var int Modify padding to 8
- */
- static padding = 8;
-
- /**
- * @var int Modify edge handler tolerance to 8
- */
- static edgeHandlerTolerance = 8;
-
- /**
- * @var int Increase min height
- */
- static minHeight = 32;
-
- /**
- * @var int Increase min width
- */
- static minWidth = 32;
-
- /**
- * @var string Change from 'Close' to 'Remove'
- */
- static closeText = "Remove";
-
- /**
- * @var array Methods to pass through (when calling from menu)
- */
- static passThroughMethods = [
- "clearMemory", "increaseSize", "decreaseSize",
- "togglePencilShape", "toggleEraser", "rotateClockwise",
- "rotateCounterClockwise", "mirrorHorizontally", "mirrorVertically",
- "toggleOptions"
- ];
-
- /**
- * On construct, bind pass-through methods.
- */
- constructor(config, editor, name, content, left, top, width, height) {
- super(config, editor, name, content, left, top, width, height);
- for (let methodName of this.constructor.passThroughMethods) {
- this[methodName] = function () {
- return this.content.selectedNode[methodName].apply(
- this.content.selectedNode,
- Array.from(arguments)
- );
- }
- }
- }
-}
-
-/**
- * A small class containing the scribble and image
- */
-class ImageScribbleView extends View {
- /**
- * @var string Custom tag name
- */
- static tagName = "enfugue-image-scribble-view";
-
- /**
- * On construct, add sub views
- */
- constructor(config, src, width, height) {
- super(config);
- this.src = src;
- if (!isEmpty(src)) {
- this.image = new BackgroundImageView(config, src);
- }
- this.scribble = new ScribbleView(config, width, height)
- this.clearScribble();
- }
-
- /**
- * @return bool If the scribble view is erasing
- */
- get isEraser() {
- return this.scribble.isEraser;
- }
-
- /**
- * @param bool If the scribble view is erasing
- */
- set isEraser(newIsEraser) {
- this.scribble.isEraser = newIsEraser;
- }
-
- /**
- * @return string The shape of the scribble tool
- */
- get shape() {
- return this.scribble.shape;
- }
-
- /**
- * @param string The shape of the scribble tool
- */
- set shape(newShape) {
- this.scribble.shape = newShape;
- }
-
- /**
- * Sets the scribble to an image source, then resizes
- */
- setScribble(source, width, height) {
- this.scribble.setMemory(source);
- this.scribble.resizeCanvas(width, height);
- this.scribble.show();
- }
-
- /**
- * Clears the scribble memory and hides it
- */
- clearScribble() {
- this.scribble.clearMemory();
- this.scribble.hide();
- }
-
- /**
- * Clears the scribble memory
- */
- clearMemory(){
- this.scribble.clearMemory();
- }
-
- /**
- * Increase the scribble size
- */
- increaseSize() {
- this.scribble.increaseSize();
- }
-
- /**
- * Decrease the scribble size
- */
- decreaseSize() {
- this.scribble.decreaseSize();
- }
-
- /**
- * Shows the scribble
- */
- showScribble() {
- this.scribble.show();
- }
-
- /**
- * Resizes the scribble canvas
- */
- resize(width, height) {
- this.scribble.resizeCanvas(width, height);
- }
-
- /**
- * @return string The data URI of the scribble
- */
- get scribbleSrc() {
- return this.scribble.src;
- }
-
- /**
- * @return string The data URI or source of the imgae
- */
- get imageSrc() {
- return isEmpty(this.image)
- ? this.src
- : this.image.src;
- }
-
- /**
- * Mirrors the image horizontally
- */
- mirrorHorizontally() {
- if (!isEmpty(this.image)) {
- return this.image.mirrorHorizontally();
- }
- }
-
- /**
- * Mirrors the image vertically
- */
- mirrorVertically() {
- if (!isEmpty(this.image)) {
- return this.image.mirrorVertically();
- }
- }
-
- /**
- * Rotates the image clockwise by 90 degrees
- */
- rotateClockwise() {
- if (!isEmpty(this.image)) {
- return this.image.rotateClockwise();
- }
- }
-
- /**
- * Rotates the image counter-clockwise by 90 degrees
- */
- rotateCounterClockwise() {
- if (!isEmpty(this.image)) {
- return this.image.rotateCounterClockwise();
- }
- }
-
- /**
- * Adds a class to the image node
- */
- addImageClass(className) {
- if (!isEmpty(this.image)) {
- this.image.addClass(className);
- }
- }
-
- /**
- * Removes a class from the image node
- */
- removeImageClass(className) {
- if (!isEmpty(this.image)) {
- this.image.removeClass(className);
- }
- }
-
- /**
- * On build, append child views
- */
- async build() {
- let node = await super.build();
- node.content(await this.scribble.getNode());
- if (!isEmpty(this.image)) {
- node.append(await this.image.getNode());
- }
- return node;
- }
-}
+import { NoImageView } from "./common.mjs";
/**
* When pasting images on the image editor, allow a few fit options
*/
class ImageEditorImageNodeView extends ImageEditorNodeView {
/**
- * @var bool Hide this header
+ * @var bool Hide header (position absolutely)
*/
static hideHeader = true;
- /**
- * @var bool Enable merging
- */
- static canMerge = true;
-
- /**
- * @var class A class to help manage merging images
- */
- static compoundNodeClass = ImageEditorCompoundImageNodeView;
-
/**
* @var string The name to show in the menu
*/
@@ -305,13 +32,6 @@ class ImageEditorImageNodeView extends ImageEditorNodeView {
"center-left", "center-center", "center-right",
"bottom-left", "bottom-center", "bottom-right"
];
-
- /**
- * @var array The node buttons that pertain to scribble.
- */
- static scribbleButtons = [
- "erase", "shape", "clear", "increase", "decrease"
- ];
/**
* @var string Add the classname for CSS
@@ -319,16 +39,19 @@ class ImageEditorImageNodeView extends ImageEditorNodeView {
static className = 'image-editor-image-node-view';
/**
- * @var object Buttons to control the scribble. Shortcuts are registered on the view itself.
+ * @var object Buttons to control the image.
*/
static nodeButtons = {
...ImageEditorNodeView.nodeButtons,
...{
- "shape": {"disabled": true, ...ImageEditorScribbleNodeView.nodeButtons.shape},
- "erase": {"disabled": true, ...ImageEditorScribbleNodeView.nodeButtons.erase},
- "clear": {"disabled": true, ...ImageEditorScribbleNodeView.nodeButtons.clear},
- "increase": {"disabled": true, ...ImageEditorScribbleNodeView.nodeButtons.increase},
- "decrease": {"disabled": true, ...ImageEditorScribbleNodeView.nodeButtons.decrease},
+ "replace-image": {
+ "icon": "fa-solid fa-upload",
+ "tooltip": "Replace Image",
+ "shortcut": "c",
+ "callback": function() {
+ this.replaceImage();
+ }
+ },
"mirror-x": {
"icon": "fa-solid fa-left-right",
"tooltip": "Mirror Horizontally",
@@ -364,112 +87,13 @@ class ImageEditorImageNodeView extends ImageEditorNodeView {
}
};
- /**
- * @var class The form for this node.
- */
- static optionsFormView = ImageEditorImageNodeOptionsFormView;
-
- /**
- * Intercept the constructor to set the contents to use view container.
- */
- constructor(editor, name, content, left, top, width, height) {
- super(
- editor,
- name,
- new ImageScribbleView(
- editor.config,
- isEmpty(content) ? null : content.src,
- width,
- height
- ),
- left,
- top,
- width,
- height
- );
- }
-
- /**
- * On resize, resize the content as well
- */
- async resized() {
- await super.resized();
- this.content.resize(
- this.visibleWidth - this.constructor.padding * 2,
- this.visibleHeight - this.constructor.padding * 2
- );
- }
-
- /**
- * Clears the content memory
- */
- clearMemory(){
- this.content.clearMemory();
- }
-
- /**
- * Increase the content size
- */
- increaseSize() {
- this.content.increaseSize();
- }
-
- /**
- * Decrease the content size
- */
- decreaseSize() {
- this.content.decreaseSize();
- }
-
/**
* Updates the options after a user makes a change.
*/
async updateOptions(newOptions) {
- super.updateOptions(newOptions);
-
// Reflected in DOM
this.updateFit(newOptions.fit);
this.updateAnchor(newOptions.anchor);
-
- // Flags
- this.infer = newOptions.infer;
- this.control = newOptions.control;
- this.inpaint = newOptions.inpaint;
- this.imagePrompt = newOptions.imagePrompt;
-
- // Conditional inputs
- this.strength = newOptions.strength;
- this.imagePromptScale = newOptions.imagePromptScale;
- this.imagePromptPlus = newOptions.imagePromptPlus;
- this.imagePromptFace = newOptions.imagePromptFace;
- this.controlnet = newOptions.controlnet;
- this.conditioningScale = newOptions.conditioningScale;
- this.conditioningStart = newOptions.conditioningStart;
- this.conditioningEnd = newOptions.conditioningEnd;
- this.processControlImage = newOptions.processControlImage;
- this.invertControlImage = newOptions.invertControlImage;
- this.cropInpaint = newOptions.cropInpaint;
- this.inpaintFeather = newOptions.inpaintFeather;
-
- // Update scribble view if inpainting
- if (this.node !== undefined) {
- if (this.inpaint) {
- this.content.showScribble();
- // Make sure the scribble view is the right size
- this.content.resize(this.w, this.h);
- } else {
- this.content.clearScribble();
- }
- }
-
- // Buttons
- if (!isEmpty(this.buttons)) {
- for (let button of this.constructor.scribbleButtons) {
- this.buttons[button].disabled = !newOptions.inpaint;
- }
- }
-
- this.rebuildHeaderButtons();
};
/**
@@ -479,10 +103,10 @@ class ImageEditorImageNodeView extends ImageEditorNodeView {
this.fit = newFit;
this.content.fit = newFit;
for (let fitMode of this.constructor.allFitModes) {
- this.content.removeImageClass(`fit-${fitMode}`);
+ this.content.removeClass(`fit-${fitMode}`);
}
if (!isEmpty(newFit)) {
- this.content.addImageClass(`fit-${newFit}`);
+ this.content.addClass(`fit-${newFit}`);
}
};
@@ -493,13 +117,35 @@ class ImageEditorImageNodeView extends ImageEditorNodeView {
this.anchor = newAnchor;
this.content.anchor = newAnchor;
for (let anchorMode of this.constructor.allAnchorModes) {
- this.content.removeImageClass(`anchor-${anchorMode}`);
+ this.content.removeClass(`anchor-${anchorMode}`);
}
if (!isEmpty(newAnchor)) {
- this.content.addImageClass(`anchor-${newAnchor}`);
+ this.content.addClass(`anchor-${newAnchor}`);
}
}
-
+
+ /**
+ * Prompts for a new image
+ */
+ async replaceImage() {
+ let imageToLoad;
+ try {
+ imageToLoad = await promptFiles("image/*");
+ } catch(e) {
+ // No files selected
+ }
+ if (!isEmpty(imageToLoad)) {
+ let reader = new FileReader();
+ reader.addEventListener("load", async () => {
+ let imageView = new BackgroundImageView(this.config, reader.result, false);
+ await this.setContent(imageView);
+ this.updateFit(this.fit);
+ this.updateAnchor(this.anchor);
+ });
+ reader.readAsDataURL(imageToLoad);
+ }
+ }
+
/**
* Mirrors the image horizontally
*/
@@ -528,97 +174,44 @@ class ImageEditorImageNodeView extends ImageEditorNodeView {
return this.content.rotateCounterClockwise();
}
- /**
- * Toggle the shape of the scribble pencil
- */
- togglePencilShape() {
- let currentShape = this.content.shape;
-
- if (currentShape === "circle") {
- this.content.shape = "square";
- this.buttons.shape.tooltip = ImageEditorScribbleNodeView.pencilCircleTooltip;
- this.buttons.shape.icon = ImageEditorScribbleNodeView.pencilCircleIcon;
- } else {
- this.content.shape = "circle";
- this.buttons.shape.tooltip = ImageEditorScribbleNodeView.pencilSquareTooltip;
- this.buttons.shape.icon = ImageEditorScribbleNodeView.pencilSquareIcon;
- }
-
- this.rebuildHeaderButtons();
- };
-
- /**
- * Toggles erase mode
- */
- toggleEraser() {
- let currentEraser = this.content.isEraser === true;
-
- if (currentEraser) {
- this.content.isEraser = false;
- this.buttons.erase.icon = ImageEditorScribbleNodeView.eraserIcon;
- this.buttons.erase.tooltip = ImageEditorScribbleNodeView.eraserTooltip;
- } else {
- this.content.isEraser = true;
- this.buttons.erase.icon = ImageEditorScribbleNodeView.pencilIcon;
- this.buttons.erase.tooltip = ImageEditorScribbleNodeView.pencilTooltip;
- }
-
- this.rebuildHeaderButtons();
- };
-
/**
* Override getState to include the image, fit and anchor
*/
getState(includeImages = true) {
let state = super.getState(includeImages);
- state.scribbleSrc = includeImages ? this.content.scribbleSrc : null;
- state.src = includeImages ? this.content.imageSrc : null;
+ state.src = includeImages ? this.content.src : null;
state.anchor = this.anchor || null;
state.fit = this.fit || null;
- state.infer = this.infer || false;
- state.control = this.control || false;
- state.inpaint = this.inpaint || false;
- state.imagePrompt = this.imagePrompt || false;
- state.imagePromptPlus = this.imagePromptPlus || false;
- state.imagePromptFace = this.imagePromptFace || false;
- state.imagePromptScale = this.imagePromptScale || 0.5;
- state.strength = this.strength || 0.8;
- state.controlnet = this.controlnet || null;
- state.cropInpaint = this.cropInpaint !== false;
- state.inpaintFeather = this.inpaintFeather || 32;
- state.conditioningScale = this.conditioningScale || 1.0;
- state.conditioningStart = this.conditioningStart || 0.0;
- state.conditioningEnd = this.conditioningEnd || 1.0;
- state.processControlImage = this.processControlImage !== false;
- state.invertControlImage = this.invertControlImage === true;
- state.removeBackground = this.removeBackground === true;
- state.scaleToModelSize = this.scaleToModelSize === true;
return state;
}
/**
- * Override setState to add the image and scribble
+ * Override setState to add the image
*/
async setState(newState) {
await super.setState(newState);
- await this.setContent(new ImageScribbleView(this.config, newState.src, newState.w, newState.h));
+ if (isEmpty(newState.src)) {
+ await this.setContent(new NoImageView(this.config));
+ } else {
+ await this.setContent(new BackgroundImageView(this.config, newState.src, false));
+ }
await this.updateAnchor(newState.anchor);
await this.updateFit(newState.fit);
- if (newState.inpaint) {
- let scribbleImage = new Image();
- scribbleImage.onload = () => {
- this.content.setScribble(scribbleImage, this.w, this.h);
- };
- scribbleImage.src = newState.scribbleSrc;
+ }
+
+ /**
+ * Gets the size of the image when scaling the node
+ */
+ async getCanvasScaleSize() {
+ if (isEmpty(this.content.src)) {
+ return await super.getCanvasScaleSize();
} else {
- this.content.clearScribble();
- }
- if (!isEmpty(this.buttons)) {
- for (let button of this.constructor.scribbleButtons) {
- this.buttons[button].disabled = !newState.inpaint;
- }
+ await this.content.waitForLoad();
+ return [
+ Math.floor(this.content.width / 8) * 8,
+ Math.floor(this.content.height / 8) * 8
+ ];
}
- this.rebuildHeaderButtons();
}
/**
@@ -627,44 +220,8 @@ class ImageEditorImageNodeView extends ImageEditorNodeView {
static getDefaultState() {
return {
"classname": this.name,
- "inpaint": false,
- "control": false,
- "inpaint": false,
- "imagePrompt": false,
- "imagePromptPlus": false,
- "imagePromptFace": false,
- "cropInpaint": true,
- "inpaintFeather": 32,
- "inferenceSteps": null,
- "guidanceScale": null,
- "imagePromptScale": 0.5,
- "strength": 0.8,
- "processControlImage": true,
- "invertControlImage": false,
- "conditioningScale": 1.0,
- "conditioningStart": 0.0,
- "conditioningEnd": 1.0,
- "removeBackground": false,
- "scaleToModelSize": false,
};
}
-
- /**
- * Catch on build to ensure buttons are correct
- */
- async build() {
- let node = await super.build();
- if (this.inpaint === true) {
- for (let button of this.constructor.scribbleButtons) {
- this.buttons[button].disabled = false;
- }
- setTimeout(() => this.rebuildHeaderButtons(), 250);
- }
- return node;
- }
};
-export {
- ImageEditorImageNodeView,
- ImageEditorCompoundImageNodeView
-};
+export { ImageEditorImageNodeView };
diff --git a/src/js/nodes/image-editor/invocation.mjs b/src/js/nodes/image-editor/invocation.mjs
index f9ed3d03..ecb3ffd0 100644
--- a/src/js/nodes/image-editor/invocation.mjs
+++ b/src/js/nodes/image-editor/invocation.mjs
@@ -1,6 +1,7 @@
-/** @module nodes/image-editor/invocation.mjs */
+/** @module nodes/image-editor/invocation */
import { SimpleNotification } from "../../common/notify.mjs";
import { isEmpty } from "../../base/helpers.mjs";
+import { View } from "../../view/base.mjs";
import { ImageView } from "../../view/image.mjs";
import { ToolbarView } from "../../view/menu.mjs";
import {
@@ -49,13 +50,14 @@ class InvocationToolbarView extends ToolbarView {
/**
* Create a small extension of the ImageView to change the class name for CSS.
*/
-class CurrentInvocationImageView extends ImageView {
+class CurrentInvocationImageView extends View {
/**
* Constructed by the editor, pass reference so we can call other functions
*/
constructor(editor) {
super(editor.config);
this.editor = editor;
+ this.imageView = new ImageView(this.config);
}
/**
@@ -185,19 +187,26 @@ class CurrentInvocationImageView extends ImageView {
* Override parent setImage to also set the image on the adjustment canvas, if present
*/
setImage(newImage) {
- super.setImage(newImage);
+ this.imageView.setImage(newImage);
if (!isEmpty(this.imageAdjuster)) {
this.imageAdjuster.setImage(newImage);
}
}
+ /**
+ * Pass through some functions to imageview
+ */
+ async waitForLoad() {
+ await this.imageView.waitForLoad();
+ }
+
/**
* Triggers the copy to clipboard
*/
async copyToClipboard() {
navigator.clipboard.write([
new ClipboardItem({
- "image/png": await this.getBlob()
+ "image/png": await this.imageView.getBlob()
})
]);
SimpleNotification.notify("Copied to clipboard!", 2000);
@@ -208,7 +217,7 @@ class CurrentInvocationImageView extends ImageView {
* Asks for a filename first
*/
async saveToDisk() {
- this.editor.application.saveBlobAs("Save Image", await this.getBlob(), ".png");
+ this.editor.application.saveBlobAs("Save Image", await this.imageView.getBlob(), ".png");
}
/**
@@ -217,7 +226,7 @@ class CurrentInvocationImageView extends ImageView {
*/
async sendToCanvas() {
this.editor.application.initializeStateFromImage(
- await this.getImageAsDataURL(),
+ await this.imageView.getImageAsDataURL(),
true, // Save history
null, // Prompt for current state
{
@@ -233,14 +242,14 @@ class CurrentInvocationImageView extends ImageView {
async startImageDownscale() {
if (this.checkActiveTool("downscale")) return;
- let imageBeforeDownscale = this.src,
- widthBeforeDownscale = this.width,
- heightBeforeDownscale = this.height,
+ let imageBeforeDownscale = this.imageView.src,
+ widthBeforeDownscale = this.imageView.width,
+ heightBeforeDownscale = this.imageView.height,
setDownscaleAmount = async (amount) => {
let image = new ImageView(this.config, imageBeforeDownscale);
await image.waitForLoad();
await image.downscale(amount);
- this.setImage(image.src);
+ this.imageView.setImage(image.src);
this.editor.setDimension(image.width, image.height, false);
},
saveResults = false;
@@ -256,7 +265,7 @@ class CurrentInvocationImageView extends ImageView {
this.imageDownscaleForm = null;
this.imageDownscaleWindow = null;
if (!saveResults) {
- this.setImage(imageBeforeDownscale);
+ this.imageView.setImage(imageBeforeDownscale);
this.editor.setDimension(widthBeforeDownscale, heightBeforeDownscale, false);
}
});
@@ -291,7 +300,7 @@ class CurrentInvocationImageView extends ImageView {
this.imageUpscaleForm.onCancel(() => this.imageUpscaleWindow.remove());
this.imageUpscaleForm.onSubmit(async (values) => {
await this.editor.application.initializeStateFromImage(
- await this.getImageAsDataURL(),
+ await this.imageView.getImageAsDataURL(),
true, // Save history
true, // Keep current state, except for...
{
@@ -321,7 +330,7 @@ class CurrentInvocationImageView extends ImageView {
async startImageFilter() {
if (this.checkActiveTool("filter")) return;
- this.imageFilterView = new ImageFilterView(this.config, this.src, this.node.element.parentElement),
+ this.imageFilterView = new ImageFilterView(this.config, this.imageView.src, this.node.element.parentElement),
this.imageFilterWindow = await this.editor.application.windows.spawnWindow(
"Filter Image",
this.imageFilterView,
@@ -339,7 +348,7 @@ class CurrentInvocationImageView extends ImageView {
this.imageFilterWindow.onClose(reset);
this.imageFilterView.onSave(async () => {
- this.setImage(this.imageFilterView.getImageSource());
+ this.imageView.setImage(this.imageFilterView.getImageSource());
setTimeout(() => {
this.imageFilterWindow.remove();
reset();
@@ -358,7 +367,7 @@ class CurrentInvocationImageView extends ImageView {
async startImageAdjustment() {
if (this.checkActiveTool("adjust")) return;
- this.imageAdjustmentView = new ImageAdjustmentView(this.config, this.src, this.node.element.parentElement),
+ this.imageAdjustmentView = new ImageAdjustmentView(this.config, this.imageView.src, this.node.element.parentElement),
this.imageAdjustmentWindow = await this.editor.application.windows.spawnWindow(
"Adjust Image",
this.imageAdjustmentView,
@@ -376,7 +385,7 @@ class CurrentInvocationImageView extends ImageView {
this.imageAdjustmentWindow.onClose(reset);
this.imageAdjustmentView.onSave(async () => {
- this.setImage(this.imageAdjustmentView.getImageSource());
+ this.imageView.setImage(this.imageAdjustmentView.getImageSource());
await this.waitForLoad();
setTimeout(() => {
this.imageAdjustmentWindow.remove();
@@ -511,6 +520,7 @@ class CurrentInvocationImageView extends ImageView {
*/
async build() {
let node = await super.build();
+ node.content(await this.imageView.getNode());
node.on("mouseenter", (e) => this.onMouseEnter(e));
node.on("mouseleave", (e) => this.onMouseLeave(e));
return node;
diff --git a/src/js/nodes/image-editor/prompt.mjs b/src/js/nodes/image-editor/prompt.mjs
index 64572c6e..893156f6 100644
--- a/src/js/nodes/image-editor/prompt.mjs
+++ b/src/js/nodes/image-editor/prompt.mjs
@@ -1,6 +1,52 @@
/** @module nodes/image-editor/prompt.mjs */
+import { ElementBuilder } from "../../base/builder.mjs";
+import { isEmpty } from "../../base/helpers.mjs";
+import { View } from "../../view/base.mjs";
import { ImageEditorNodeView } from "./base.mjs";
-import { ImageEditorNodeOptionsFormView } from "../../forms/enfugue/image-editor.mjs";
+
+const E = new ElementBuilder();
+
+class PromptNodeContentView extends View {
+ /**
+ * @var string tag name
+ */
+ static tagName = "enfugue-region-prompts";
+
+ /**
+ * @var string The text to display initially
+ */
+ static placeholderText = "Use the layer options menu to add a prompt. This region will be filled with an image generated from that prompt, instead of the global prompt.
Any remaining empty regions will be inpainted.
Check Remove Background to remove the background before merging down and inpainting.";
+
+ /**
+ * Sets the prompts
+ */
+ setPrompts(positive, negative = null) {
+ if (isEmpty(positive)) {
+ this.node.content(E.p().content(this.constructor.placeholderText));
+ } else {
+ let positiveContent = Array.isArray(positive)
+ ? positive.join(", ")
+ : positive;
+ let contentArray = [E.p().content(positiveContent)];
+ if (!isEmpty(negative)) {
+ let negativeContent = Array.isArray(negative)
+ ? negative.join(", ")
+ : negative;
+ contentArray.push(E.p().content(negativeContent));
+ }
+ this.node.content(...contentArray);
+ }
+ }
+
+ /**
+ * On first build, append placeholder
+ */
+ async build() {
+ let node = await super.build();
+ node.content(E.p().content(this.constructor.placeholderText));
+ return node;
+ }
+}
/**
* The PromptNode just allows for regions to have different prompts.
@@ -35,27 +81,15 @@ class ImageEditorPromptNodeView extends ImageEditorNodeView {
* Intercept the constructor to set the contents to the options.
*/
constructor(editor, name, content, left, top, width, height) {
- let realContent = new ImageEditorNodeOptionsFormView(editor.config);
+ let realContent = new PromptNodeContentView(editor.config);
super(editor, name, realContent, left, top, width, height);
- realContent.onSubmit((values) => this.updateOptions(values));
- }
-
- /**
- * Gets state from the content
- */
- getState(includeImages = true) {
- let state = super.getState(includeImages);
- state = {...state, ...this.content.values};
- return state;
}
/**
- * Set the state on the content
+ * Sets the prompts
*/
- async setState(newState) {
- await super.setState(newState);
- await this.content.getNode(); // Wait for first build
- await this.content.setValues(newState);
+ setPrompts(positive, negative = null) {
+ this.content.setPrompts(positive, negative);
}
};
diff --git a/src/js/nodes/image-editor/scribble.mjs b/src/js/nodes/image-editor/scribble.mjs
index fdb5ff8c..ef054b6a 100644
--- a/src/js/nodes/image-editor/scribble.mjs
+++ b/src/js/nodes/image-editor/scribble.mjs
@@ -12,6 +12,11 @@ class ImageEditorScribbleNodeView extends ImageEditorNodeView {
*/
static nodeTypeName = "Scribble";
+ /**
+ * @var bool hide header
+ */
+ static hideHeader = true;
+
/**
* @var string The icon for changing the cursor to a square.
*/
diff --git a/src/js/nodes/image-editor/video.mjs b/src/js/nodes/image-editor/video.mjs
new file mode 100644
index 00000000..aff7a88c
--- /dev/null
+++ b/src/js/nodes/image-editor/video.mjs
@@ -0,0 +1,167 @@
+/** @module nodes/image-editor/video.mjs */
+import { isEmpty, promptFiles } from "../../base/helpers.mjs";
+import { View } from "../../view/base.mjs";
+import { VideoView } from "../../view/video.mjs";
+import { ImageEditorNodeView } from "./base.mjs";
+import { NoVideoView } from "./common.mjs";
+
+/**
+ * When pasting videos on the video editor, allow a few fit options
+ */
+class ImageEditorVideoNodeView extends ImageEditorNodeView {
+ /**
+ * @var bool Hide header (position absolutely)
+ */
+ static hideHeader = true;
+
+ /**
+ * @var string The name to show in the menu
+ */
+ static nodeTypeName = "Video";
+
+ /**
+ * @var array All fit modes.
+ */
+ static allFitModes = ["actual", "stretch", "cover", "contain"];
+
+ /**
+ * @var array All anchor modes.
+ */
+ static allAnchorModes = [
+ "top-left", "top-center", "top-right",
+ "center-left", "center-center", "center-right",
+ "bottom-left", "bottom-center", "bottom-right"
+ ];
+
+ /**
+ * @var string Add the classname for CSS
+ */
+ static className = 'image-editor-video-node-view';
+
+ /**
+ * @var object Buttons to control the scribble. Shortcuts are registered on the view itself.
+ */
+ static nodeButtons = {
+ ...ImageEditorNodeView.nodeButtons,
+ ...{
+ "replace-video": {
+ "icon": "fa-solid fa-upload",
+ "tooltip": "Replace Video",
+ "shortcut": "c",
+ "callback": function() {
+ this.replaceVideo();
+ }
+ }
+ }
+ };
+
+ /**
+ * Updates the options after a user makes a change.
+ */
+ async updateOptions(newOptions) {
+ // Reflected in DOM
+ this.updateFit(newOptions.fit);
+ this.updateAnchor(newOptions.anchor);
+ };
+
+ /**
+ * Updates the video fit
+ */
+ async updateFit(newFit) {
+ this.fit = newFit;
+ this.content.fit = newFit;
+ for (let fitMode of this.constructor.allFitModes) {
+ this.content.removeClass(`fit-${fitMode}`);
+ }
+ if (!isEmpty(newFit)) {
+ this.content.addClass(`fit-${newFit}`);
+ }
+ };
+
+ /**
+ * Updates the video anchor
+ */
+ async updateAnchor(newAnchor) {
+ this.anchor = newAnchor;
+ this.content.anchor = newAnchor;
+ for (let anchorMode of this.constructor.allAnchorModes) {
+ this.content.removeClass(`anchor-${anchorMode}`);
+ }
+ if (!isEmpty(newAnchor)) {
+ this.content.addClass(`anchor-${newAnchor}`);
+ }
+ }
+
+ /**
+ * Prompts for a new video
+ */
+ async replaceVideo() {
+ let videoToLoad;
+ try {
+ videoToLoad = await promptFiles("video/*");
+ } catch(e) {
+ // No files selected
+ }
+ if (!isEmpty(videoToLoad)) {
+ let reader = new FileReader();
+ reader.addEventListener("load", async () => {
+ let videoView = new VideoView(this.config, reader.result);
+ await this.setContent(videoView);
+ this.updateFit(this.fit);
+ this.updateAnchor(this.anchor);
+ });
+ reader.readAsDataURL(videoToLoad);
+ }
+ }
+
+ /**
+ * Override getState to include the video, fit and anchor
+ */
+ getState(includeImages = true) {
+ let state = super.getState(includeImages);
+ state.src = includeImages ? this.content.src : null;
+ state.anchor = this.anchor || null;
+ state.fit = this.fit || null;
+ return state;
+ }
+
+ /**
+ * Override setState to add the video and scribble
+ */
+ async setState(newState) {
+ await super.setState(newState);
+ if (isEmpty(newState.src)) {
+ await this.setContent(new NoVideoView(this.config));
+ } else {
+ await this.setContent(new VideoView(this.config, newState.src));
+ }
+ await this.updateAnchor(newState.anchor);
+ await this.updateFit(newState.fit);
+ }
+
+ /**
+ * Gets the size of the video when scaling the node
+ */
+ async getCanvasScaleSize() {
+ if (isEmpty(this.content.src)) {
+ return await super.getCanvasScaleSize();
+ } else {
+ await this.content.waitForLoad();
+ return [
+ Math.floor(this.content.width / 8) * 8,
+ Math.floor(this.content.height / 8) * 8
+ ];
+ }
+ }
+
+ /**
+ * Provide a default state for when we are initializing from an video
+ */
+ static getDefaultState() {
+ return {
+ "classname": this.name,
+ };
+ }
+};
+
+export { ImageEditorVideoNodeView };
diff --git a/src/js/view/animation.mjs b/src/js/view/animation.mjs
new file mode 100644
index 00000000..d1af31aa
--- /dev/null
+++ b/src/js/view/animation.mjs
@@ -0,0 +1,131 @@
+/** @module view/animation */
+import { isEmpty, waitFor } from "../base/helpers.mjs";
+import { View } from "./base.mjs";
+import { ImageView } from "./image.mjs";
+
+/**
+ * The AnimationView extends the ImageView for working with animations.
+ */
+class AnimationView extends View {
+ /**
+ * @var string The tag name
+ */
+ static tagName = "enfugue-animation-view";
+
+ /**
+ * On construct, check if we're initializing with sources
+ */
+ constructor(config, images = []){
+ super(config);
+ this.canvas = document.createElement("canvas");
+ this.loadedCallbacks = [];
+ this.setImages(images);
+ }
+
+ /**
+ * Adds a callback to fire when imgaes are loaded
+ */
+ onLoad(callback) {
+ if (this.loaded) {
+ callback(this);
+ } else {
+ this.loadedCallbacks.push(callback);
+ }
+ }
+
+ /**
+ * On set image, wait for load then trigger callbacks
+ */
+ setImages(images) {
+ this.images = images;
+ if (isEmpty(images)) {
+ this.loaded = true;
+ this.clearCanvas();
+ } else {
+ this.loaded = false;
+ this.imageViews = images.map(
+ (image) => new ImageView(this.config, image, false)
+ );
+ Promise.all(
+ this.imageViews.map(
+ (imageView) => imageView.waitForLoad()
+ )
+ ).then(() => this.imagesLoaded());
+ }
+ }
+
+ /**
+ * When images are loaded, fire callbacks
+ */
+ async imagesLoaded() {
+ this.loaded = true;
+
+ if (!isEmpty(this.imageViews)) {
+ this.width = this.imageViews[0].width;
+ this.height = this.imageViews[0].height;
+
+ this.canvas.width = this.width;
+ this.canvas.height = this.height;
+
+ let context = this.canvas.getContext("2d");
+ context.drawImage(this.imageViews[0].image, 0, 0);
+
+ if (this.node !== undefined) {
+ this.node.css({
+ "width": this.width,
+ "height": this.height
+ });
+ }
+ }
+
+ for (let callback of this.loadedCallbacks) {
+ await callback();
+ }
+ }
+
+ /**
+ * Waits for the promise boolean to be set
+ */
+ waitForLoad() {
+ return waitFor(() => this.loaded);
+ }
+
+ /**
+ * Sets the frame index
+ */
+ setFrame(index) {
+ if (isEmpty(index)) index = 0;
+ this.frame = index;
+ if (this.loaded) {
+ let context = this.canvas.getContext("2d");
+ context.drawImage(this.imageViews[this.frame].image, 0, 0);
+ } else {
+ this.waitForLoad().then(() => this.setFrame(index));
+ }
+ }
+
+ /**
+ * Clears the canvas
+ */
+ clearCanvas() {
+ let context = this.canvas.getContext("2d");
+ context.clearRect(0, 0, this.canvas.width, this.canvas.height);
+ }
+
+ /**
+ * On build, append canvas
+ */
+ async build() {
+ let node = await super.build();
+ node.content(this.canvas);
+ if (this.loaded) {
+ node.css({
+ "width": this.width,
+ "height": this.height
+ });
+ }
+ return node;
+ }
+}
+
+export { AnimationView };
diff --git a/src/js/view/image.mjs b/src/js/view/image.mjs
index d9afab5c..cd5bfe71 100644
--- a/src/js/view/image.mjs
+++ b/src/js/view/image.mjs
@@ -23,13 +23,21 @@ class ImageView extends View {
* @param object $config The base config object
* @param string $src The image source
*/
- constructor(config, src, usePng = true) {
+ constructor(config, src, usePng) {
super(config);
this.src = src;
this.usePng = usePng;
this.loadedCallbacks = [];
this.metadata = {};
if (!isEmpty(src)) {
+ if (usePng === null || usePng === undefined) {
+ if (src.startsWith("data")) {
+ let fileType = src.substring(5, src.indexOf(";"));
+ usePng = fileType === "image/png";
+ } else {
+ usePng = src.endsWith(".png");
+ }
+ }
if (usePng) {
let callable = PNG.fromURL;
if (src instanceof File) {
@@ -51,7 +59,7 @@ class ImageView extends View {
});
} else {
this.image = new Image();
- this.image.onload = this.imageLoaded();
+ this.image.onload = () => this.imageLoaded();
this.image.src = this.src;
}
}
@@ -254,12 +262,10 @@ class ImageView extends View {
async build() {
let node = await super.build();
if (!isEmpty(this.src)) {
- await this.waitForLoad();
node.attr("src", this.src);
}
return node;
}
-
}
/**
diff --git a/src/js/view/menu.mjs b/src/js/view/menu.mjs
index 4585a8f4..e538c90b 100644
--- a/src/js/view/menu.mjs
+++ b/src/js/view/menu.mjs
@@ -80,6 +80,22 @@ class MenuView extends ParentView {
}
}
+ /**
+ * Starts a hide timer to hide self
+ */
+ startHideTimer() {
+ this.hideTimer = setTimeout(() => {
+ this.hideCategories();
+ }, 500);
+ }
+
+ /**
+ * Stops the hide timer
+ */
+ stopHideTimer() {
+ clearTimeout(this.hideTimer);
+ }
+
/**
* Toggles a specific category
*
@@ -88,6 +104,7 @@ class MenuView extends ParentView {
*/
toggleCategory(name) {
let found = false, newValue;
+ this.stopHideTimer();
for (let child of this.children) {
if (child instanceof MenuCategoryView){
if (child.name === name) {
@@ -243,7 +260,8 @@ class MenuCategoryView extends ParentView {
* @return IconItemView
*/
async addItem(name, icon, shortcut) {
- return this.addChild(IconItemView, name, icon, shortcut);
+ let itemView = await this.addChild(IconItemView, name, icon, shortcut);
+ return itemView;
}
/**
@@ -257,7 +275,11 @@ class MenuCategoryView extends ParentView {
header.append(button);
}
- node.prepend(header).on("click", () => this.parent.toggleCategory(this.name));
+ node.prepend(header)
+ .on("click", () => this.parent.toggleCategory(this.name))
+ .on("mouseleave", () => this.parent.startHideTimer())
+ .on("mouseenter,mousemove", () => this.parent.stopHideTimer());
+
return node;
}
}
@@ -318,9 +340,9 @@ class MenuItemView extends View {
async build() {
let node = await super.build();
node.prepend(E.span().content(formatName(this.name, this.shortcut)));
- node.on("click", async(e) => {
+ node.on("click", (e) => {
e.preventDefault();
- this.activate()
+ this.activate();
});
return node;
}
@@ -340,6 +362,20 @@ class IconItemView extends MenuItemView {
this.icon = icon;
}
+ /**
+ * Sets the icon after initialization
+ */
+ setIcon(newIcon) {
+ this.icon = newIcon;
+ if (!isEmpty(this.node)) {
+ if (newIcon.startsWith("http")) {
+ this.node.find("img").src(newIcon);
+ } else {
+ this.node.find("i").class(newIcon)
+ }
+ }
+ }
+
/**
* On build, append icon.
*/
diff --git a/src/js/view/samples/chooser.mjs b/src/js/view/samples/chooser.mjs
new file mode 100644
index 00000000..89318513
--- /dev/null
+++ b/src/js/view/samples/chooser.mjs
@@ -0,0 +1,511 @@
+/** @module view/samples/chooser */
+import { isEmpty, isEquivalent, bindMouseUntilRelease } from "../../base/helpers.mjs";
+import { View } from "../base.mjs";
+import { ImageView } from "../image.mjs";
+import { ElementBuilder } from "../../base/builder.mjs";
+import { NumberInputView } from "../../forms/input.mjs";
+
+const E = new ElementBuilder();
+
+class SampleChooserView extends View {
+ /**
+ * @var string Custom tag name
+ */
+ static tagName = "enfugue-sample-chooser";
+
+ /**
+ * @var string Show canvas icon
+ */
+ static showCanvasIcon = "fa-solid fa-table-cells";
+
+ /**
+ * @var string Show canvas tooltip
+ */
+ static showCanvasTooltip = "Show the canvas, hiding any sample currently visible on-screen and revealing the grid and any nodes you've placed on it.";
+
+ /**
+ * @var string Loop video icon
+ */
+ static loopIcon = "fa-solid fa-rotate-left";
+
+ /**
+ * @var string Loop video tooltip
+ */
+ static loopTooltip = "Loop the video, restarting it after it has completed.";
+
+ /**
+ * @var string Play video icon
+ */
+ static playIcon = "fa-solid fa-play";
+
+ /**
+ * @var string Play video tooltip
+ */
+ static playTooltip = "Play the animation.";
+
+ /**
+ * @var string Tile vertical icon
+ */
+ static tileVerticalIcon = "fa-solid fa-ellipsis-vertical";
+
+ /**
+ * @var string Tile vertical tooltip
+ */
+ static tileVerticalTooltip = "Show the image tiled vertically.";
+
+ /**
+ * @var string Tile horizontal icon
+ */
+ static tileHorizontalIcon = "fa-solid fa-ellipsis";
+
+ /**
+ * @var string Tile horizontal tooltip
+ */
+ static tileHorizontalTooltip = "Show the image tiled horizontally.";
+
+ /**
+ * @var int default playback rate
+ */
+ static playbackRate = 8;
+
+ /**
+ * @var string playback rate tooltip
+ */
+ static playbackRateTooltip = "The playback rate of the animation in frames per second.";
+
+ /**
+ * @var string Text to show when there are no samples
+ */
+ static noSamplesLabel = "No samples yet. When you generate one or more images, their thumbnails will appear here.";
+
+ /**
+ * Constructor creates arrays for callbacks
+ */
+ constructor(config, samples = [], isAnimation = false) {
+ super(config);
+ this.showCanvasCallbacks = [];
+ this.loopAnimationCallbacks = [];
+ this.playAnimationCallbacks = [];
+ this.tileHorizontalCallbacks = [];
+ this.tileVerticalCallbacks = [];
+ this.setActiveCallbacks = [];
+ this.setPlaybackRateCallbacks = [];
+ this.imageViews = [];
+ this.isAnimation = isAnimation;
+ this.samples = samples;
+ this.activeIndex = 0;
+ this.playbackRate = this.constructor.playbackRate;
+ this.playbackRateInput = new NumberInputView(config, "playbackRate", {
+ "min": 1,
+ "max": 60,
+ "value": this.constructor.playbackRate,
+ "tooltip": this.constructor.playbackRateTooltip,
+ "allowNull": false
+ });
+ this.playbackRateInput.onChange(
+ () => this.setPlaybackRate(this.playbackRateInput.getValue(), false)
+ );
+ }
+
+ // ADD CALLBACK FUNCTIONS
+
+ /**
+ * Adds a callback to the show canvas button
+ */
+ onShowCanvas(callback){
+ this.showCanvasCallbacks.push(callback);
+ }
+
+ /**
+ * Adds a callback to the loop animation button
+ */
+ onLoopAnimation(callback) {
+ this.loopAnimationCallbacks.push(callback);
+ }
+
+ /**
+ * Adds a callback to the play animation button
+ */
+ onPlayAnimation(callback) {
+ this.playAnimationCallbacks.push(callback);
+ }
+
+ /**
+ * Adds a callback to the tile horizontal button
+ */
+ onTileHorizontal(callback) {
+ this.tileHorizontalCallbacks.push(callback);
+ }
+
+ /**
+ * Adds a callback to the tile vertical button
+ */
+ onTileVertical(callback) {
+ this.tileVerticalCallbacks.push(callback);
+ }
+
+ /**
+ * Adds a callback to when active is set
+ */
+ onSetActive(callback) {
+ this.setActiveCallbacks.push(callback);
+ }
+
+ /**
+ * Adds a callback to when playback rate is set
+ */
+ onSetPlaybackRate(callback) {
+ this.setPlaybackRateCallbacks.push(callback);
+ }
+
+ // EXECUTE CALLBACK FUNCTIONS
+
+ /**
+ * Calls show canvas callbacks
+ */
+ showCanvas() {
+ this.setActiveIndex(null);
+ for (let callback of this.showCanvasCallbacks) {
+ callback();
+ }
+ }
+
+ /**
+ * Sets whether or not the samples should be controlled as an animation
+ */
+ setIsAnimation(isAnimation) {
+ this.isAnimation = isAnimation;
+ if (!isEmpty(this.node)) {
+ if (isAnimation) {
+ this.node.addClass("animation");
+ } else {
+ this.node.removeClass("animation");
+ }
+ }
+ }
+
+ /**
+ * Calls tile horizontal callbacks
+ */
+ setHorizontalTile(tileHorizontal, updateDom = true) {
+ for (let callback of this.tileHorizontalCallbacks) {
+ callback(tileHorizontal);
+ }
+ if (!isEmpty(this.node) && updateDom) {
+ let tileButton = this.node.find(".tile-horizontal");
+ if (tileHorizontal) {
+ tileButton.addClass("active");
+ } else {
+ tileButton.removeClass("active");
+ }
+ }
+ }
+
+ /**
+ * Calls tile vertical callbacks
+ */
+ setVerticalTile(tileVertical, updateDom = true) {
+ for (let callback of this.tileVerticalCallbacks) {
+ callback(tileVertical);
+ }
+ if (!isEmpty(this.node) && updateDom) {
+ let tileButton = this.node.find(".tile-vertical");
+ if (tileVertical) {
+ tileButton.addClass("active");
+ } else {
+ tileButton.removeClass("active");
+ }
+ }
+ }
+
+ /**
+ * Calls loop animation callbacks
+ */
+ setLoopAnimation(loopAnimation, updateDom = true) {
+ for (let callback of this.loopAnimationCallbacks) {
+ callback(loopAnimation);
+ }
+ if (!isEmpty(this.node) && updateDom) {
+ let loopButton = this.node.find(".loop");
+ if (loopAnimation) {
+ loopButton.addClass("active");
+ } else {
+ loopButton.removeClass("active");
+ }
+ }
+ }
+
+ /**
+ * Calls play animation callbacks
+ */
+ setPlayAnimation(playAnimation, updateDom = true) {
+ for (let callback of this.playAnimationCallbacks) {
+ callback(playAnimation);
+ }
+ if (!isEmpty(this.node) && updateDom) {
+ let playButton = this.node.find(".play");
+ if (playAnimation) {
+ playButton.addClass("active");
+ } else {
+ playButton.removeClass("active");
+ }
+ }
+ }
+
+ /**
+ * Sets the active sample in the chooser
+ */
+ setActiveIndex(activeIndex, invokeCallbacks = true) {
+ this.activeIndex = activeIndex;
+ if (invokeCallbacks) {
+ for (let callback of this.setActiveCallbacks) {
+ callback(activeIndex);
+ }
+ }
+ if (!isEmpty(this.imageViews)) {
+ for (let i in this.imageViews) {
+ let child = this.imageViews[i];
+ if (i++ == activeIndex) {
+ child.addClass("active");
+ } else {
+ child.removeClass("active");
+ }
+ }
+ }
+ }
+
+ /**
+ * Sets the playback rate
+ */
+ setPlaybackRate(playbackRate, updateDom = true) {
+ this.playbackRate = playbackRate;
+ for (let callback of this.setPlaybackRateCallbacks) {
+ callback(playbackRate);
+ }
+ if (updateDom) {
+ this.playbackRateInput.setValue(playbackRate, false);
+ }
+ }
+
+ /**
+ * Sets samples after initialization
+ */
+ async setSamples(samples) {
+ let isChanged = !isEquivalent(this.samples, samples);
+ this.samples = samples;
+
+ if (!isEmpty(this.node)) {
+ let samplesContainer = await this.node.find(".samples");
+ if (isEmpty(this.samples)) {
+ samplesContainer.content(
+ E.div().class("no-samples").content(this.constructor.noSamplesLabel)
+ );
+ this.imageViews = [];
+ } else if (isChanged) {
+ let samplesContainer = await this.node.find(".samples"),
+ render = false;
+
+ if (isEmpty(this.imageViews)) {
+ samplesContainer.empty();
+ render = true;
+ }
+ for (let i in this.samples) {
+ let imageView,
+ imageViewNode,
+ sample = this.samples[i];
+
+ if (this.imageViews.length <= i) {
+ imageView = new ImageView(this.config, sample, false);
+ await imageView.waitForLoad();
+ imageViewNode = await imageView.getNode();
+ imageViewNode.on("click", () => {
+ this.setActiveIndex(i);
+ });
+ this.imageViews.push(imageView);
+ samplesContainer.append(imageViewNode);
+ render = true;
+ } else {
+ imageView = this.imageViews[i];
+ imageView.setImage(sample);
+ await imageView.waitForLoad();
+ imageViewNode = await imageView.getNode();
+ }
+
+ if (this.activeIndex !== null && this.activeIndex == i) {
+ imageView.addClass("active");
+ } else {
+ imageView.removeClass("active");
+ }
+
+ if (this.isAnimation) {
+ let widthPercentage = 100.0 / this.samples.length;
+ imageViewNode.css("width", `${widthPercentage}%`);
+ } else {
+ imageViewNode.css("width", null);
+ }
+ }
+ if (render) {
+ samplesContainer.render();
+ }
+ }
+ }
+ }
+
+ /**
+ * On build, add icons and selectors as needed
+ */
+ async build() {
+ let node = await super.build(),
+ showCanvas = E.i()
+ .addClass("show-canvas")
+ .addClass(this.constructor.showCanvasIcon)
+ .data("tooltip", this.constructor.showCanvasTooltip)
+ .on("click", () => this.showCanvas()),
+ tileHorizontal = E.i()
+ .addClass("tile-horizontal")
+ .addClass(this.constructor.tileHorizontalIcon)
+ .data("tooltip", this.constructor.tileHorizontalTooltip)
+ .on("click", () => {
+ tileHorizontal.toggleClass("active");
+ this.setHorizontalTile(tileHorizontal.hasClass("active"), false);
+ }),
+ tileVertical = E.i()
+ .addClass("tile-vertical")
+ .addClass(this.constructor.tileVerticalIcon)
+ .data("tooltip", this.constructor.tileVerticalTooltip)
+ .on("click", () => {
+ tileVertical.toggleClass("active");
+ this.setVerticalTile(tileVertical.hasClass("active"), false);
+ }),
+ loopAnimation = E.i()
+ .addClass("loop")
+ .addClass(this.constructor.loopIcon)
+ .data("tooltip", this.constructor.loopTooltip)
+ .on("click", () => {
+ loopAnimation.toggleClass("active");
+ this.setLoopAnimation(loopAnimation.hasClass("active"), false);
+ }),
+ playAnimation = E.i()
+ .addClass("play")
+ .addClass(this.constructor.playIcon)
+ .data("tooltip", this.constructor.playTooltip)
+ .on("click", () => {
+ playAnimation.toggleClass("active");
+ this.setPlayAnimation(playAnimation.hasClass("active"), false);
+ }),
+ samplesContainer = E.div().class("samples");
+
+ let isScrubbing = false,
+ getFrameIndexFromMousePosition = (e) => {
+ let sampleContainerPosition = samplesContainer.element.getBoundingClientRect(),
+ clickRatio = e.clientX < sampleContainerPosition.left
+ ? 0
+ : e.clientX > sampleContainerPosition.left + sampleContainerPosition.width
+ ? 1
+ : (e.clientX - sampleContainerPosition.left) / sampleContainerPosition.width;
+
+ return Math.min(
+ Math.floor(clickRatio * this.samples.length),
+ this.samples.length - 1
+ );
+ };
+
+ samplesContainer
+ .on("wheel", (e) => {
+ e.preventDefault();
+ samplesContainer.element.scrollLeft += e.deltaY / 10;
+ })
+ .on("mousedown", (e) => {
+ if (this.isAnimation) {
+ e.preventDefault();
+ e.stopPropagation();
+
+ isScrubbing = true;
+ this.setActiveIndex(getFrameIndexFromMousePosition(e));
+
+ bindMouseUntilRelease(
+ (e2) => {
+ if (isScrubbing) {
+ this.setActiveIndex(getFrameIndexFromMousePosition(e2));
+ }
+ },
+ (e2) => {
+ isScrubbing = false;
+ }
+ );
+ }
+ })
+ .on("mousemove", (e) => {
+ if (this.isAnimation) {
+ e.preventDefault();
+ e.stopPropagation();
+ if (isScrubbing) {
+ this.setActiveIndex(getFrameIndexFromMousePosition(e));
+ }
+ }
+ })
+ .on("mouseup", (e) => {
+ if (this.isAnimation) {
+ e.preventDefault();
+ e.stopPropagation();
+ isScrubbing = false;
+ }
+ });
+
+ if (isEmpty(this.samples)) {
+ samplesContainer.append(
+ E.div().class("no-samples").content(this.constructor.noSamplesLabel)
+ );
+ } else {
+ for (let i in this.samples) {
+ let imageView,
+ imageViewNode,
+ sample = this.samples[i];
+
+ if (this.imageViews.length <= i) {
+ imageView = new ImageView(this.config, sample, false);
+ imageViewNode = await imageView.getNode();
+ imageViewNode.on("click", () => {
+ this.setActiveIndex(i);
+ });
+ this.imageViews.push(imageView);
+ } else {
+ imageView = this.imageViews[i];
+ imageView.setImage(sample);
+ imageViewNode = await imageView.getNode();
+ }
+
+ if (this.activeIndex !== null && this.activeIndex === i) {
+ imageView.addClass("active");
+ } else {
+ imageView.removeClass("active");
+ }
+
+ samplesContainer.append(imageViewNode);
+ }
+ }
+
+ node.content(
+ showCanvas,
+ E.div().class("tile-buttons").content(
+ tileHorizontal,
+ tileVertical
+ ),
+ samplesContainer,
+ E.div().class("playback-rate").content(
+ await this.playbackRateInput.getNode(),
+ E.span().content("fps")
+ ),
+ loopAnimation,
+ playAnimation
+ );
+
+ if (this.isAnimation) {
+ node.addClass("animation");
+ }
+
+ return node;
+ }
+};
+
+export { SampleChooserView };
diff --git a/src/js/view/samples/filter.mjs b/src/js/view/samples/filter.mjs
new file mode 100644
index 00000000..6da0c6eb
--- /dev/null
+++ b/src/js/view/samples/filter.mjs
@@ -0,0 +1,191 @@
+/** @module view/samples/filter.mjs */
+import { isEmpty } from "../../base/helpers.mjs";
+import { ElementBuilder } from "../../base/builder.mjs";
+import { View } from "../../view/base.mjs";
+import { ImageAdjustmentFilter } from "../../graphics/image-adjust.mjs";
+import { ImagePixelizeFilter } from "../../graphics/image-pixelize.mjs";
+import { ImageSharpenFilter } from "../../graphics/image-sharpen.mjs";
+import {
+ ImageBoxBlurFilter,
+ ImageGaussianBlurFilter
+} from "../../graphics/image-blur.mjs";
+import {
+ ImageFilterFormView,
+ ImageAdjustmentFormView
+} from "../../forms/enfugue/image-editor.mjs";
+
+const E = new ElementBuilder();
+
+/**
+ * Combines the a filter form view and various buttons for executing
+ */
+class ImageFilterView extends View {
+ /**
+ * @var class The class of the filter form.
+ */
+ static filterFormView = ImageFilterFormView;
+
+ /**
+ * On construct, build form and bind submit
+ */
+ constructor(config, image, container) {
+ super(config);
+ this.image = image;
+ this.container = container;
+ this.cancelCallbacks = [];
+ this.saveCallbacks = [];
+ this.formView = new this.constructor.filterFormView(config);
+ this.formView.onSubmit((values) => {
+ this.setFilter(values);
+ });
+ }
+
+ /**
+ * Creates a GPU-accelerated filter helper using the image
+ */
+ createFilter(filterType, execute = true) {
+ switch (filterType) {
+ case "box":
+ return new ImageBoxBlurFilter(this.image, execute);
+ case "gaussian":
+ return new ImageGaussianBlurFilter(this.image, execute);
+ case "sharpen":
+ return new ImageSharpenFilter(this.image, execute);
+ case "pixelize":
+ return new ImagePixelizeFilter(this.image, execute);
+ case "adjust":
+ return new ImageAdjustmentFilter(this.image, execute);
+ case "invert":
+ return new ImageAdjustmentFilter(this.image, execute, {invert: 1});
+ default:
+ console.error("Bad filter", filterType);
+ }
+ }
+
+ /**
+ * Gets the image source from the filter, if present
+ */
+ getImageSource() {
+ if (!isEmpty(this.filter)) {
+ return this.filter.imageSource;
+ }
+ return this.image;
+ }
+
+ /**
+ * Sets the filter and filter constants
+ */
+ setFilter(values) {
+ if (values.filter === null) {
+ this.removeCanvas();
+ } else if (values.filter !== undefined && this.filterType !== values.filter) {
+ // Filter changed
+ this.removeCanvas();
+ this.filter = this.createFilter(values.filter, false);
+ this.filterType = values.filter;
+ this.filter.getCanvas().then((canvas) => {
+ this.filter.setConstants(values);
+ this.canvas = canvas;
+ this.container.appendChild(this.canvas);
+ });
+ }
+
+ if (!isEmpty(this.filter)) {
+ this.filter.setConstants(values);
+ }
+ }
+
+ /**
+ * Removes the canvas if its attached
+ */
+ removeCanvas() {
+ if (!isEmpty(this.canvas)) {
+ try {
+ this.container.removeChild(this.canvas);
+ } catch(e) { }
+ this.canvas = null;
+ }
+ }
+
+ /**
+ * @param callable $callback Method to call when 'cancel' is clicked
+ */
+ onCancel(callback) {
+ this.cancelCallbacks.push(callback);
+ }
+
+ /**
+ * @param callable $callback Method to call when 'save' is clicked
+ */
+ onSave(callback) {
+ this.saveCallbacks.push(callback);
+ }
+
+ /**
+ * Call all save callbacks
+ */
+ async saved() {
+ for (let saveCallback of this.saveCallbacks) {
+ await saveCallback();
+ }
+ }
+
+ /**
+ * Call all cancel callbacks
+ */
+ async canceled() {
+ for (let cancelCallback of this.cancelCallbacks) {
+ await cancelCallback();
+ }
+ }
+
+ /**
+ * On build, add buttons and bind callbacks
+ */
+ async build() {
+ let node = await super.build(),
+ reset = E.button().class("column").content("Reset"),
+ save = E.button().class("column").content("Save"),
+ cancel = E.button().class("column").content("Cancel"),
+ nodeButtons = E.div().class("flex-columns half-spaced margin-top padded-horizontal").content(
+ reset,
+ save,
+ cancel
+ );
+
+ reset.on("click", () => {
+ this.formView.setValues(this.constructor.filterFormView.defaultValues);
+ setTimeout(() => { this.formView.submit(); }, 100);
+ });
+ save.on("click", () => this.saved());
+ cancel.on("click", () => this.canceled());
+ node.content(
+ await this.formView.getNode(),
+ nodeButtons
+ );
+ return node;
+ }
+};
+
+/**
+ * Combines the adjustment form view and application buttons
+ */
+class ImageAdjustmentView extends ImageFilterView {
+ /**
+ * @var class The class of the filter form.
+ */
+ static filterFormView = ImageAdjustmentFormView;
+
+ /**
+ * On construct, build form and bind submit
+ */
+ constructor(config, image, container) {
+ super(config, image, container);
+ this.setFilter({"filter": "adjust"});
+ }
+}
+
+export {
+ ImageFilterView,
+ ImageAdjustmentView
+};
diff --git a/src/js/view/samples/viewer.mjs b/src/js/view/samples/viewer.mjs
new file mode 100644
index 00000000..541fcc10
--- /dev/null
+++ b/src/js/view/samples/viewer.mjs
@@ -0,0 +1,189 @@
+/** @module view/samples/viewer */
+import { isEmpty, isEquivalent } from "../../base/helpers.mjs";
+import { ElementBuilder } from "../../base/builder.mjs";
+import { SimpleNotification } from "../../common/notify.mjs";
+import { View } from "../../view/base.mjs";
+import { ImageView } from "../../view/image.mjs";
+import { AnimationView } from "../../view/animation.mjs";
+import { ToolbarView } from "../../view/menu.mjs";
+import {
+ UpscaleFormView,
+ DownscaleFormView
+} from "../../forms/enfugue/upscale.mjs";
+import {
+ ImageAdjustmentView,
+ ImageFilterView
+} from "./filter.mjs";
+
+const E = new ElementBuilder();
+
+/**
+ * This view represents the visible image(s) on the canvas
+ */
+class SampleView extends View {
+ /**
+ * Constructed by the editor, pass reference so we can call other functions
+ */
+ constructor(config) {
+ super(config);
+ this.animationViews = (new Array(9)).fill(null).map(() => new AnimationView(this.config));
+ this.imageViews = (new Array(9)).fill(null).map((_, i) => new ImageView(this.config, null, i === 4));
+ this.image = null;
+ this.tileHorizontal = false;
+ this.tileVertical = false;
+ }
+
+ /**
+ * @var string The tag name
+ */
+ static tagName = "enfugue-sample";
+
+ /**
+ * @return int width of the sample(s)
+ */
+ get width() {
+ if (isEmpty(this.image)) {
+ return null;
+ }
+ if (Array.isArray(this.image)) {
+ return this.animationViews[4].width;
+ }
+ return this.imageViews[4].width;
+ }
+
+ /**
+ * @return int height of the sample(s)
+ */
+ get height() {
+ if (isEmpty(this.image)) {
+ return null;
+ }
+ if (Array.isArray(this.image)) {
+ return this.animationViews[4].height;
+ }
+ return this.imageViews[4].height;
+ }
+
+ /**
+ * Checks and shows what should be shown (if anything)
+ */
+ checkVisibility() {
+ for (let i = 0; i < 9; i++) {
+ let imageView = this.imageViews[i],
+ animationView = this.animationViews[i],
+ isVisible = true;
+
+ switch (i) {
+ case 0:
+ case 2:
+ case 6:
+ case 8:
+ isVisible = this.tileHorizontal && this.tileVertical;
+ break;
+ case 1:
+ case 7:
+ isVisible = this.tileVertical;
+ break;
+ case 3:
+ case 5:
+ isVisible = this.tileHorizontal;
+ break;
+ }
+
+ if (!isVisible) {
+ imageView.hide();
+ animationView.hide();
+ } else if (Array.isArray(this.image)) {
+ imageView.hide();
+ animationView.show();
+ } else {
+ imageView.show();
+ animationView.hide();
+ }
+ }
+ }
+
+ /**
+ * Gets the image view as a blob
+ */
+ async getBlob() {
+ return await this.imageViews[4].getBlob();
+ }
+
+ /**
+ * Gets the image view as a data URL
+ */
+ getDataURL() {
+ return this.imageViews[4].getDataURL()
+ }
+
+ /**
+ * Sets the image, either a single image or multiple
+ */
+ setImage(image) {
+ if (isEquivalent(this.image, image)) {
+ return;
+ }
+ this.image = image;
+ if (Array.isArray(image)) {
+ for (let animationView of this.animationViews) {
+ animationView.setImages(image);
+ }
+ window.requestAnimationFrame(() => {
+ this.checkVisibility();
+ window.requestAnimationFrame(() => {
+ this.show();
+ });
+ });
+ } else if (!isEmpty(this.image)) {
+ for (let imageView of this.imageViews) {
+ imageView.setImage(image);
+ }
+ Promise.all(this.imageViews.map((v) => v.waitForLoad())).then(() => {
+ this.checkVisibility();
+ window.requestAnimationFrame(() => {
+ this.show();
+ });
+ });
+ } else {
+ this.hide();
+ for (let animationView of this.animationViews) {
+ animationView.clearCanvas();
+ }
+ }
+ }
+
+ /**
+ * Sets the frame for animations
+ */
+ setFrame(frame) {
+ this.show();
+ for (let animationView of this.animationViews) {
+ animationView.setFrame(frame);
+ }
+ }
+
+ /**
+ * On build, add image and animation containers
+ */
+ async build() {
+ let node = await super.build(),
+ imageContainer = E.div().class("images-container"),
+ animationContainer = E.div().class("animation-container");
+
+ for (let imageView of this.imageViews) {
+ imageView.hide();
+ imageContainer.append(await imageView.getNode());
+ }
+
+ for (let animationView of this.animationViews) {
+ animationView.hide();
+ animationContainer.append(await animationView.getNode());
+ }
+
+ node.content(imageContainer, animationContainer);
+ return node;
+ }
+};
+
+export { SampleView };
diff --git a/src/js/view/scribble.mjs b/src/js/view/scribble.mjs
index 969c73d1..5353ba0b 100644
--- a/src/js/view/scribble.mjs
+++ b/src/js/view/scribble.mjs
@@ -29,11 +29,12 @@ class ScribbleView extends View {
/**
* Allows for a simple 'scribble' interface, a canvas that can be painted on in pure white/black.
*/
- constructor(config, width, height) {
+ constructor(config, width, height, invert = true) {
super(config);
this.width = width;
this.height = height;
this.active = false;
+ this.invert = invert;
this.shape = this.constructor.defaultPencilShape;
this.size = this.constructor.defaultPencilSize;
@@ -42,6 +43,8 @@ class ScribbleView extends View {
this.memoryCanvas = document.createElement("canvas");
this.visibleCanvas = document.createElement("canvas");
+ this.onDrawCallbacks = [];
+
if (!isEmpty(width) && !isEmpty(height)) {
this.memoryCanvas.width = width;
this.memoryCanvas.height = height;
@@ -50,6 +53,40 @@ class ScribbleView extends View {
}
}
+ /**
+ * Gets the active color
+ */
+ get activeColor() {
+ return this.invert
+ ? "white"
+ : "black";
+ }
+
+ /**
+ * Gets the background color
+ */
+ get backgroundColor() {
+ return this.invert
+ ? "black"
+ : "white";
+ }
+
+ /**
+ * Adds a drawing callback
+ */
+ onDraw(callback) {
+ this.onDrawCallbacks.push(callback);
+ }
+
+ /**
+ * Triggers draw callbacks
+ */
+ drawn() {
+ for (let callback of this.onDrawCallbacks) {
+ callback();
+ }
+ }
+
/**
* Gets the canvas image as a data URL.
* We use the visible canvas so that we crop appropriately.
@@ -59,14 +96,50 @@ class ScribbleView extends View {
return this.visibleCanvas.toDataURL();
}
+ /**
+ * Gets the inverted canvas image as a data URL.
+ */
+ get invertSrc() {
+ let canvas = document.createElement("canvas");
+ canvas.width = this.visibleCanvas.width;
+ canvas.height = this.visibleCanvas.height
+ let context = canvas.getContext("2d");
+
+ context.drawImage(this.visibleCanvas, 0, 0);
+ context.globalCompositeOperation = "difference";
+ context.fillStyle = "white";
+ context.fillRect(0, 0, canvas.width, canvas.height);
+
+ return canvas.toDataURL();
+ }
+
/**
* Clears the canvas in memory.
*/
clearMemory() {
let memoryContext = this.memoryCanvas.getContext("2d");
- memoryContext.fillStyle = "#ffffff";
+ memoryContext.fillStyle = this.backgroundColor;
+ memoryContext.fillRect(0, 0, this.memoryCanvas.width, this.memoryCanvas.height);
+ this.updateVisibleCanvas();
+ this.drawn();
+ }
+
+ /**
+ * Fills the canvas in memory.
+ */
+ fillMemory() {
+ let memoryContext = this.memoryCanvas.getContext("2d");
+ memoryContext.fillStyle = this.activeColor;
memoryContext.fillRect(0, 0, this.memoryCanvas.width, this.memoryCanvas.height);
this.updateVisibleCanvas();
+ this.drawn();
+ }
+
+ /**
+ * Inverts the canvas in memory.
+ */
+ invertMemory() {
+ this.setMemory(this.invertSrc);
}
/**
@@ -85,6 +158,7 @@ class ScribbleView extends View {
this.memoryCanvas = newMemoryCanvas;
this.updateVisibleCanvas();
+ this.drawn();
}
/**
@@ -101,12 +175,13 @@ class ScribbleView extends View {
newMemoryCanvas.width = width;
newMemoryCanvas.height = height;
let newMemoryContext = newMemoryCanvas.getContext("2d");
- newMemoryContext.fillStyle = "#ffffff";
+ newMemoryContext.fillStyle = this.backgroundColor;
newMemoryContext.fillRect(0, 0, width, height);
newMemoryContext.drawImage(this.memoryCanvas, 0, 0);
this.memoryCanvas = newMemoryCanvas;
}
this.updateVisibleCanvas();
+ this.drawn();
}
/**
@@ -116,7 +191,7 @@ class ScribbleView extends View {
let canvasContext = this.visibleCanvas.getContext("2d");
canvasContext.beginPath();
canvasContext.rect(0, 0, this.width, this.height);
- canvasContext.fillStyle = "white";
+ canvasContext.fillStyle = this.backgroundColor;
canvasContext.fill();
canvasContext.drawImage(this.memoryCanvas, 0, 0);
}
@@ -236,10 +311,11 @@ class ScribbleView extends View {
context.save();
this.drawPencilShape(context, x, y);
context.clip();
- context.fillStyle = "#ffffff";
+ context.fillStyle = this.backgroundColor;
context.fillRect(0, 0, this.memoryCanvas.width, this.memoryCanvas.height);
context.restore();
this.updateVisibleCanvas();
+ this.drawn();
}
/**
@@ -248,12 +324,13 @@ class ScribbleView extends View {
drawMemory(x, y) {
let context = this.memoryCanvas.getContext("2d");
this.drawPencilShape(context, x, y);
- context.fillStyle = "#000000";
+ context.fillStyle = this.activeColor;
context.fill();
this.updateVisibleCanvas();
this.lastX = x;
this.lastY = y;
this.lastDrawTime = (new Date()).getTime();
+ this.drawn();
}
/**
@@ -264,12 +341,12 @@ class ScribbleView extends View {
let context = this.visibleCanvas.getContext("2d");
this.size -= 1;
this.drawPencilShape(context, x, y);
- context.strokeStyle = "#ffffff";
+ context.strokeStyle = this.backgroundColor;
context.lineWidth = 1;
context.stroke();
this.size += 1;
this.drawPencilShape(context, x, y);
- context.strokeStyle = "#000000";
+ context.strokeStyle = this.activeColor;
context.lineWidth = 1;
context.stroke();
}
@@ -282,9 +359,10 @@ class ScribbleView extends View {
context.beginPath();
context.moveTo(this.lastX, this.lastY);
context.lineTo(x, y);
- context.strokeStyle = "#000000";
+ context.strokeStyle = this.activeColor;
context.lineWidth = this.size;
context.stroke();
+ this.drawn();
}
/**
@@ -300,6 +378,7 @@ class ScribbleView extends View {
top = Math.max(0, y - this.size / 2),
right = Math.min(left + this.size, this.width),
bottom = Math.min(top + this.size, this.height);
+
context.moveTo(left, top);
context.lineTo(right, top);
context.lineTo(right, bottom);
diff --git a/src/js/view/video.mjs b/src/js/view/video.mjs
new file mode 100644
index 00000000..2b6830a3
--- /dev/null
+++ b/src/js/view/video.mjs
@@ -0,0 +1,99 @@
+/** @module view/video */
+import { View } from "./base.mjs";
+import { waitFor, isEmpty } from "../base/helpers.mjs";
+import { ElementBuilder } from "../base/builder.mjs";
+
+const E = new ElementBuilder();
+
+/**
+ * The VideoView mimics the capabilities of the ImageView
+ */
+class VideoView extends View {
+ /**
+ * @var string Tagname, we don't use view for a video
+ */
+ static tagName = "enfugue-video-view";
+
+ /**
+ * Construct with source
+ */
+ constructor(config, src) {
+ super(config);
+ this.loaded = false;
+ this.loadedCallbacks = [];
+ this.setVideo(src);
+ }
+
+ /**
+ * Adds a callback to the list of loaded callbacks
+ */
+ onLoad(callback) {
+ if (this.loaded) {
+ callback(this);
+ } else {
+ this.loadedCallbacks.push(callback);
+ }
+ }
+
+ /**
+ * Wait for the video to be loaded
+ */
+ waitForLoad() {
+ return waitFor(() => this.loaded);
+ }
+
+ /**
+ * Sets the video source after initialization
+ */
+ setVideo(src) {
+ if (this.src === src) {
+ return;
+ }
+ this.loaded = false;
+ this.src = src;
+ this.video = document.createElement("video");
+ this.video.onloadedmetadata = () => this.videoLoaded();
+ this.video.autoplay = true;
+ this.video.loop = true;
+ this.video.muted = true;
+ this.video.src = src;
+ }
+
+ /**
+ * Trigger video load callbacks
+ */
+ videoLoaded() {
+ this.loaded = true;
+ this.width = this.video.videoWidth;
+ this.height = this.video.videoHeight;
+ for (let callback of this.loadedCallbacks) {
+ callback();
+ }
+ }
+
+ /**
+ * Build the container and append the DOM node
+ */
+ async build() {
+ let node = await super.build();
+ node.content(this.video);
+ return node;
+ }
+}
+
+class VideoPlayerView extends VideoView {
+ setVideo(src) {
+ super.setVideo(src);
+ this.video.controls = true;
+ }
+
+ async build() {
+ let node = await super.build();
+ return node;
+ }
+}
+
+export {
+ VideoView,
+ VideoPlayerView
+};
diff --git a/src/python/enfugue/api/controller/downloads.py b/src/python/enfugue/api/controller/downloads.py
index 72bcf145..8759bcbc 100644
--- a/src/python/enfugue/api/controller/downloads.py
+++ b/src/python/enfugue/api/controller/downloads.py
@@ -90,6 +90,7 @@ def civitai_lookup(self, request: Request, response: Response, lookup: str) -> L
"poses": "Poses",
"hypernetwork": "Hypetnetwork",
"gradient": "AestheticGradient",
+ "motion": "MotionModule",
}.get(lookup, None)
if lookup_type is None:
diff --git a/src/python/enfugue/api/controller/invocation.py b/src/python/enfugue/api/controller/invocation.py
index f6f92907..67025b31 100644
--- a/src/python/enfugue/api/controller/invocation.py
+++ b/src/python/enfugue/api/controller/invocation.py
@@ -1,9 +1,12 @@
+from __future__ import annotations
+
import os
import glob
import PIL
import PIL.Image
-from typing import Dict, List, Any, Union, Tuple, Optional
+from typing import Dict, List, Any, Union, Tuple, Optional, TYPE_CHECKING
+
from webob import Request, Response
from pibble.ext.user.server.base import (
@@ -16,17 +19,14 @@
from pibble.api.middleware.database.orm import ORMMiddlewareBase
from pibble.api.exceptions import NotFoundError, BadRequestError
-from enfugue.diffusion.plan import DiffusionPlan
-from enfugue.diffusion.constants import (
- DEFAULT_MODEL,
- DEFAULT_INPAINTING_MODEL,
- DEFAULT_SDXL_MODEL,
- DEFAULT_SDXL_REFINER,
- DEFAULT_SDXL_INPAINTING_MODEL,
-)
+from enfugue.diffusion.invocation import LayeredInvocation
+from enfugue.diffusion.constants import *
from enfugue.util import find_file_in_directory
from enfugue.api.controller.base import EnfugueAPIControllerBase
+if TYPE_CHECKING:
+ import cv2
+
__all__ = ["EnfugueAPIInvocationController"]
DEFAULT_MODEL_CKPT = os.path.basename(DEFAULT_MODEL)
@@ -34,6 +34,14 @@
DEFAULT_SDXL_MODEL_CKPT = os.path.basename(DEFAULT_SDXL_MODEL)
DEFAULT_SDXL_REFINER_CKPT = os.path.basename(DEFAULT_SDXL_REFINER)
DEFAULT_SDXL_INPAINTING_CKPT = os.path.basename(DEFAULT_SDXL_INPAINTING_MODEL)
+MOTION_LORA_ZOOM_OUT_CKPT = os.path.basename(MOTION_LORA_ZOOM_OUT)
+MOTION_LORA_ZOOM_IN_CKPT = os.path.basename(MOTION_LORA_ZOOM_IN)
+MOTION_LORA_PAN_LEFT_CKPT = os.path.basename(MOTION_LORA_PAN_LEFT)
+MOTION_LORA_PAN_RIGHT_CKPT = os.path.basename(MOTION_LORA_PAN_RIGHT)
+MOTION_LORA_TILT_UP_CKPT = os.path.basename(MOTION_LORA_TILT_UP)
+MOTION_LORA_TILT_DOWN_CKPT = os.path.basename(MOTION_LORA_TILT_DOWN)
+MOTION_LORA_ROLL_CLOCKWISE_CKPT = os.path.basename(MOTION_LORA_ROLL_CLOCKWISE)
+MOTION_LORA_ROLL_ANTI_CLOCKWISE_CKPT = os.path.basename(MOTION_LORA_ROLL_ANTI_CLOCKWISE)
class EnfugueAPIInvocationController(EnfugueAPIControllerBase):
handlers = UserExtensionHandlerRegistry()
@@ -43,7 +51,7 @@ def thumbnail_height(self) -> int:
"""
Gets the height of thumbnails.
"""
- return self.configuration.get("enfugue.thumbnail", 200)
+ return self.configuration.get("enfugue.thumbnail", 100)
def get_default_model(self, model: str) -> Optional[str]:
"""
@@ -60,6 +68,22 @@ def get_default_model(self, model: str) -> Optional[str]:
return DEFAULT_SDXL_REFINER
if base_model_name == DEFAULT_SDXL_INPAINTING_CKPT:
return DEFAULT_SDXL_INPAINTING_MODEL
+ if base_model_name == MOTION_LORA_ZOOM_OUT_CKPT:
+ return MOTION_LORA_ZOOM_OUT
+ if base_model_name == MOTION_LORA_ZOOM_IN_CKPT:
+ return MOTION_LORA_ZOOM_IN
+ if base_model_name == MOTION_LORA_PAN_LEFT_CKPT:
+ return MOTION_LORA_PAN_LEFT
+ if base_model_name == MOTION_LORA_PAN_RIGHT_CKPT:
+ return MOTION_LORA_PAN_RIGHT
+ if base_model_name == MOTION_LORA_TILT_UP_CKPT:
+ return MOTION_LORA_TILT_UP
+ if base_model_name == MOTION_LORA_TILT_DOWN_CKPT:
+ return MOTION_LORA_TILT_DOWN
+ if base_model_name == MOTION_LORA_ROLL_CLOCKWISE_CKPT:
+ return MOTION_LORA_ROLL_CLOCKWISE
+ if base_model_name == MOTION_LORA_ROLL_ANTI_CLOCKWISE_CKPT:
+ return MOTION_LORA_ROLL_ANTI_CLOCKWISE
return None
def get_default_size_for_model(self, model: Optional[str]) -> int:
@@ -120,17 +144,95 @@ def check_find_adaptations(
model_name = model.get("model", None)
model_weight = model.get("weight", 1.0)
if not model_name:
- raise BadRequestError(f"Bad model format for type `{model_type}` - missing required dictionary key `model`")
+ return []
if is_weighted:
return [(self.check_find_model(model_type, model_name), model_weight)]
return [self.check_find_model(model_type, model_name)]
elif isinstance(model, list):
models = []
for item in model:
- models.extend(self.check_find_adaptations(model_type, is_weighted, item))
+ models.extend(
+ self.check_find_adaptations(model_type, is_weighted, item)
+ )
return models
raise BadRequestError(f"Bad format for {model_type} - must be either a single string, a dictionary with the key `model` and optionally `weight`, or a list of the same (got {model})")
+ def convert_animation(
+ self,
+ source_path: str,
+ dest_path: str,
+ rate: float,
+ ) -> str:
+ """
+ Converts animation file formats
+ """
+ def on_open(capture: cv2.VideoCapture) -> None:
+ nonlocal rate
+ import cv2
+ rate = capture.get(cv2.CAP_PROP_FPS)
+
+ from enfugue.diffusion.util import Video
+ frames = [
+ frame for frame in
+ Video.file_to_frames(
+ source_path,
+ on_open=on_open,
+ )
+ ] # Memoize so we capture rate
+ Video(frames).save(
+ dest_path,
+ rate=rate,
+ overwrite=True
+ )
+
+ return dest_path
+
+ def get_animation(
+ self,
+ file_path: str,
+ rate: float=8.0,
+ overwrite: bool=False,
+ ) -> str:
+ """
+ Gets an animation
+ """
+ video_path = os.path.join(self.manager.engine_image_dir, file_path)
+ base, ext = os.path.splitext(video_path)
+ if not os.path.exists(video_path) or overwrite:
+ if ext != ".mp4":
+ # Look for mp4
+ mp4_path = f"{base}.mp4"
+ if os.path.exists(mp4_path):
+ return self.convert_animation(mp4_path, video_path, rate)
+
+ from enfugue.diffusion.util import Video
+
+ images = []
+ image_id, _ = os.path.splitext(os.path.basename(video_path))
+ frame = 0
+
+ while True:
+ image_path = os.path.join(self.manager.engine_image_dir, f"{image_id}_{frame}.png")
+ if not os.path.exists(image_path):
+ break
+ images.append(image_path)
+ frame += 1
+
+ if not images:
+ raise NotFoundError(f"No images for ID {image_id}")
+
+ frames = [
+ PIL.Image.open(image) for image in images
+ ]
+
+ Video(frames).save(
+ video_path,
+ rate=rate,
+ overwrite=True
+ )
+
+ return video_path
+
@handlers.path("^/api/invoke$")
@handlers.methods("POST")
@handlers.format()
@@ -164,8 +266,10 @@ def invoke_engine(self, request: Request, response: Response) -> Dict[str, Any]:
"inpainter",
"inpainter_size",
"inpainter_vae",
+ "motion_module",
]:
request.parsed.pop(ignored_arg, None)
+
elif model_name and model_type in ["checkpoint", "diffusers", "checkpoint+diffusers"]:
if model_type == "diffusers":
plan_kwargs["model"] = model_name # Hope for the best
@@ -207,23 +311,31 @@ def invoke_engine(self, request: Request, response: Response) -> Dict[str, Any]:
scheduler = request.parsed.pop("scheduler", None)
if scheduler:
plan_kwargs["scheduler"] = scheduler
+
disable_decoding = request.parsed.pop("intermediates", None) == False
ui_state: Optional[str] = None
+ video_rate: Optional[float] = None
+
for key, value in request.parsed.items():
if key == "state":
ui_state = value
+ elif key == "frame_rate":
+ video_rate = value
elif value is not None:
plan_kwargs[key] = value
if not plan_kwargs.get("size", None):
- plan_kwargs["size"] = self.get_default_size_for_model(plan_kwargs.get("model", None))
+ plan_kwargs["size"] = self.get_default_size_for_model(
+ plan_kwargs.get("model", None)
+ )
- plan = DiffusionPlan.assemble(**plan_kwargs)
+ plan = LayeredInvocation.assemble(**plan_kwargs)
return self.invoke(
request.token.user.id,
plan,
ui_state=ui_state,
+ video_rate=video_rate,
disable_intermediate_decoding=disable_decoding
).format()
@@ -280,12 +392,40 @@ def download_image(self, request: Request, response: Response, file_path: str) -
raise NotFoundError(f"No image at {file_path}")
return image_path
+ @handlers.path("^/api/invocation/animation/images/(?P.+)$")
+ @handlers.download()
+ @handlers.methods("GET")
+ @handlers.compress()
+ @handlers.cache()
+ @handlers.reverse("Animation", "/api/invocation/animation/images/{file_path}")
+ @handlers.bypass(
+ UserRESTExtensionServerBase,
+ UserExtensionServerBase,
+ ORMMiddlewareBase,
+ SessionExtensionServerBase,
+ UserExtensionTemplateServer,
+ ) # bypass processing for speed
+ def download_animation(self, request: Request, response: Response, file_path: str) -> str:
+ """
+ Downloads all results of an invocation as a video
+ """
+ video_path = os.path.join(self.manager.engine_image_dir, file_path)
+ try:
+ rate = float(request.params.get("rate", 8.0))
+ except:
+ rate = 8.0
+ return self.get_animation(
+ file_path,
+ rate=rate,
+ overwrite=bool(request.params.get("overwrite", 0))
+ )
+
@handlers.path("^/api/invocation/thumbnails/(?P.+)$")
@handlers.download()
@handlers.methods("GET")
@handlers.compress()
@handlers.cache()
- @handlers.reverse("Image", "/api/invocation/thumbnails/{file_path}")
+ @handlers.reverse("Thumbnail", "/api/invocation/thumbnails/{file_path}")
@handlers.bypass(
UserRESTExtensionServerBase,
UserExtensionServerBase,
@@ -300,6 +440,7 @@ def download_image_thumbnail(self, request: Request, response: Response, file_pa
image_path = os.path.join(self.manager.engine_image_dir, file_path)
if not os.path.exists(image_path):
raise NotFoundError(f"No image at {file_path}")
+
image_name, ext = os.path.splitext(os.path.basename(image_path))
thumbnail_path = os.path.join(self.manager.engine_image_dir, f"{image_name}_thumb{ext}")
if not os.path.exists(thumbnail_path):
@@ -309,6 +450,59 @@ def download_image_thumbnail(self, request: Request, response: Response, file_pa
image.resize((int(width * scale), int(height * scale))).save(thumbnail_path)
return thumbnail_path
+ @handlers.path("^/api/invocation/animation/thumbnails/(?P.+)$")
+ @handlers.download()
+ @handlers.methods("GET")
+ @handlers.compress()
+ @handlers.cache()
+ @handlers.reverse("AnimationThumbnail", "/api/invocation/animation/thumbnails/{file_path}")
+ @handlers.bypass(
+ UserRESTExtensionServerBase,
+ UserExtensionServerBase,
+ ORMMiddlewareBase,
+ SessionExtensionServerBase,
+ UserExtensionTemplateServer,
+ ) # bypass processing for speed
+ def download_animation_thumbnail(self, request: Request, response: Response, file_path: str) -> str:
+ """
+ Downloads all results of an invocation as a thumbnail video
+ """
+ image_name, ext = os.path.splitext(os.path.basename(file_path))
+ thumbnail_path = os.path.join(self.manager.engine_image_dir, f"{image_name}_thumb{ext}")
+
+ if not os.path.exists(thumbnail_path):
+ try:
+ rate = float(request.params.get("rate", 8.0))
+ except:
+ rate = 8.0
+
+ video_path = self.get_animation(
+ file_path,
+ rate=rate,
+ )
+
+ def on_open(capture: cv2.VideoCapture) -> None:
+ nonlocal rate
+ import cv2
+ rate = capture.get(cv2.CAP_PROP_FPS)
+
+ from enfugue.diffusion.util import Video
+ frames = [
+ frame for frame in
+ Video.file_to_frames(
+ video_path,
+ on_open=on_open,
+ resolution=self.thumbnail_height
+ )
+ ] # Memoize so we capture rate
+ Video(frames).save(
+ thumbnail_path,
+ rate=rate,
+ overwrite=True
+ )
+
+ return thumbnail_path
+
@handlers.path("^/api/invocation/nsfw$")
@handlers.methods("GET")
@handlers.cache()
@@ -376,7 +570,7 @@ def delete_invocation(self, request: Request, response: Response, uuid: str) ->
raise NotFoundError(f"No invocation with ID {uuid}")
for dirname in [self.manager.engine_image_dir, self.manager.engine_intermediate_dir]:
- for invocation_image in glob.glob(f"{uuid}*.png", root_dir=dirname):
+ for invocation_image in glob.glob(f"{uuid}*.*", root_dir=dirname):
os.remove(os.path.join(dirname, invocation_image))
self.database.delete(database_invocation)
self.database.commit()
@@ -390,3 +584,4 @@ def stop_engine(self, request: Request, response: Response) -> None:
Stops the engine and any invocations.
"""
self.manager.stop_engine()
+ self.manager.stop_interpolator()
diff --git a/src/python/enfugue/api/controller/models.py b/src/python/enfugue/api/controller/models.py
index bbf71bee..cf2aad2a 100644
--- a/src/python/enfugue/api/controller/models.py
+++ b/src/python/enfugue/api/controller/models.py
@@ -1,41 +1,36 @@
+from __future__ import annotations
+
+import re
import os
import glob
import PIL
import PIL.Image
import shutil
-from typing import List, Dict, Any
+from typing import List, Dict, Optional, Any
from webob import Request, Response
from pibble.api.exceptions import BadRequestError, NotFoundError
from pibble.util.files import load_json
from pibble.ext.user.server.base import UserExtensionHandlerRegistry
-from enfugue.util import find_files_in_directory, find_file_in_directory
+from enfugue.util import find_files_in_directory, find_file_in_directory, logger
from enfugue.api.controller.base import EnfugueAPIControllerBase
from enfugue.database.models import DiffusionModel
from enfugue.diffusion.manager import DiffusionPipelineManager
-from enfugue.diffusion.plan import DiffusionPlan, DiffusionStep, DiffusionNode
-from enfugue.diffusion.constants import (
- DEFAULT_MODEL,
- DEFAULT_INPAINTING_MODEL,
- DEFAULT_SDXL_MODEL,
- DEFAULT_SDXL_REFINER,
- DEFAULT_SDXL_INPAINTING_MODEL,
-)
+from enfugue.diffusion.invocation import LayeredInvocation
+from enfugue.diffusion.constants import *
__all__ = ["EnfugueAPIModelsController"]
class EnfugueAPIModelsController(EnfugueAPIControllerBase):
- handlers = UserExtensionHandlerRegistry()
+ XL_BASE_KEY = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
+ XL_REFINER_KEY = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
+
+ INPUT_BLOCK_KEY = "model.diffusion_model.input_blocks.0.0.weight"
MODEL_DEFAULT_FIELDS = [
- "width",
- "height",
- "chunking_size",
- "chunking_mask_type",
- "chunking_mask_kwargs",
"num_inference_steps",
"guidance_scale",
"refiner_start",
@@ -48,8 +43,7 @@ class EnfugueAPIModelsController(EnfugueAPIControllerBase):
"refiner_negative_prompt",
"refiner_negative_prompt_2",
"prompt_2",
- "negative_prompt_2",
- "upscale_steps",
+ "negative_prompt_2"
]
DEFAULT_CHECKPOINTS = [
@@ -60,7 +54,21 @@ class EnfugueAPIModelsController(EnfugueAPIControllerBase):
os.path.basename(DEFAULT_SDXL_INPAINTING_MODEL),
]
- def get_models_in_directory(self, directory: str) -> List[str]:
+ DEFAULT_LORA = [
+ os.path.basename(MOTION_LORA_PAN_LEFT),
+ os.path.basename(MOTION_LORA_PAN_RIGHT),
+ os.path.basename(MOTION_LORA_ROLL_CLOCKWISE),
+ os.path.basename(MOTION_LORA_ROLL_ANTI_CLOCKWISE),
+ os.path.basename(MOTION_LORA_TILT_UP),
+ os.path.basename(MOTION_LORA_TILT_DOWN),
+ os.path.basename(MOTION_LORA_ZOOM_IN),
+ os.path.basename(MOTION_LORA_ZOOM_OUT),
+ ]
+
+ handlers = UserExtensionHandlerRegistry()
+
+ @staticmethod
+ def get_models_in_directory(directory: str) -> List[str]:
"""
Gets stored AI model networks in a directory (.safetensors, .ckpt, etc.)
"""
@@ -71,6 +79,78 @@ def get_models_in_directory(self, directory: str) -> List[str]:
)
)
+ def check_name(self, name: str) -> None:
+ """
+ Raises an exception if a name contains invalid characters
+ """
+ if re.match(r".*[./\\].*", name):
+ raise BadRequestError("Name cannot contain the following characters: ./\\")
+
+ def get_model_metadata(self, model: str) -> Optional[Dict[str, Any]]:
+ """
+ Gets metadata for a checkpoint if it exists
+ ONLY reads safetensors files
+ """
+ directory = self.get_configured_directory("checkpoint")
+ model_path = find_file_in_directory(directory, model)
+ if model_path is None:
+ return None
+
+ # Start with dumb name checks, we'll do better checks in a moment
+ model_name, model_ext = os.path.splitext(model)
+ model_metadata = {
+ "xl": "xl" in model_name.lower(),
+ "refiner": "refiner" in model_name.lower(),
+ "inpainter": "inpaint" in model_name.lower()
+ }
+
+ if model_name in self.get_diffusers_models():
+ # Read diffusers cache
+ diffusers_cache_dir = os.path.join(self.get_configured_directory("diffusers"), model_name)
+ model_metadata["xl"] = os.path.exists(os.path.join(diffusers_cache_dir, "text_encoder_2"))
+ model_metadata["refiner"] = (
+ model_metadata["xl"] and not
+ os.path.exists(os.path.join(diffusers_cache_dir, "text_encoder"))
+ )
+ unet_config_file = os.path.join(diffusers_cache_dir, "unet", "config.json")
+ if os.path.exists(unet_config_file):
+ unet_config = load_json(unet_config_file)
+ model_metadata["inpainter"] = unet_config.get("in_channels", 4) == 9
+ else:
+ logger.warning(f"Diffusers model {model_name} with no UNet is unexpected, errors may occur.")
+ model_metadata["inpainter"] = False
+ elif model_ext == ".safetensors":
+ # Reads safetensors metadata
+ import safetensors
+ with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
+ keys = list(f.keys())
+ xl_base = self.XL_BASE_KEY in keys
+ xl_refiner = self.XL_REFINER_KEY in keys
+ model_metadata["xl"] = xl_base or xl_refiner
+ model_metadata["refiner"] = xl_refiner
+ if self.INPUT_BLOCK_KEY in keys:
+ input_weights = f.get_tensor(self.INPUT_BLOCK_KEY)
+ model_metadata["inpainter"] = input_weights.shape[1] == 9 # type: ignore[union-attr]
+ else:
+ logger.warning(f"Checkpoint file {model_path} with no input block shape is unexpected, errors may occur.")
+ model_metadata["inpainter"] = False
+
+ return model_metadata
+
+ def get_diffusers_models(self) -> List[str]:
+ """
+ Gets diffusers models in the configured directory.
+ """
+ diffusers_dir = self.get_configured_directory("diffusers")
+ diffusers_models = []
+ if os.path.exists(diffusers_dir):
+ diffusers_models = [
+ dirname
+ for dirname in os.listdir(diffusers_dir)
+ if os.path.exists(os.path.join(diffusers_dir, dirname, "model_index.json"))
+ ]
+ return diffusers_models
+
@handlers.path("^/api/checkpoints$")
@handlers.methods("GET")
@handlers.format()
@@ -108,6 +188,10 @@ def get_lora(self, request: Request, response: Response) -> List[Dict[str, Any]]
}
for filename in self.get_models_in_directory(lora_dir)
]
+ for default_lora in self.DEFAULT_LORA:
+ if default_lora not in [l["name"] for l in lora]:
+ lora.append({"name": default_lora, "directory": "available for download"})
+
return lora
@handlers.path("^/api/lycoris$")
@@ -146,6 +230,24 @@ def get_inversions(self, request: Request, response: Response) -> List[Dict[str,
]
return inversions
+ @handlers.path("^/api/motion$")
+ @handlers.methods("GET")
+ @handlers.format()
+ @handlers.secured()
+ def get_motion(self, request: Request, response: Response) -> List[Dict[str, Any]]:
+ """
+ Gets installed motion modules
+ """
+ motion_dir = self.get_configured_directory("motion")
+ motion = [
+ {
+ "name": os.path.basename(filename),
+ "directory": os.path.relpath(os.path.dirname(filename), motion_dir)
+ }
+ for filename in self.get_models_in_directory(motion_dir)
+ ]
+ return motion
+
@handlers.path("^/api/tensorrt$")
@handlers.methods("GET")
@handlers.format()
@@ -269,21 +371,30 @@ def delete_tensorrt_engine(
raise NotFoundError(f"Couldn't find {engine_type} TensorRT engine for {model_name} with key {engine_key}")
shutil.rmtree(engine_dir)
- @handlers.path("^/api/models/(?P[^\/]+)/status$")
+ @handlers.path("^/api/models/(?P[^\/]+)/status$")
@handlers.methods("GET")
@handlers.format()
@handlers.secured("DiffusionModel", "read")
- def get_model_status(self, request: Request, response: Response, model_name: str) -> Dict[str, Any]:
+ def get_model_status(self, request: Request, response: Response, model_name_or_ckpt: str) -> Dict[str, Any]:
"""
Gets status for a particular model
"""
+ if "." in model_name_or_ckpt:
+ return {
+ "model": model_name_or_ckpt,
+ "metadata": {
+ "base": self.get_model_metadata(model_name_or_ckpt)
+ }
+ }
+
model = (
self.database.query(self.orm.DiffusionModel)
- .filter(self.orm.DiffusionModel.name == model_name)
+ .filter(self.orm.DiffusionModel.name == model_name_or_ckpt)
.one_or_none()
)
+
if not model:
- raise NotFoundError(f"No model named {model_name}")
+ raise NotFoundError(f"No model named {model_name_or_ckpt}")
main_model_status = DiffusionPipelineManager.get_status(
self.engine_root,
@@ -293,9 +404,11 @@ def get_model_status(self, request: Request, response: Response, model_name: str
[(lycoris.model, lycoris.weight) for lycoris in model.lycoris],
[inversion.model for inversion in model.inversion],
)
+ main_model_metadata = self.get_model_metadata(model.model)
if model.inpainter:
inpainter_model = model.inpainter[0].model
+ inpainter_model_metadata = self.get_model_metadata(inpainter_model)
inpainter_model_status = DiffusionPipelineManager.get_status(
self.engine_root,
model.inpainter[0].model,
@@ -304,6 +417,7 @@ def get_model_status(self, request: Request, response: Response, model_name: str
else:
model_name, ext = os.path.splitext(model.model)
inpainter_model = f"{model_name}-inpainting{ext}"
+ inpainter_model_metadata = self.get_model_metadata(inpainter_model)
inpainter_model_status = DiffusionPipelineManager.get_status(
self.engine_root,
inpainter_model,
@@ -312,6 +426,7 @@ def get_model_status(self, request: Request, response: Response, model_name: str
if model.refiner:
refiner_model = model.refiner[0].model
+ refiner_model_metadata = self.get_model_metadata(refiner_model)
refiner_model_status = DiffusionPipelineManager.get_status(
self.engine_root,
refiner_model,
@@ -319,6 +434,7 @@ def get_model_status(self, request: Request, response: Response, model_name: str
)
else:
refiner_model = None
+ refiner_model_metadata = None
refiner_model_status = None
return {
@@ -330,6 +446,11 @@ def get_model_status(self, request: Request, response: Response, model_name: str
"inpainter": inpainter_model_status,
"refiner": refiner_model_status,
},
+ "metadata": {
+ "base": main_model_metadata,
+ "inpainter": inpainter_model_metadata,
+ "refiner": refiner_model_metadata
+ }
}
@handlers.path("^/api/models/(?P[^\/]+)/tensorrt/(?P[^\/]+)$")
@@ -342,29 +463,33 @@ def create_model_tensorrt_engine(
"""
Issues a job to create an engine.
"""
- plan = DiffusionPlan.assemble(**self.get_plan_kwargs_from_model(model_name, include_prompts=False))
- plan.build_tensorrt = True
+ plan = LayeredInvocation.assemble(**self.get_plan_kwargs_from_model(model_name, include_prompts=False))
+ if not plan.tiling_size:
+ raise ValueError("Tiling must be enabled for TensorRT.")
- step = DiffusionStep(prompt="a green field, blue sky, outside", width=plan.size, height=plan.size)
+ plan.build_tensorrt = True
network_name = network_name.lower()
if network_name == "inpaint_unet":
- step.image = PIL.Image.new("RGB", (plan.size, plan.size))
- step.mask = PIL.Image.new("RGB", (plan.size, plan.size))
- step.strength = 1.0
+ plan.layers = [{"image": PIL.Image.new("RGB", (plan.tiling_size, plan.tiling_size))}]
+ plan.mask = PIL.Image.new("RGB", (plan.tiling_size, plan.tiling_size))
+ plan.strength = 1.0
elif network_name == "controlled_unet":
- step.control_images = [{
- "controlnet": "canny",
- "image": PIL.Image.new("RGB", (plan.size, plan.size)),
- "scale": 1.0,
- "process": True,
- "invert": False,
+ plan.layers = [{
+ "control_units": [
+ {
+ "controlnet": "canny",
+ "scale": 1.0,
+ "process": True,
+ }
+ ],
+ "image": PIL.Image.new("RGB", (plan.tiling_size, plan.tiling_size)),
}]
elif network_name != "unet":
raise BadRequestError(f"Unknown or unsupported network {network_name}")
build_metadata = {"model": model_name, "network": network_name}
- plan.nodes = [DiffusionNode([(0, 0), (plan.size, plan.size)], step)]
+
return self.invoke(
request.token.user.id,
plan,
@@ -379,8 +504,12 @@ def create_model_tensorrt_engine(
@handlers.secured("DiffusionModel", "update")
def modify_model(self, request: Request, response: Response, model_name: str) -> DiffusionModel:
"""
- Asks the pipeline manager for information about models.
+ Modifies a model
"""
+ # Check arguments
+ if "name" in request.parsed:
+ self.check_name(request.parsed["name"])
+
model = (
self.database.query(self.orm.DiffusionModel)
.filter(self.orm.DiffusionModel.name == model_name)
@@ -411,6 +540,9 @@ def modify_model(self, request: Request, response: Response, model_name: str) ->
for existing_config in model.config:
self.database.delete(existing_config)
+ for existing_motion_module in model.motion_module:
+ self.database.delete(existing_motion_module)
+
for existing_vae in model.vae:
self.database.delete(existing_vae)
@@ -427,7 +559,7 @@ def modify_model(self, request: Request, response: Response, model_name: str) ->
model.size = request.parsed.get("size", model.size)
model.prompt = request.parsed.get("prompt", model.prompt)
model.negative_prompt = request.parsed.get("negative_prompt", model.negative_prompt)
-
+
self.database.commit()
refiner = request.parsed.get("refiner", None)
@@ -482,6 +614,15 @@ def modify_model(self, request: Request, response: Response, model_name: str) ->
)
)
+ motion_module = request.parsed.get("motion_module", None)
+ if motion_module:
+ self.database.add(
+ self.orm.DiffusionModelMotionModule(
+ diffusion_model_name=model_name,
+ name=motion_module,
+ )
+ )
+
for lora in request.parsed.get("lora", []):
new_lora = self.orm.DiffusionModelLora(
diffusion_model_name=model.name, model=lora["model"], weight=lora["weight"]
@@ -545,6 +686,8 @@ def delete_model(self, request: Request, response: Response, model_name: str) ->
self.database.delete(vae)
for vae in model.inpainter_vae:
self.database.delete(vae)
+ for motion_module in model.motion_module:
+ self.database.delete(motion_module)
for config in model.config:
self.database.delete(config)
@@ -561,80 +704,92 @@ def create_model(self, request: Request, response: Response) -> DiffusionModel:
Creates a new model.
"""
try:
- new_model = self.orm.DiffusionModel(
- name=request.parsed["name"],
- model=request.parsed["checkpoint"],
- size=request.parsed.get("size", 512),
- prompt=request.parsed.get("prompt", ""),
- negative_prompt=request.parsed.get("negative_prompt", ""),
+ name = request.parsed["name"]
+ model = request.parsed["checkpoint"]
+ size = request.parsed.get("size", 512)
+ prompt = request.parsed.get("prompt", "")
+ negative_prompt = request.parsed.get("negative_prompt", "")
+ self.check_name(name)
+ except KeyError as ex:
+ raise BadRequestError(f"Missing required parameter {ex}")
+
+ new_model = self.orm.DiffusionModel(
+ name=name,
+ model=model,
+ size=size,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ )
+ self.database.add(new_model)
+ self.database.commit()
+ refiner = request.parsed.get("refiner", None)
+ if refiner:
+ new_refiner = self.orm.DiffusionModelRefiner(
+ diffusion_model_name=new_model.name, model=refiner, size=request.parsed.get("refiner_size", None)
)
- self.database.add(new_model)
+ self.database.add(new_refiner)
self.database.commit()
- refiner = request.parsed.get("refiner", None)
- if refiner:
- new_refiner = self.orm.DiffusionModelRefiner(
- diffusion_model_name=new_model.name, model=refiner, size=request.parsed.get("refiner_size", None)
- )
- self.database.add(new_refiner)
- self.database.commit()
- inpainter = request.parsed.get("inpainter", None)
- if inpainter:
- new_inpainter = self.orm.DiffusionModelInpainter(
+ inpainter = request.parsed.get("inpainter", None)
+ if inpainter:
+ new_inpainter = self.orm.DiffusionModelInpainter(
+ diffusion_model_name=new_model.name,
+ model=inpainter,
+ size=request.parsed.get("inpainter_size", None),
+ )
+ self.database.add(new_inpainter)
+ self.database.commit()
+ scheduler = request.parsed.get("scheduler", None)
+ if scheduler:
+ new_scheduler = self.orm.DiffusionModelScheduler(diffusion_model_name=new_model.name, name=scheduler)
+ self.database.add(new_scheduler)
+ self.database.commit()
+ vae = request.parsed.get("vae", None)
+ if vae:
+ new_vae = self.orm.DiffusionModelVAE(diffusion_model_name=new_model.name, name=vae)
+ self.database.add(new_vae)
+ self.database.commit()
+ refiner_vae = request.parsed.get("refiner_vae", None)
+ if refiner_vae:
+ new_refiner_vae = self.orm.DiffusionModelRefinerVAE(diffusion_model_name=new_model.name, name=refiner_vae)
+ self.database.add(new_refiner_vae)
+ self.database.commit()
+ inpainter_vae = request.parsed.get("inpainter_vae", None)
+ if inpainter_vae:
+ new_inpainter_vae = self.orm.DiffusionModelInpainterVAE(diffusion_model_name=new_model.name, name=inpainter_vae)
+ self.database.add(new_inpainter_vae)
+ self.database.commit()
+ motion_module = request.parsed.get("motion_module", None)
+ if motion_module:
+ new_motion_module = self.orm.DiffusionModelMotionModule(diffusion_model_name=new_model.name, name=motion_module)
+ self.database.add(new_motion_module)
+ self.database.commit()
+ for lora in request.parsed.get("lora", []):
+ new_lora = self.orm.DiffusionModelLora(
+ diffusion_model_name=new_model.name, model=lora["model"], weight=lora["weight"]
+ )
+ self.database.add(new_lora)
+ self.database.commit()
+ for lycoris in request.parsed.get("lycoris", []):
+ new_lycoris = self.orm.DiffusionModelLycoris(
+ diffusion_model_name=new_model.name, model=lycoris["model"], weight=lycoris["weight"]
+ )
+ self.database.add(new_lycoris)
+ self.database.commit()
+ for inversion in request.parsed.get("inversion", []):
+ new_inversion = self.orm.DiffusionModelInversion(diffusion_model_name=new_model.name, model=inversion)
+ self.database.add(new_inversion)
+ self.database.commit()
+ for field_name in self.MODEL_DEFAULT_FIELDS:
+ field_value = request.parsed.get(field_name, None)
+ if field_value is not None:
+ new_config = self.orm.DiffusionModelDefaultConfiguration(
diffusion_model_name=new_model.name,
- model=inpainter,
- size=request.parsed.get("inpainter_size", None),
- )
- self.database.add(new_inpainter)
- self.database.commit()
- scheduler = request.parsed.get("scheduler", None)
- if scheduler:
- new_scheduler = self.orm.DiffusionModelScheduler(diffusion_model_name=new_model.name, name=scheduler)
- self.database.add(new_scheduler)
- self.database.commit()
- vae = request.parsed.get("vae", None)
- if vae:
- new_vae = self.orm.DiffusionModelVAE(diffusion_model_name=new_model.name, name=vae)
- self.database.add(new_vae)
- self.database.commit()
- refiner_vae = request.parsed.get("refiner_vae", None)
- if refiner_vae:
- new_refiner_vae = self.orm.DiffusionModelRefinerVAE(diffusion_model_name=new_model.name, name=refiner_vae)
- self.database.add(new_refiner_vae)
- self.database.commit()
- inpainter_vae = request.parsed.get("inpainter_vae", None)
- if inpainter_vae:
- new_inpainter_vae = self.orm.DiffusionModelInpainterVAE(diffusion_model_name=new_model.name, name=inpainter_vae)
- self.database.add(new_inpainter_vae)
- self.database.commit()
- for lora in request.parsed.get("lora", []):
- new_lora = self.orm.DiffusionModelLora(
- diffusion_model_name=new_model.name, model=lora["model"], weight=lora["weight"]
+ configuration_key=field_name,
+ configuration_value=field_value,
)
- self.database.add(new_lora)
- self.database.commit()
- for lycoris in request.parsed.get("lycoris", []):
- new_lycoris = self.orm.DiffusionModelLycoris(
- diffusion_model_name=new_model.name, model=lycoris["model"], weight=lycoris["weight"]
- )
- self.database.add(new_lycoris)
- self.database.commit()
- for inversion in request.parsed.get("inversion", []):
- new_inversion = self.orm.DiffusionModelInversion(diffusion_model_name=new_model.name, model=inversion)
- self.database.add(new_inversion)
+ self.database.add(new_config)
self.database.commit()
- for field_name in self.MODEL_DEFAULT_FIELDS:
- field_value = request.parsed.get(field_name, None)
- if field_value is not None:
- new_config = self.orm.DiffusionModelDefaultConfiguration(
- diffusion_model_name=new_model.name,
- configuration_key=field_name,
- configuration_value=field_value,
- )
- self.database.add(new_config)
- self.database.commit()
- return new_model
- except KeyError as ex:
- raise BadRequestError(f"Missing required parameter {ex}")
+ return new_model
@handlers.path("^/api/model-options$")
@handlers.methods("GET")
@@ -665,14 +820,7 @@ def get_all_models(self, request: Request, response: Response) -> List[Dict[str,
})
# Get diffusers caches
- diffusers_dir = self.get_configured_directory("diffusers")
- diffusers_models = []
- if os.path.exists(diffusers_dir):
- diffusers_models = [
- dirname
- for dirname in os.listdir(diffusers_dir)
- if os.path.exists(os.path.join(diffusers_dir, dirname, "model_index.json"))
- ]
+ diffusers_models = self.get_diffusers_models()
diffusers_caches = []
for model in diffusers_models:
found = False
diff --git a/src/python/enfugue/api/controller/system.py b/src/python/enfugue/api/controller/system.py
index fb61ec87..0574f5db 100644
--- a/src/python/enfugue/api/controller/system.py
+++ b/src/python/enfugue/api/controller/system.py
@@ -95,9 +95,9 @@ def get_settings(self, request: Request, response: Response) -> Dict[str, Any]:
settings = {
"safe": self.configuration.get("enfugue.safe", True),
"auth": not (self.configuration.get("enfugue.noauth", True)),
- "max_queued_invocations": self.manager.max_queued_invocations,
- "max_queued_downloads": self.manager.max_queued_downloads,
- "max_concurrent_downloads": self.manager.max_concurrent_downloads,
+ "max_queued_invocations": self.configuration.get("enfugue.queue", self.manager.DEFAULT_MAX_QUEUED_INVOCATIONS),
+ "max_queued_downloads": self.configuration.get("enfugue.downloads.queue", self.manager.DEFAULT_MAX_QUEUED_DOWNLOADS),
+ "max_concurrent_downloads": self.configuration.get("enfugue.downloads.concurrent", self.manager.DEFAULT_MAX_CONCURRENT_DOWNLOADS),
"switch_mode": self.configuration.get("enfugue.pipeline.switch", "offload"),
"sequential": self.configuration.get("enfugue.pipeline.sequential", False),
"cache_mode": self.configuration.get("enfugue.pipeline.cache", None),
@@ -170,15 +170,15 @@ def update_settings(self, request: Request, response: Response) -> None:
self.configuration["enfugue.pipeline.inpainter"] = None
else:
self.user_config["enfugue.pipeline.inpainter"] = False
-
- for key in [
- "max_queued_invocation",
- "max_queued_downloads",
- "max_concurrent_downloads",
- ]:
- if key in request.parsed:
- self.user_config[f"enfugue.{key}"] = request.parsed[key]
+ if "max_queued_invocations" in request.parsed:
+ self.user_config["enfugue.queue"] = request.parsed["max_queued_invocations"]
+
+ if "max_queued_downloads" in request.parsed:
+ self.user_config["enfugue.downloads.queue"] = request.parsed["max_queued_downloads"]
+
+ if "max_concurrent_downloads" in request.parsed:
+ self.user_config["enfugue.downloads.concurrent"] = request.parsed["max_concurrent_downloads"]
for controlnet in self.CONTROLNETS:
if controlnet in request.parsed:
@@ -316,7 +316,7 @@ def get_installation_summary(self, request: Request, response: Response) -> Dict
Gets a summary of files and filesize in the installation
"""
sizes = {}
- for dirname in ["cache", "diffusers", "checkpoint", "lora", "lycoris", "inversion", "tensorrt", "other"]:
+ for dirname in ["cache", "diffusers", "checkpoint", "lora", "lycoris", "inversion", "tensorrt", "other", "images", "intermediate"]:
directory = self.get_configured_directory(dirname)
items, files, size = get_directory_size(directory)
sizes[dirname] = {"items": items, "files": files, "bytes": size, "path": directory}
diff --git a/src/python/enfugue/api/invocations.py b/src/python/enfugue/api/invocations.py
index 82da09c4..06e79f1e 100644
--- a/src/python/enfugue/api/invocations.py
+++ b/src/python/enfugue/api/invocations.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import re
import datetime
+import traceback
from typing import Any, Optional, List, Dict
from pibble.util.helpers import resolve
@@ -7,7 +10,8 @@
from enfugue.util import logger, get_version
from enfugue.diffusion.engine import DiffusionEngine
-from enfugue.diffusion.plan import DiffusionPlan
+from enfugue.diffusion.interpolate import InterpolationEngine
+from enfugue.diffusion.invocation import LayeredInvocation
from multiprocessing import Lock
from PIL.PngImagePlugin import PngInfo
@@ -18,16 +22,21 @@
class TerminatedError(Exception):
pass
+def get_relative_paths(paths: List[str]) -> List[str]:
+ """
+ Gets relative paths from a list of paths (os agnostic)
+ """
+ return ["/".join(re.split(r"/|\\", path)[-2:]) for path in paths]
class Invocation:
"""
Holds the details for a single invocation
"""
-
start_time: Optional[datetime.datetime]
end_time: Optional[datetime.datetime]
last_intermediate_time: Optional[datetime.datetime]
results: Optional[List[str]]
+ video_result: Optional[str]
last_images: Optional[List[str]]
last_step: Optional[int]
last_total: Optional[int]
@@ -38,7 +47,8 @@ class Invocation:
def __init__(
self,
engine: DiffusionEngine,
- plan: DiffusionPlan,
+ interpolator: InterpolationEngine,
+ plan: LayeredInvocation,
engine_image_dir: str,
engine_intermediate_dir: str,
ui_state: Optional[str] = None,
@@ -46,11 +56,15 @@ def __init__(
communication_timeout: Optional[int] = 180,
metadata: Optional[Dict[str, Any]] = None,
save: bool = True,
+ video_format: str = "mp4",
+ video_codec: str = "avc1",
+ video_rate: float = 8.0,
**kwargs: Any,
) -> None:
self.lock = Lock()
self.uuid = get_uuid()
self.engine = engine
+ self.interpolator = interpolator
self.plan = plan
self.results_dir = engine_image_dir
@@ -61,17 +75,27 @@ def __init__(
self.metadata = metadata
self.save = save
+ self.video_format = video_format
+ self.video_codec = video_codec
+ self.video_rate = video_rate
+
self.id = None
+ self.interpolate_id = None
self.error = None
+
self.start_time = None
self.end_time = None
+ self.start_interpolate_time = None
+ self.end_interpolate_time = None
self.last_intermediate_time = None
+
self.last_step = None
self.last_total = None
self.last_rate = None
self.last_images = None
self.last_task = None
self.results = None
+ self.interpolate_result = None
def _communicate(self) -> None:
"""
@@ -90,6 +114,7 @@ def _communicate(self) -> None:
setattr(self, f"last_{key}", last_intermediate[key])
self.last_intermediate_time = datetime.datetime.now()
end_comm = (datetime.datetime.now() - start_comm).total_seconds()
+
try:
result = self.engine.wait(self.id, timeout=0.1)
except TimeoutError:
@@ -97,10 +122,12 @@ def _communicate(self) -> None:
except Exception as ex:
result = None
self.error = ex
+
if result is not None:
# Complete
self.results = []
self.end_time = datetime.datetime.now()
+
if "images" in result:
is_nsfw = result.get("nsfw_content_detected", [])
for i, image in enumerate(result["images"]):
@@ -120,12 +147,34 @@ def _communicate(self) -> None:
self.results.append(image_path)
else:
self.results.append("unsaved")
- elif "error" in result:
+
+ if self.plan.animation_frames:
+ if self.plan.interpolate_frames or self.plan.reflect:
+ # Start interpolation
+ self.start_interpolate()
+ else:
+ # Save video
+ try:
+ from enfugue.diffusion.util.video_util import Video
+ video_path = f"{self.results_dir}/{self.uuid}.{self.video_format}"
+ Video(result["images"]).save(
+ video_path,
+ rate=self.video_rate,
+ encoder=self.video_codec
+ )
+ self.interpolate_result = video_path # type: ignore[assignment]
+ except Exception as ex:
+ self.error = ex
+ logger.error(f"Couldn't save video: {ex}")
+ logger.debug(traceback.format_exc())
+
+ if "error" in result:
error_type = resolve(result["error"])
self.error = error_type(result["message"])
if "traceback" in result:
logger.error(f"Traceback for invocation {self.uuid}:")
logger.debug(result["traceback"])
+
if self.metadata is not None and "tensorrt_build" in self.metadata:
logger.info("TensorRT build complete, terminating engine to start fresh on next invocation.")
self.engine.terminate_process()
@@ -146,7 +195,7 @@ def start(self) -> None:
"""
with self.lock:
self.start_time = datetime.datetime.now()
- payload = self.plan.get_serialization_dict(self.intermediate_dir)
+ payload = self.plan.serialize(self.intermediate_dir)
payload["intermediate_dir"] = self.intermediate_dir
payload["intermediate_steps"] = self.intermediate_steps
self.id = self.engine.dispatch("plan", payload)
@@ -158,6 +207,76 @@ def poll(self) -> None:
with self.lock:
self._communicate()
+ def start_interpolate(self) -> None:
+ """
+ Starts the interpolation (is locked when called)
+ """
+ from PIL import Image
+ if self.interpolate_id is not None:
+ raise IOError("Interpolation already began.")
+ assert isinstance(self.results, list), "Must have a list of image results"
+ self.interpolate_start_time = datetime.datetime.now()
+ self.interpolate_id = self.interpolator.dispatch("plan", {
+ "reflect": self.plan.reflect,
+ "frames": self.plan.interpolate_frames,
+ "images": [
+ Image.open(path) for path in self.results
+ ],
+ "save_path": f"{self.results_dir}/{self.uuid}.{self.video_format}",
+ "video_rate": self.video_rate,
+ "video_codec": self.video_codec
+ })
+
+ def _interpolate_communicate(self) -> None:
+ """
+ Tries to communicate with the engine to see what's going on.
+ """
+ if self.interpolate_id is None:
+ raise IOError("Interpolation not started yet.")
+ if self.interpolate_result is not None: # type: ignore[unreachable]
+ raise IOError("Interpolation already completed.")
+ try:
+ start_comm = datetime.datetime.now()
+ last_intermediate = self.interpolator.last_intermediate(self.interpolate_id)
+ if last_intermediate is not None:
+ for key in ["step", "total", "rate", "task"]:
+ if key in last_intermediate:
+ setattr(self, f"last_{key}", last_intermediate[key])
+ self.last_intermediate_time = datetime.datetime.now()
+ end_comm = (datetime.datetime.now() - start_comm).total_seconds()
+
+ try:
+ result = self.interpolator.wait(self.interpolate_id, timeout=0.1)
+ except TimeoutError:
+ raise
+ except Exception as ex:
+ result = None
+ self.error = ex
+
+ if result is not None:
+ # Complete
+ if isinstance(result, list):
+ from enfugue.diffusion.util.video_util import Video
+ self.interpolate_end_time = datetime.datetime.now()
+ video_path = f"{self.results_dir}/{self.uuid}.{self.video_format}"
+ Video(result).save(
+ video_path,
+ rate=self.video_rate,
+ encoder=self.video_codec
+ )
+ self.interpolate_result = video_path
+ else:
+ self.interpolate_result = result
+ except TimeoutError:
+ return
+
+ def poll_interpolator(self) -> None:
+ """
+ Calls communicate on the interpolator once (locks)
+ """
+ with self.lock:
+ self._interpolate_communicate()
+
@property
def is_dangling(self) -> bool:
"""
@@ -205,21 +324,63 @@ def format(self) -> Dict[str, Any]:
return {"status": "queued", "uuid": self.uuid}
if self.error is not None:
+ if self.results:
+ images = get_relative_paths(self.results)
+ else:
+ images = None
+ if self.interpolate_result:
+ video = get_relative_paths([self.interpolate_result])[0] # type: ignore[unreachable]
+ else:
+ video = None
+ if self.end_time is not None:
+ duration = (self.end_time - self.start_time).total_seconds()
+ else:
+ duration = 0
return {
"status": "error",
"uuid": self.uuid,
"message": str(self.error),
+ "images": images,
+ "video": video,
+ "duration": duration,
}
images = None
+ video = None
+
if self.results is not None:
- status = "completed"
- images = ["/".join(re.split(r"/|\\", path)[-2:]) for path in self.results]
+ if self.plan.animation_frames:
+ if self.interpolate_result:
+ status = "completed" # type: ignore[unreachable]
+ video = get_relative_paths([self.interpolate_result])[0]
+ elif self.plan.interpolate_frames or self.plan.reflect:
+ status = "interpolating"
+ self._interpolate_communicate()
+ else:
+ # Saving
+ status = "processing"
+ else:
+ status = "completed"
+ if self.plan.animation_frames:
+ video = get_relative_paths([self.interpolate_result])[0] # type: ignore[list-item]
+ images = get_relative_paths(self.results)
else:
status = "processing"
self._communicate()
- if self.last_images is not None:
- images = ["/".join(re.split(r"/|\\", path)[-2:]) for path in self.last_images]
+ if self.results is not None:
+ # Finished in previous _communicate() calling
+ if self.plan.animation_frames: # type: ignore[unreachable]
+ if self.plan.interpolate_frames or self.plan.reflect:
+ # Interpolation just started
+ ...
+ elif self.interpolate_result:
+ status = "completed"
+ video = get_relative_paths([self.interpolate_result])[0]
+ else:
+ status = "completed"
+ images = get_relative_paths(self.results)
+ elif self.last_images is not None:
+ images = get_relative_paths(self.last_images)
if self.end_time is None:
duration = (datetime.datetime.now() - self.start_time).total_seconds()
@@ -230,7 +391,7 @@ def format(self) -> Dict[str, Any]:
if self.last_total is not None and self.last_total > 0:
total = self.last_total
if self.last_step is not None:
- step = total if self.results is not None else self.last_step
+ step = self.last_total if status == "completed" else self.last_step
if total is not None and step is not None:
progress = step / total
if self.last_rate is not None:
@@ -238,6 +399,9 @@ def format(self) -> Dict[str, Any]:
elif step is not None:
rate = step / duration
+ if video:
+ video = f"animation/{video}"
+
formatted = {
"id": self.id,
"uuid": self.uuid,
@@ -247,11 +411,14 @@ def format(self) -> Dict[str, Any]:
"duration": duration,
"total": total,
"images": images,
+ "video": video,
"rate": rate,
"task": self.last_task
}
+
if self.metadata:
formatted["metadata"] = self.metadata
+
return formatted
def __str__(self) -> str:
diff --git a/src/python/enfugue/api/manager.py b/src/python/enfugue/api/manager.py
index 4f77812a..519113b6 100644
--- a/src/python/enfugue/api/manager.py
+++ b/src/python/enfugue/api/manager.py
@@ -14,7 +14,9 @@
from enfugue.api.downloads import Download
from enfugue.api.invocations import Invocation
from enfugue.diffusion.engine import DiffusionEngine
-from enfugue.diffusion.plan import DiffusionPlan
+from enfugue.diffusion.interpolate import InterpolationEngine
+
+from enfugue.diffusion.invocation import LayeredInvocation
from enfugue.util import logger, check_make_directory, find_file_in_directory
from enfugue.diffusion.constants import (
DEFAULT_MODEL,
@@ -100,6 +102,7 @@ def __init__(self, configuration: APIConfiguration) -> None:
self.active_invocation = None
self.configuration = configuration
self.engine = DiffusionEngine(self.configuration)
+ self.interpolator = InterpolationEngine(self.configuration)
self.downloads = {}
self.invocations = {}
self.download_queue = []
@@ -141,7 +144,10 @@ def engine_image_dir(self) -> str:
"""
Gets the location for image result outputs.
"""
- directory = os.path.join(self.engine_root_dir, "images")
+ directory = self.configuration.get(
+ "enfugue.engine.images",
+ os.path.join(self.engine_root_dir, "images")
+ )
check_make_directory(directory)
return directory
@@ -150,7 +156,10 @@ def engine_intermediate_dir(self) -> str:
"""
Gets the location for image intermediate outputs.
"""
- directory = os.path.join(self.engine_root_dir, "intermediates")
+ directory = self.configuration.get(
+ "enfugue.engine.intermediate",
+ os.path.join(self.engine_root_dir, "intermediate")
+ )
check_make_directory(directory)
return directory
@@ -328,7 +337,10 @@ def active_downloads(self) -> List[Download]:
Gets a list of active downloads
"""
return [
- download for download_list in self.downloads.values() for download in download_list if not download.complete and download.started
+ download
+ for download_list in self.downloads.values()
+ for download in download_list
+ if not download.complete and download.started
]
@property
@@ -444,9 +456,12 @@ def cancel_download(self, url: str) -> bool:
def invoke(
self,
user_id: int,
- plan: DiffusionPlan,
+ plan: LayeredInvocation,
ui_state: Optional[str] = None,
disable_intermediate_decoding: bool = False,
+ video_rate: Optional[float] = None,
+ video_codec: Optional[str] = None,
+ video_format: Optional[str] = None,
**kwargs: Any,
) -> Invocation:
"""
@@ -463,8 +478,16 @@ def invoke(
else:
kwargs["decode_nth_intermediate"] = self.engine_intermediate_steps
+ if video_rate is not None:
+ kwargs["video_rate"] = video_rate
+ if video_codec is not None:
+ kwargs["video_codec"] = video_codec
+ if video_format is not None:
+ kwargs["video_format"] = video_format
+
invocation = Invocation(
engine=self.engine,
+ interpolator=self.interpolator,
plan=plan,
engine_image_dir=self.engine_image_dir,
engine_intermediate_dir=self.engine_intermediate_dir,
@@ -479,6 +502,7 @@ def invoke(
self.invocation_queue.append(invocation)
if user_id not in self.invocations:
self.invocations[user_id] = []
+
self.invocations[user_id].append(invocation)
return invocation
@@ -496,17 +520,23 @@ def get_invocations(self, user_id: int) -> List[Invocation]:
def stop_engine(self) -> None:
"""
- Stops the engine forcibly.
+ stops the engine forcibly.
"""
if self.active_invocation is not None:
try:
self.active_invocation.terminate()
time.sleep(5)
except Exception as ex:
- logger.info(f"Ignoring exception during invocation termination: {ex}")
+ logger.info(f"ignoring exception during invocation termination: {ex}")
self.active_invocation = None
self.engine.terminate_process()
+ def stop_interpolator(self) -> None:
+ """
+ stops the interpolator forcibly.
+ """
+ self.interpolator.terminate_process()
+
def clean_intermediates(self) -> None:
"""
Cleans up intermediate files
diff --git a/src/python/enfugue/api/server.py b/src/python/enfugue/api/server.py
index 809f768e..4e3ade48 100644
--- a/src/python/enfugue/api/server.py
+++ b/src/python/enfugue/api/server.py
@@ -19,7 +19,7 @@
from pibble.util.encryption import Password
from pibble.util.helpers import OutputCatcher
-from enfugue.diffusion.plan import DiffusionPlan
+from enfugue.diffusion.invocation import LayeredInvocation
from enfugue.database import *
from enfugue.api.controller import *
@@ -152,11 +152,21 @@ def on_destroy(self) -> None:
logger.debug("Stopping system manager")
self.manager.stop_monitor()
self.manager.stop_engine()
+ self.manager.stop_interpolator()
- def format_plan(self, plan: DiffusionPlan) -> Dict[str, Any]:
+ def format_plan(self, plan: LayeredInvocation) -> Dict[str, Any]:
"""
Formats a plan for inserting into the database
"""
+ def get_image_metadata(image: PIL.Image.Image) -> Dict[str, Any]:
+ """
+ Gets metadata from an image
+ """
+ width, height = image.size
+ metadata = {"width": width, "height": height, "mode": image.mode}
+ if hasattr(image, "filename"):
+ metadata["filename"] = image.filename
+ return metadata
def replace_images(serialized: Dict[str, Any]) -> Dict[str, Any]:
"""
@@ -164,18 +174,21 @@ def replace_images(serialized: Dict[str, Any]) -> Dict[str, Any]:
"""
for key, value in serialized.items():
if isinstance(value, PIL.Image.Image):
- width, height = value.size
- metadata = {"width": width, "height": height, "mode": value.mode}
- if hasattr(value, "filename"):
- metadata["filename"] = value.filename
- serialized[key] = metadata
+ serialized[key] = get_image_metadata(value)
elif isinstance(value, dict):
serialized[key] = replace_images(value)
elif isinstance(value, list):
- serialized[key] = [replace_images(part) if isinstance(part, dict) else part for part in value]
+ serialized[key] = [
+ replace_images(part)
+ if isinstance(part, dict)
+ else get_image_metadata(part)
+ if isinstance(part, PIL.Image.Image)
+ else part
+ for part in value
+ ]
return serialized
- return replace_images(plan.get_serialization_dict())
+ return replace_images(plan.serialize())
def get_plan_kwargs_from_model(
self,
@@ -198,32 +211,29 @@ def get_plan_kwargs_from_model(
lora_dir = self.get_configured_directory("lora")
lycoris_dir = self.get_configured_directory("lycoris")
inversion_dir = self.get_configured_directory("inversion")
+ motion_dir = self.get_configured_directory("motion")
model = find_file_in_directory(checkpoint_dir, diffusion_model.model)
if not model:
raise ValueError(f"Could not find {diffusion_model.model} in {checkpoint_dir}")
- size = diffusion_model.size
-
refiner = diffusion_model.refiner
if refiner:
- refiner_size = refiner[0].size
refiner_model = find_file_in_directory(checkpoint_dir, refiner[0].model)
if not refiner_model:
raise ValueError(f"Could not find {refiner[0].model} in {checkpoint_dir}")
refiner = refiner_model
else:
- refiner, refiner_size = None, None
+ refiner = None
inpainter = diffusion_model.inpainter
if inpainter:
- inpainter_size = inpainter[0].size
inpainter_model = os.path.join(checkpoint_dir, inpainter[0].model)
if not inpainter_model:
raise ValueError(f"Could not find {inpainter[0].model} in {checkpoint_dir}")
inpainter = inpainter_model
else:
- inpainter, inpainter_size = None, None
+ inpainter = None
scheduler = diffusion_model.scheduler
if scheduler:
@@ -247,6 +257,12 @@ def get_plan_kwargs_from_model(
else:
inpainter_vae = None
+ motion_module = diffusion_model.motion_module
+ if motion_module:
+ motion_module = diffusion_model.motion_module[0].name
+ else:
+ motion_module = None
+
lora = []
for lora_model in diffusion_model.lora:
lora_model_path = find_file_in_directory(lora_dir, lora_model.model)
@@ -274,17 +290,15 @@ def get_plan_kwargs_from_model(
plan_kwargs: Dict[str, Any] = {
"model": model,
"refiner": refiner,
- "refiner_size": refiner_size,
"inpainter": inpainter,
- "inpainter_size": inpainter_size,
- "size": size,
"lora": lora,
"lycoris": lycoris,
"inversion": inversion,
"scheduler": scheduler,
"vae": vae,
"refiner_vae": refiner_vae,
- "inpainter_vae": inpainter_vae
+ "inpainter_vae": inpainter_vae,
+ "motion_module": motion_module
}
model_config = {}
@@ -304,10 +318,11 @@ def get_plan_kwargs_from_model(
def invoke(
self,
user_id: int,
- plan: DiffusionPlan,
+ plan: LayeredInvocation,
save: bool = True,
ui_state: Optional[str] = None,
disable_intermediate_decoding: bool = False,
+ video_rate: Optional[float] = None,
**kwargs: Any,
) -> Invocation:
"""
@@ -318,6 +333,9 @@ def invoke(
plan,
ui_state=ui_state,
disable_intermediate_decoding=disable_intermediate_decoding,
+ video_rate=video_rate,
+ video_codec=self.configuration.get("enfugue.video.codec", "avc1"),
+ video_format=self.configuration.get("enfugue.video.format", "mp4"),
**kwargs
)
if save:
@@ -419,6 +437,8 @@ def announcements(self, request: Request, response: Response) -> List[Dict[str,
"inversion",
"other",
"tensorrt",
+ "images",
+ "intermediate",
]:
directories[dirname] = self.configuration.get(
f"enfugue.engine.{dirname}", os.path.join(self.engine_root, dirname)
diff --git a/src/python/enfugue/client/client.py b/src/python/enfugue/client/client.py
index 3e97ee96..f5e9b12e 100644
--- a/src/python/enfugue/client/client.py
+++ b/src/python/enfugue/client/client.py
@@ -11,13 +11,8 @@
from pibble.api.client.webservice.jsonapi import JSONWebServiceAPIClient
from pibble.ext.user.client.base import UserExtensionClientBase
-from enfugue.diffusion.plan import NodeDict, UpscaleStepDict
from enfugue.diffusion.constants import *
-from enfugue.util import (
- logger,
- IMAGE_FIT_LITERAL,
- IMAGE_ANCHOR_LITERAL
-)
+from enfugue.util import logger
from enfugue.client.invocation import RemoteInvocation
__all__ = ["WeightedModelDict", "EnfugueClient"]
@@ -126,6 +121,7 @@ def settings(self) -> Dict[str, Any]:
def invoke(
self,
+ prompts: Optional[List[Dict]] = None,
prompt: Optional[str] = None,
prompt_2: Optional[str] = None,
negative_prompt: Optional[str] = None,
@@ -138,9 +134,10 @@ def invoke(
intermediates: Optional[bool] = None,
width: Optional[int] = None,
height: Optional[int] = None,
- chunking_size: Optional[int] = None,
- chunking_mask_type: Optional[MASK_TYPE_LITERAL] = None,
- chunking_mask_kwargs: Optional[Dict[str, Any]] = None,
+ tiling_size: Optional[int] = None,
+ tiling_stride: Optional[int] = None,
+ tiling_mask_type: Optional[MASK_TYPE_LITERAL] = None,
+ tiling_mask_kwargs: Optional[Dict[str, Any]] = None,
samples: Optional[int] = None,
iterations: Optional[int] = None,
num_inference_steps: Optional[int] = None,
@@ -154,12 +151,9 @@ def invoke(
refiner_prompt_2: Optional[str] = None,
refiner_negative_prompt: Optional[str] = None,
refiner_negative_prompt_2: Optional[str] = None,
- nodes: Optional[List[NodeDict]] = None,
+ layers: Optional[List[Dict[str, Any]]] = None,
model: Optional[str] = None,
model_type: Optional[Literal["checkpoint", "model"]] = None,
- size: Optional[int] = None,
- refiner_size: Optional[int] = None,
- inpainter_size: Optional[int] = None,
inpainter: Optional[str] = None,
refiner: Optional[str] = None,
lora: Optional[List[WeightedModelDict]] = None,
@@ -171,26 +165,21 @@ def invoke(
inpainter_vae: Optional[str] = None,
freeu_factors: Optional[Tuple[float, float, float, float]] = None,
seed: Optional[int] = None,
- image: Optional[Union[str, Image]] = None,
mask: Optional[Union[str, Image]] = None,
- ip_adapter_images: Optional[List[Dict[str, Any]]] = None,
- ip_adapter_plus: bool = False,
- ip_adapter_face: bool = False,
- control_images: Optional[List[Dict[str, Any]]] = None,
+ ip_adapter_model: Optional[IP_ADAPTER_LITERAL] = None,
strength: Optional[float] = None,
- fit: Optional[IMAGE_FIT_LITERAL] = None,
- anchor: Optional[IMAGE_ANCHOR_LITERAL] = None,
- remove_background: Optional[bool] = None,
- fill_background: Optional[bool] = None,
- scale_to_model_size: Optional[bool] = None,
- invert_mask: Optional[bool] = None,
- conditioning_scale: Optional[float] = None,
- crop_inpaint: Optional[bool] = None,
- inpaint_feather: Optional[int] = None,
+ outpaint: Optional[bool] = None,
noise_offset: Optional[float] = None,
noise_method: Optional[NOISE_METHOD_LITERAL] = None,
noise_blend_method: Optional[LATENT_BLEND_METHOD_LITERAL] = None,
- upscale_steps: Optional[Union[UpscaleStepDict, List[UpscaleStepDict]]] = None,
+ upscale: Optional[Union[UpscaleStepDict, List[UpscaleStepDict]]] = None,
+ motion_scale: Optional[float] = None,
+ position_encoding_truncate_length: Optional[int] = None,
+ position_encoding_scale_length: Optional[int] = None,
+ motion_module: Optional[str] = None,
+ animation_frames: Optional[int] = None,
+ loop: Optional[bool] = None,
+ tile: Optional[Union[bool, Tuple[bool, bool], List[bool]]] = None,
) -> RemoteInvocation:
"""
Invokes the engine.
@@ -209,6 +198,8 @@ def invoke(
kwargs["negative_prompt"] = negative_prompt
if negative_prompt_2 is not None:
kwargs["negative_prompt_2"] = negative_prompt_2
+ if prompts is not None:
+ kwargs["prompts"] = prompts
if model_prompt is not None:
kwargs["model_prompt"] = model_prompt
if model_prompt_2 is not None:
@@ -223,12 +214,14 @@ def invoke(
kwargs["width"] = width
if height is not None:
kwargs["height"] = height
- if chunking_size is not None:
- kwargs["chunking_size"] = chunking_size
- if chunking_mask_type is not None:
- kwargs["chunking_mask_type"] = chunking_mask_type
- if chunking_mask_kwargs is not None:
- kwargs["chunking_mask_kwargs"] = chunking_mask_kwargs
+ if tiling_size is not None:
+ kwargs["tiling_size"] = tiling_size
+ if tiling_stride is not None:
+ kwargs["tiling_stride"] = tiling_stride
+ if tiling_mask_type is not None:
+ kwargs["tiling_mask_type"] = tiling_mask_type
+ if tiling_mask_kwargs is not None:
+ kwargs["tiling_mask_kwargs"] = tiling_mask_kwargs
if samples is not None:
kwargs["samples"] = samples
if iterations is not None:
@@ -251,14 +244,6 @@ def invoke(
kwargs["refiner_prompt"] = refiner_prompt
if refiner_negative_prompt is not None:
kwargs["refiner_negative_prompt"] = refiner_negative_prompt
- if nodes is not None:
- kwargs["nodes"] = nodes
- if size is not None:
- kwargs["size"] = size
- if refiner_size is not None:
- kwargs["refiner_size"] = refiner_size
- if inpainter_size is not None:
- kwargs["inpainter_size"] = inpainter_size
if inpainter is not None:
kwargs["inpainter"] = inpainter
if refiner is not None:
@@ -271,42 +256,18 @@ def invoke(
kwargs["scheduler"] = scheduler
if vae is not None:
kwargs["vae"] = vae
- if refiner_vae is not None:
- kwargs["refiner_vae"] = refiner_vae
+ if inpainter_vae is not None:
+ kwargs["inpainter_vae"] = inpainter_vae
if refiner_vae is not None:
kwargs["refiner_vae"] = refiner_vae
if seed is not None:
kwargs["seed"] = seed
- if image is not None:
- kwargs["image"] = image
if mask is not None:
kwargs["mask"] = mask
- if control_images is not None:
- kwargs["control_images"] = control_images
if strength is not None:
kwargs["strength"] = strength
- if fit is not None:
- kwargs["fit"] = fit
- if anchor is not None:
- kwargs["anchor"] = anchor
- if remove_background is not None:
- kwargs["remove_background"] = remove_background
- if fill_background is not None:
- kwargs["fill_background"] = fill_background
- if scale_to_model_size is not None:
- kwargs["scale_to_model_size"] = scale_to_model_size
- if invert_mask is not None:
- kwargs["invert_mask"] = invert_mask
- if crop_inpaint is not None:
- kwargs["crop_inpaint"] = crop_inpaint
- if inpaint_feather is not None:
- kwargs["inpaint_feather"] = inpaint_feather
- if ip_adapter_images is not None:
- kwargs["ip_adapter_images"] = ip_adapter_images
- kwargs["ip_adapter_plus"] = ip_adapter_plus
- kwargs["ip_adapter_face"] = ip_adapter_face
- if upscale_steps is not None:
- kwargs["upscale_steps"] = upscale_steps
+ if upscale is not None:
+ kwargs["upscale"] = upscale
if clip_skip is not None:
kwargs["clip_skip"] = clip_skip
if freeu_factors is not None:
@@ -317,6 +278,24 @@ def invoke(
kwargs["noise_method"] = noise_method
if noise_blend_method is not None:
kwargs["noise_blend_method"] = noise_blend_method
+ if layers is not None:
+ kwargs["layers"] = layers
+ if motion_scale is not None:
+ kwargs["motion_scale"] = motion_scale
+ if ip_adapter_model is not None:
+ kwargs["ip_adapter_model"] = ip_adapter_model
+ if position_encoding_truncate_length is not None:
+ kwargs["position_encoding_truncate_length"] = position_encoding_truncate_length
+ if position_encoding_scale_length is not None:
+ kwargs["position_encoding_scale_length"] = position_encoding_scale_length
+ if animation_frames is not None:
+ kwargs["animation_frames"] = animation_frames
+ if loop is not None:
+ kwargs["loop"] = loop
+ if tile is not None:
+ kwargs["tile"] = tile
+ if outpaint is not None:
+ kwargs["outpaint"] = outpaint
logger.info(f"Invoking with keyword arguments {kwargs}")
diff --git a/src/python/enfugue/database/__init__.py b/src/python/enfugue/database/__init__.py
index 489c9d17..2129d6a9 100644
--- a/src/python/enfugue/database/__init__.py
+++ b/src/python/enfugue/database/__init__.py
@@ -13,9 +13,10 @@
DiffusionModelLycoris,
DiffusionModelInversion,
DiffusionModelDefaultConfiguration,
+ DiffusionModelMotionModule,
)
-EnfugueObjectBase, DiffusionModel, DiffusionModelRefiner, DiffusionModelInpainter, DiffusionModelVAE, DiffusionModelScheduler, DiffusionModelLora, DiffusionModelLycoris, DiffusionModelInversion, DiffusionInvocation, DiffusionModelDefaultConfiguration, ConfigurationItem, DiffusionModelRefinerVAE, DiffusionModelInpainterVAE # Silence importchecker
+EnfugueObjectBase, DiffusionModel, DiffusionModelRefiner, DiffusionModelInpainter, DiffusionModelVAE, DiffusionModelScheduler, DiffusionModelLora, DiffusionModelLycoris, DiffusionModelInversion, DiffusionInvocation, DiffusionModelDefaultConfiguration, ConfigurationItem, DiffusionModelRefinerVAE, DiffusionModelInpainterVAE, DiffusionModelMotionModule # Silence importchecker
__all__ = [
"EnfugueObjectBase",
@@ -32,4 +33,5 @@
"DiffusionModelLycoris",
"DiffusionModelInversion",
"DiffusionModelDefaultConfiguration",
+ "DiffusionModelMotionModule",
]
diff --git a/src/python/enfugue/database/models.py b/src/python/enfugue/database/models.py
index e9a696df..aa49ffa4 100644
--- a/src/python/enfugue/database/models.py
+++ b/src/python/enfugue/database/models.py
@@ -14,6 +14,7 @@
"DiffusionModelLora",
"DiffusionModelLycoris",
"DiffusionModelInversion",
+ "DiffusionModelMotionModule",
]
@@ -140,3 +141,13 @@ class DiffusionModelInversion(EnfugueObjectBase):
model = Column(String(256), primary_key=True)
diffusion_model = DiffusionModel.Relationship(backref="inversion")
+
+
+class DiffusionModelMotionModule(EnfugueObjectBase):
+ __tablename__ = "model_motion_module"
+
+ diffusion_model_name = Column(
+ DiffusionModel.ForeignKey("name", ondelete="CASCADE", onupdate="CASCADE"), primary_key=True, unique=True
+ )
+ name = Column(String(256), nullable=False)
+ diffusion_model = DiffusionModel.Relationship(backref="motion_module", uselist=False)
diff --git a/src/python/enfugue/diffusion/README.md b/src/python/enfugue/diffusion/README.md
index 2a5c8658..196b4c72 100644
--- a/src/python/enfugue/diffusion/README.md
+++ b/src/python/enfugue/diffusion/README.md
@@ -31,7 +31,9 @@ def __init__(
engine_size: int = 512,
chunking_size: int = 64,
chunking_mask_type: Literal['constant', 'bilinear', 'gaussian'] = 'bilinear',
- chunking_mask_kwargs: Dict[str, Any] = {}
+ chunking_mask_kwargs: Dict[str, Any] = {},
+ temporal_engine_size: int = 16,
+ temporal_chunking_size: int = 4
) -> None:
```
@@ -56,6 +58,8 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
chunking_size: Optional[int] = None,
+ temporal_engine_size: Optional[int] = None,
+ temporal_chunking_size: Optional[int] = None,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
ip_adapter_scale: Optional[float] = None,
@@ -63,7 +67,7 @@ def __call__(
strength: Optional[float] = 0.8,
num_inference_steps: int = 40,
guidance_scale: float = 7.5,
- num_images_per_prompt: int = 1,
+ num_results_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
diff --git a/src/python/enfugue/diffusion/animate/__init__.py b/src/python/enfugue/diffusion/animate/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/python/enfugue/diffusion/animate/diff/__init__.py b/src/python/enfugue/diffusion/animate/diff/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/python/enfugue/diffusion/animate/diff/attention.py b/src/python/enfugue/diffusion/animate/diff/attention.py
new file mode 100644
index 00000000..160d58e9
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/diff/attention.py
@@ -0,0 +1,420 @@
+# type: ignore
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
+from typing import Any, Dict, Optional
+
+import torch
+from torch import nn
+
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention_processor import Attention
+from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero
+from einops import rearrange, repeat
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from dataclasses import dataclass
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+ xformers.ops # quiet importchecker
+else:
+ xformers = None
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True):
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # Input
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length,
+ encoder_attention_mask=encoder_attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x, objs):
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ unet_use_cross_frame_attention = None,
+ unet_use_temporal_attention = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
+ self.unet_use_temporal_attention = unet_use_temporal_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # 4. Temp-Attn
+ assert unet_use_temporal_attention is not None
+ if unet_use_temporal_attention:
+ self.attn_temp = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # 5. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ video_length: Optional[int] = None,
+ ):
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 1. Self-Attention
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 0. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ # 1.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+ # 1.5 ends
+
+ # 2. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ # 4. Temporal-Attention
+ if self.unet_use_temporal_attention:
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
+ )
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
diff --git a/src/python/enfugue/diffusion/animate/diff/motion_module.py b/src/python/enfugue/diffusion/animate/diff/motion_module.py
new file mode 100644
index 00000000..15f3972b
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/diff/motion_module.py
@@ -0,0 +1,375 @@
+# type: ignore
+from dataclasses import dataclass
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import Attention, FeedForward
+
+from einops import rearrange, repeat
+import math
+
+
+def zero_module(module):
+ # Zero out the parameters of a module and return it.
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+@dataclass
+class TemporalTransformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+ xformers.ops # silence importchecker
+else:
+ xformers = None
+
+
+def get_motion_module(
+ in_channels,
+ motion_module_type: str,
+ motion_module_kwargs: dict
+):
+ if motion_module_type == "Vanilla":
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
+ else:
+ raise ValueError
+
+class VanillaTemporalModule(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads = 8,
+ num_transformer_block = 2,
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ temporal_attention_dim_div = 1,
+ attention_scale_multiplier = 1.0,
+ zero_initialize = True,
+ ):
+ super().__init__()
+
+ self.temporal_transformer = TemporalTransformer3DModel(
+ in_channels=in_channels,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
+ num_layers=num_transformer_block,
+ attention_block_types=attention_block_types,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ attention_scale_multiplier=attention_scale_multiplier,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ )
+
+ if zero_initialize:
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
+
+ def set_attention_scale_multiplier(self, attention_scale: float = 1.0) -> None:
+ self.temporal_transformer.set_attention_scale_multiplier(attention_scale)
+
+ def reset_attention_scale_multiplier(self) -> None:
+ self.temporal_transformer.reset_attention_scale_multiplier()
+
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
+ hidden_states = input_tensor
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
+
+ output = hidden_states
+ return output
+
+class TemporalTransformer3DModel(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads,
+ attention_head_dim,
+
+ num_layers,
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
+ dropout = 0.0,
+ norm_num_groups = 32,
+ cross_attention_dim = 768,
+ activation_fn = "geglu",
+ attention_bias = False,
+ upcast_attention = False,
+
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ attention_scale_multiplier = 1.0,
+ ):
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TemporalTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ attention_block_types=attention_block_types,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ upcast_attention=upcast_attention,
+ attention_scale_multiplier=attention_scale_multiplier,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ def set_attention_scale_multiplier(self, attention_scale: float = 1.0) -> None:
+ for block in self.transformer_blocks:
+ block.set_attention_scale_multiplier(attention_scale)
+
+ def reset_attention_scale_multiplier(self) -> None:
+ for block in self.transformer_blocks:
+ block.reset_attention_scale_multiplier()
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Transformer Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
+
+ # output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+
+ return output
+
+
+class TemporalTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_attention_heads,
+ attention_head_dim,
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
+ dropout = 0.0,
+ norm_num_groups = 32,
+ cross_attention_dim = 768,
+ activation_fn = "geglu",
+ attention_bias = False,
+ upcast_attention = False,
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ attention_scale_multiplier = 1.0,
+ ):
+ super().__init__()
+
+ attention_blocks = []
+ norms = []
+
+ for block_name in attention_block_types:
+ attention_blocks.append(
+ VersatileAttention(
+ attention_mode=block_name.split("_")[0],
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
+
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ attention_scale_multiplier=attention_scale_multiplier,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ )
+ )
+ norms.append(nn.LayerNorm(dim))
+
+ self.attention_blocks = nn.ModuleList(attention_blocks)
+ self.norms = nn.ModuleList(norms)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.ff_norm = nn.LayerNorm(dim)
+
+ def set_attention_scale_multiplier(self, attention_scale: float = 1.0) -> None:
+ for block in self.attention_blocks:
+ block.set_scale_multiplier(attention_scale)
+
+ def reset_attention_scale_multiplier(self) -> None:
+ for block in self.attention_blocks:
+ block.reset_scale_multiplier()
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
+ norm_hidden_states = norm(hidden_states)
+ hidden_states = attention_block(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
+ video_length=video_length,
+ ) + hidden_states
+
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
+
+ output = hidden_states
+ return output
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ dropout = 0.,
+ max_len = 24
+ ):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+ pe = torch.zeros(1, max_len, d_model)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1)]
+ return self.dropout(x)
+
+
+class VersatileAttention(Attention):
+ def __init__(
+ self,
+ attention_mode = None,
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ attention_scale_multiplier = 1.0,
+ *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ assert attention_mode == "Temporal"
+
+ self.attention_mode = attention_mode
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
+
+ self.pos_encoder = PositionalEncoding(
+ kwargs["query_dim"],
+ dropout=0.,
+ max_len=temporal_position_encoding_max_len
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
+
+ def set_scale_multiplier(self, multiplier: float = 1.0) -> None:
+ if not hasattr(self, "_default_scale"):
+ self._default_scale = self.scale
+ self.scale = math.sqrt((math.log(24) / math.log(24//4)) / (self.inner_dim // self.heads)) * multiplier
+
+ def reset_scale_multiplier(self) -> None:
+ if hasattr(self, "_default_scale"):
+ self.scale = self._default_scale
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def extra_repr(self):
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ if self.attention_mode == "Temporal":
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+
+ if self.pos_encoder is not None:
+ hidden_states = self.pos_encoder(hidden_states)
+
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
+ else:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ # if self._use_memory_efficient_attention_xformers:
+ # hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ # hidden_states = hidden_states.to(query.dtype)
+ # else:
+ # if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ # hidden_states = self._attention(query, key, value, attention_mask)
+ # else:
+ # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+
+ attention_probs = self.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = self.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ if self.attention_mode == "Temporal":
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
diff --git a/src/python/enfugue/diffusion/animate/diff/resnet.py b/src/python/enfugue/diffusion/animate/diff/resnet.py
new file mode 100644
index 00000000..843e294b
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/diff/resnet.py
@@ -0,0 +1,219 @@
+# type: ignore
+# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/resnet.py
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class InflatedGroupNorm(nn.GroupNorm):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ raise NotImplementedError
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ # if self.use_conv:
+ # if self.name == "conv":
+ # hidden_states = self.conv(hidden_states)
+ # else:
+ # hidden_states = self.Conv2d_0(hidden_states)
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ raise NotImplementedError
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ raise NotImplementedError
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ use_inflated_groupnorm=None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ assert use_inflated_groupnorm != None
+ if use_inflated_groupnorm:
+ self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ else:
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ if use_inflated_groupnorm:
+ self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ else:
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
diff --git a/src/python/enfugue/diffusion/animate/diff/unet.py b/src/python/enfugue/diffusion/animate/diff/unet.py
new file mode 100644
index 00000000..7330c87f
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/diff/unet.py
@@ -0,0 +1,627 @@
+# type: ignore
+from __future__ import annotations
+# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union, Any, Dict
+
+import os
+import json
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from enfugue.diffusion.animate.diff.unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+)
+from enfugue.diffusion.animate.diff.resnet import InflatedConv3d, InflatedGroupNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+
+ use_inflated_groupnorm=False,
+
+ # Additional
+ use_motion_module = False,
+ motion_module_resolutions = ( 1,2,4,8 ),
+ motion_module_mid_block = False,
+ motion_module_decoder_only = False,
+ motion_module_type = None,
+ motion_module_kwargs = {},
+ unet_use_cross_frame_attention = None,
+ unet_use_temporal_attention = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ res = 2 ** i
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
+ use_motion_module=use_motion_module and motion_module_mid_block,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ res = 2 ** (3 - i)
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if use_inflated_groupnorm:
+ self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ else:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def set_motion_attention_scale(self, scale: float = 1.0) -> None:
+ for block in self.down_blocks + self.up_blocks + [self.mid_block]:
+ block.set_motion_module_attention_scale(scale)
+
+ def reset_motion_attention_scale(self) -> None:
+ for block in self.down_blocks + self.up_blocks + [self.mid_block]:
+ block.reset_motion_module_attention_scale()
+
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ **kwargs: Any
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
+
+ down_block_res_samples += res_samples
+
+ # down controlnet
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+ down_block_res_samples = new_down_block_res_samples
+
+ # mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+
+ # mid controlnet
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual
+
+ # up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
+ )
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **unet_additional_kwargs)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
+
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
+
+ return model
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ return self.get_attn_processors(include_temporal_layers=False)
+
+ def get_attn_processors(self, include_temporal_layers=True) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+
+ if not include_temporal_layers:
+ if 'temporal' in name:
+ return processors
+
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
+ include_temporal_layers=False):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.get_attn_processors(include_temporal_layers=include_temporal_layers).keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+
+ if not include_temporal_layers:
+ if "temporal" in name:
+ return
+
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
diff --git a/src/python/enfugue/diffusion/animate/diff/unet_blocks.py b/src/python/enfugue/diffusion/animate/diff/unet_blocks.py
new file mode 100644
index 00000000..420341e0
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/diff/unet_blocks.py
@@ -0,0 +1,853 @@
+# type: ignore
+# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
+
+import torch
+from torch import nn
+from diffusers.utils.torch_utils import apply_freeu
+
+from enfugue.diffusion.animate.diff.attention import Transformer3DModel
+from enfugue.diffusion.animate.diff.resnet import Downsample3D, ResnetBlock3D, Upsample3D
+from enfugue.diffusion.animate.diff.motion_module import get_motion_module
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
+
+ use_motion_module=None,
+
+ motion_module_type=None,
+ motion_module_kwargs=None,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
+ use_motion_module=use_motion_module,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
+ use_motion_module=use_motion_module,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
+
+ use_motion_module=None,
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ resolution_idx=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
+ use_motion_module=use_motion_module,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ resolution_idx=resolution_idx,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ use_inflated_groupnorm=use_inflated_groupnorm,
+
+ use_motion_module=use_motion_module,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ resolution_idx=resolution_idx,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
+
+ use_motion_module=None,
+
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
+ )
+ ]
+ attentions = []
+ motion_modules = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=in_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ def set_motion_module_attention_scale(self, scale: float = 1.0) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.set_attention_scale_multiplier(scale)
+
+ def reset_motion_module_attention_scale(self) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.reset_attention_scale_multiplier()
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
+
+ use_motion_module=None,
+
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ motion_modules = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=out_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_motion_module_attention_scale(self, scale: float = 1.0) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.set_attention_scale_multiplier(scale)
+
+ def reset_motion_module_attention_scale(self) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.reset_attention_scale_multiplier()
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ output_states = ()
+
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ if motion_module is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
+
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ # add motion module
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+
+ use_inflated_groupnorm=None,
+
+ use_motion_module=None,
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ ):
+ super().__init__()
+ resnets = []
+ motion_modules = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=out_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_motion_module_attention_scale(self, scale: float = 1.0) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.set_attention_scale_multiplier(scale)
+
+ def reset_motion_module_attention_scale(self) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.reset_attention_scale_multiplier()
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ if motion_module is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ # add motion module
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+ use_inflated_groupnorm=None,
+
+ use_motion_module=None,
+
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ resolution_idx=None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ motion_modules = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=out_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def set_motion_module_attention_scale(self, scale: float = 1.0) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.set_attention_scale_multiplier(scale)
+
+ def reset_motion_module_attention_scale(self) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.reset_attention_scale_multiplier()
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ ):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ if motion_module is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
+
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ # add motion module
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+
+ use_inflated_groupnorm=None,
+
+ use_motion_module=None,
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ resolution_idx=None
+ ):
+ super().__init__()
+ resnets = []
+ motion_modules = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+
+ use_inflated_groupnorm=use_inflated_groupnorm,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=out_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def set_motion_module_attention_scale(self, scale: float = 1.0) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.set_attention_scale_multiplier(scale)
+
+ def reset_motion_module_attention_scale(self) -> None:
+ for motion_module in self.motion_modules:
+ if motion_module is not None:
+ motion_module.reset_attention_scale_multiplier()
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ if motion_module is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
diff --git a/src/python/enfugue/diffusion/animate/hotshot/__init__.py b/src/python/enfugue/diffusion/animate/hotshot/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/python/enfugue/diffusion/animate/hotshot/resnet.py b/src/python/enfugue/diffusion/animate/hotshot/resnet.py
new file mode 100644
index 00000000..818c8ea9
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/hotshot/resnet.py
@@ -0,0 +1,135 @@
+# type: ignore
+# Copyright 2023 Natural Synthetics Inc. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+
+import torch
+import torch.nn as nn
+from diffusers.models.resnet import Upsample2D, Downsample2D, LoRACompatibleConv
+from einops import rearrange
+
+
+class Upsample3D(Upsample2D):
+ def forward(self, hidden_states, output_size=None, scale: float = 1.0):
+ f = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ hidden_states = super(Upsample3D, self).forward(hidden_states, output_size, scale)
+ return rearrange(hidden_states, "(b f) c h w -> b c f h w", f=f)
+
+
+class Downsample3D(Downsample2D):
+
+ def forward(self, hidden_states, scale: float = 1.0):
+ f = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ hidden_states = super(Downsample3D, self).forward(hidden_states, scale)
+ return rearrange(hidden_states, "(b f) c h w -> b c f h w", f=f)
+
+
+class Conv3d(LoRACompatibleConv):
+ def forward(self, hidden_states, scale: float = 1.0):
+ f = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ hidden_states = super().forward(hidden_states, scale)
+ return rearrange(hidden_states, "(b f) c h w -> b c f h w", f=f)
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="silu",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ conv_shortcut_bias: bool = True,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+ self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ assert non_linearity == "silu"
+
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = Conv3d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
+ )
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.nonlinearity(temb)
+ temb = self.time_emb_proj(temb)[:, :, None, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
diff --git a/src/python/enfugue/diffusion/animate/hotshot/transformer_3d.py b/src/python/enfugue/diffusion/animate/hotshot/transformer_3d.py
new file mode 100644
index 00000000..41b2579d
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/hotshot/transformer_3d.py
@@ -0,0 +1,76 @@
+# type: ignore
+# Copyright 2023 Natural Synthetics Inc. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+
+from dataclasses import dataclass
+from typing import Optional
+import torch
+from torch import nn
+from diffusers.utils import BaseOutput
+from diffusers.models.transformer_2d import Transformer2DModel
+from einops import rearrange, repeat
+from typing import Dict, Any
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer3DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
+ The hidden states output conditioned on the `encoder_hidden_states` input.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Transformer3DModel(Transformer2DModel):
+
+ def __init__(self, *args, **kwargs):
+ super(Transformer3DModel, self).__init__(*args, **kwargs)
+ nn.init.zeros_(self.proj_out.weight.data)
+ nn.init.zeros_(self.proj_out.bias.data)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ enable_temporal_layers: bool = True,
+ positional_embedding: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+
+ is_video = len(hidden_states.shape) == 5
+
+ if is_video:
+ f = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=f)
+
+ hidden_states = super(Transformer3DModel, self).forward(hidden_states,
+ encoder_hidden_states,
+ timestep,
+ class_labels,
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ return_dict=False)[0]
+
+ if is_video:
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=f)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer3DModelOutput(sample=hidden_states)
diff --git a/src/python/enfugue/diffusion/animate/hotshot/transformer_temporal.py b/src/python/enfugue/diffusion/animate/hotshot/transformer_temporal.py
new file mode 100644
index 00000000..3459001b
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/hotshot/transformer_temporal.py
@@ -0,0 +1,228 @@
+# type: ignore
+# Copyright 2023 Natural Synthetics Inc. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+
+import torch
+import math
+from dataclasses import dataclass
+from torch import nn
+from diffusers.utils import BaseOutput
+from diffusers.models.attention import Attention, FeedForward
+from einops import rearrange, repeat
+from typing import Optional
+
+
+class PositionalEncoding(nn.Module):
+ """
+ Implements positional encoding as described in "Attention Is All You Need".
+ Adds sinusoidal based positional encodings to the input tensor.
+ """
+
+ _SCALE_FACTOR = 10000.0 # Scale factor used in the positional encoding computation.
+
+ def __init__(self, dim: int, dropout: float = 0.0, max_length: int = 24):
+ super(PositionalEncoding, self).__init__()
+
+ self.dropout = nn.Dropout(p=dropout)
+
+ # The size is (1, max_length, dim) to allow easy addition to input tensors.
+ positional_encoding = torch.zeros(1, max_length, dim)
+
+ # Position and dim are used in the sinusoidal computation.
+ position = torch.arange(max_length).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(self._SCALE_FACTOR) / dim))
+
+ positional_encoding[0, :, 0::2] = torch.sin(position * div_term)
+ positional_encoding[0, :, 1::2] = torch.cos(position * div_term)
+
+ # Register the positional encoding matrix as a buffer,
+ # so it's part of the model's state but not the parameters.
+ self.register_buffer('positional_encoding', positional_encoding)
+
+ def forward(self, hidden_states: torch.Tensor, length: int) -> torch.Tensor:
+ hidden_states = hidden_states + self.positional_encoding[:, :length]
+ return self.dropout(hidden_states)
+
+
+class TemporalAttention(Attention):
+ def __init__(
+ self,
+ positional_encoding_max_length: int = 24,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.pos_encoder = PositionalEncoding(
+ kwargs["query_dim"],
+ dropout=0,
+ max_length=positional_encoding_max_length
+ )
+
+ def set_scale_multiplier(self, multiplier: float = 1.0) -> None:
+ self.scale = math.sqrt((math.log(24) / math.log(24//4)) / (self.inner_dim // self.heads)) * multiplier
+
+ def reset_scale_multiplier(self) -> None:
+ self.scale = (self.inner_dim // self.heads) ** -0.5
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, number_of_frames=8):
+ sequence_length = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) s c -> (b s) f c", f=number_of_frames)
+ hidden_states = self.pos_encoder(hidden_states, length=number_of_frames)
+
+ if encoder_hidden_states:
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b s) n c", s=sequence_length)
+
+ hidden_states = super().forward(hidden_states, encoder_hidden_states, attention_mask=attention_mask)
+
+ return rearrange(hidden_states, "(b s) f c -> (b f) s c", s=sequence_length)
+
+
+@dataclass
+class TransformerTemporalOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class TransformerTemporal(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ in_channels: int,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ positional_encoding_max_length: int = 24,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ upcast_attention=upcast_attention,
+ cross_attention_dim=cross_attention_dim,
+ positional_encoding_max_length=positional_encoding_max_length
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ def set_attention_scale_multiplier(self, attention_scale: float = 1.0) -> None:
+ for block in self.transformer_blocks:
+ block.set_attention_scale_multiplier(attention_scale)
+
+ def reset_attention_scale_multiplier(self) -> None:
+ for block in self.transformer_blocks:
+ block.reset_attention_scale_multiplier()
+
+ def forward(self, hidden_states, encoder_hidden_states=None):
+ _, num_channels, f, height, width = hidden_states.shape
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+
+ skip = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ hidden_states = rearrange(hidden_states, "bf c h w -> bf (h w) c")
+ hidden_states = self.proj_in(hidden_states)
+
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, number_of_frames=f)
+
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = rearrange(hidden_states, "bf (h w) c -> bf c h w", h=height, w=width).contiguous()
+
+ output = hidden_states + skip
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=f)
+
+ return output
+
+
+class TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=0.0,
+ activation_fn="geglu",
+ attention_bias=False,
+ upcast_attention=False,
+ depth=2,
+ positional_encoding_max_length=24,
+ cross_attention_dim: Optional[int] = None
+ ):
+ super().__init__()
+
+ self.is_cross = cross_attention_dim is not None
+
+ attention_blocks = []
+ norms = []
+
+ for _ in range(depth):
+ attention_blocks.append(
+ TemporalAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ positional_encoding_max_length=positional_encoding_max_length
+ )
+ )
+ norms.append(nn.LayerNorm(dim))
+
+ self.attention_blocks = nn.ModuleList(attention_blocks)
+ self.norms = nn.ModuleList(norms)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.ff_norm = nn.LayerNorm(dim)
+
+ def set_attention_scale_multiplier(self, attention_scale: float = 1.0) -> None:
+ for block in self.attention_blocks:
+ block.set_scale_multiplier(attention_scale)
+
+ def reset_attention_scale_multiplier(self) -> None:
+ for block in self.attention_blocks:
+ block.reset_scale_multiplier()
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, number_of_frames=None):
+
+ if not self.is_cross:
+ encoder_hidden_states = None
+
+ for block, norm in zip(self.attention_blocks, self.norms):
+ norm_hidden_states = norm(hidden_states)
+ hidden_states = block(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ number_of_frames=number_of_frames
+ ) + hidden_states
+
+ norm_hidden_states = self.ff_norm(hidden_states)
+ hidden_states = self.ff(norm_hidden_states) + hidden_states
+
+ output = hidden_states
+ return output
diff --git a/src/python/enfugue/diffusion/animate/hotshot/unet.py b/src/python/enfugue/diffusion/animate/hotshot/unet.py
new file mode 100644
index 00000000..067c948b
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/hotshot/unet.py
@@ -0,0 +1,1012 @@
+# type: ignore
+from __future__ import annotations
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Modifications:
+# Copyright 2023 Natural Synthetics Inc. All rights reserved.
+# - Unet now supports SDXL
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union, Dict
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from enfugue.diffusion.animate.hotshot.unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+)
+
+from enfugue.diffusion.animate.hotshot.resnet import Conv3d
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ positional_encoding_max_length=24,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+
+ self.conv_in = Conv3d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ res = 2 ** i
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ positional_encoding_max_length=positional_encoding_max_length,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ raise ValueError("UNetMidBlock2DSimpleCrossAttn not supported")
+
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ res = 2 ** (len(up_block_types) - 1 - i)
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ positional_encoding_max_length=positional_encoding_max_length,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+
+ self.conv_out = Conv3d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel,
+ padding=conv_out_padding)
+
+ def set_motion_attention_scale(self, scale: float = 1.0) -> None:
+ for block in self.down_blocks + self.up_blocks + [self.mid_block]:
+ if not isinstance(block, UNetMidBlock3DCrossAttn):
+ block.set_temporal_attention_scale(scale)
+
+ def reset_motion_attention_scale(self) -> None:
+ for block in self.down_blocks + self.up_blocks + [self.mid_block]:
+ if not isinstance(block, UNetMidBlock3DCrossAttn):
+ block.set_temporal_attention_scale()
+
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def temporal_parameters(self) -> list:
+ output = []
+ all_blocks = self.down_blocks + self.up_blocks + [self.mid_block]
+ for block in all_blocks:
+ output.extend(block.temporal_parameters())
+ return output
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ return self.get_attn_processors(include_temporal_layers=False)
+
+ def get_attn_processors(self, include_temporal_layers=True) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+
+ if not include_temporal_layers:
+ if 'temporal' in name or 'motion' in name:
+ return processors
+
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
+ include_temporal_layers=False):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.get_attn_processors(include_temporal_layers=include_temporal_layers).keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+
+ if not include_temporal_layers:
+ if "temporal" in name or "motion" in name:
+ return
+
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ enable_temporal_attentions: bool = True
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2 ** self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ # 2. pre-process
+
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ enable_temporal_attentions=enable_temporal_attentions
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ enable_temporal_attentions=enable_temporal_attentions)
+
+ down_block_res_samples += res_samples
+
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ enable_temporal_attentions=enable_temporal_attentions
+ )
+
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ enable_temporal_attentions=enable_temporal_attentions
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ encoder_hidden_states=encoder_hidden_states,
+ enable_temporal_attentions=enable_temporal_attentions
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
+
+ @classmethod
+ def from_pretrained_spatial(cls, pretrained_model_path, subfolder=None):
+
+ import os
+ import json
+
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ config["_class_name"] = "UNet3DConditionModel"
+
+ config["down_block_types"] = [
+ "DownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ ]
+ config["up_block_types"] = [
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "UpBlock3D"
+ ]
+
+ config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
diff --git a/src/python/enfugue/diffusion/animate/hotshot/unet_blocks.py b/src/python/enfugue/diffusion/animate/hotshot/unet_blocks.py
new file mode 100644
index 00000000..5bedf1d4
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/hotshot/unet_blocks.py
@@ -0,0 +1,844 @@
+# type: ignore
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Modifications:
+# Copyright 2023 Natural Synthetics Inc. All rights reserved.
+# - Add temporal transformers to unet blocks
+
+import torch
+from torch import nn
+
+from enfugue.diffusion.animate.hotshot.transformer_3d import Transformer3DModel
+from enfugue.diffusion.animate.hotshot.resnet import Downsample3D, ResnetBlock3D, Upsample3D
+from enfugue.diffusion.animate.hotshot.transformer_temporal import TransformerTemporal
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ transformer_layers_per_block=1,
+ num_attention_heads=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+ attention_head_dim=None,
+ downsample_type=None,
+ positional_encoding_max_length=24,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ positional_encoding_max_length=positional_encoding_max_length,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ transformer_layers_per_block=transformer_layers_per_block,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ positional_encoding_max_length=positional_encoding_max_length,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ transformer_layers_per_block=1,
+ num_attention_heads=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+ attention_head_dim=None,
+ upsample_type=None,
+ positional_encoding_max_length=24,
+ resolution_idx=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ positional_encoding_max_length=positional_encoding_max_length,
+ resolution_idx=resolution_idx,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ transformer_layers_per_block=transformer_layers_per_block,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ positional_encoding_max_length=positional_encoding_max_length,
+ resolution_idx=resolution_idx,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ cross_attention_kwargs=None,
+ enable_temporal_attentions: bool = True
+ ):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+ def temporal_parameters(self) -> list:
+ return []
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ positional_encoding_max_length=24,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ temporal_attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ temporal_attentions.append(
+ TransformerTemporal(
+ num_attention_heads=8,
+ attention_head_dim=out_channels // 8,
+ in_channels=out_channels,
+ cross_attention_dim=None,
+ positional_encoding_max_length=positional_encoding_max_length,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.temporal_attentions = nn.ModuleList(temporal_attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
+ cross_attention_kwargs=None, enable_temporal_attentions: bool = True):
+ output_states = ()
+
+ for resnet, attn, temporal_attention \
+ in zip(self.resnets, self.attentions, self.temporal_attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb,
+ use_reentrant=False)
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ use_reentrant=False
+ )[0]
+ if enable_temporal_attentions and temporal_attention is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temporal_attention),
+ hidden_states, encoder_hidden_states,
+ use_reentrant=False)
+
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ if temporal_attention and enable_temporal_attentions:
+ hidden_states = temporal_attention(hidden_states,
+ encoder_hidden_states=encoder_hidden_states)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+ def set_temporal_attention_scale(self, scale: float = 1.0) -> None:
+ for temporal_attention in self.temporal_attentions:
+ temporal_attention.set_attention_scale_multiplier(scale)
+
+ def reset_temporal_attention_scale(self) -> None:
+ for temporal_attention in self.temporal_attentions:
+ temporal_attention.reset_attention_scale_multiplier()
+
+ def temporal_parameters(self) -> list:
+ output = []
+ for block in self.temporal_attentions:
+ if block:
+ output.extend(block.parameters())
+ return output
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ positional_encoding_max_length=24,
+ ):
+ super().__init__()
+ resnets = []
+ temporal_attentions = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temporal_attentions.append(
+ TransformerTemporal(
+ num_attention_heads=8,
+ attention_head_dim=out_channels // 8,
+ in_channels=out_channels,
+ cross_attention_dim=None,
+ positional_encoding_max_length=positional_encoding_max_length
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temporal_attentions = nn.ModuleList(temporal_attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def set_temporal_attention_scale(self, scale: float = 1.0) -> None:
+ for temporal_attention in self.temporal_attentions:
+ temporal_attention.set_attention_scale_multiplier(scale)
+
+ def reset_temporal_attention_scale(self) -> None:
+ for temporal_attention in self.temporal_attentions:
+ temporal_attention.reset_attention_scale_multiplier()
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, enable_temporal_attentions: bool = True):
+ output_states = ()
+
+ for resnet, temporal_attention in zip(self.resnets, self.temporal_attentions):
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb,
+ use_reentrant=False)
+ if enable_temporal_attentions and temporal_attention is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temporal_attention),
+ hidden_states, encoder_hidden_states,
+ use_reentrant=False)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if enable_temporal_attentions and temporal_attention:
+ hidden_states = temporal_attention(hidden_states, encoder_hidden_states=encoder_hidden_states)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+ def temporal_parameters(self) -> list:
+ output = []
+ for block in self.temporal_attentions:
+ if block:
+ output.extend(block.parameters())
+ return output
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ positional_encoding_max_length=24,
+ resolution_idx=None
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ temporal_attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ temporal_attentions.append(
+ TransformerTemporal(
+ num_attention_heads=8,
+ attention_head_dim=out_channels // 8,
+ in_channels=out_channels,
+ cross_attention_dim=None,
+ positional_encoding_max_length=positional_encoding_max_length
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.temporal_attentions = nn.ModuleList(temporal_attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def set_temporal_attention_scale(self, scale: float = 1.0) -> None:
+ for temporal_attention in self.temporal_attentions:
+ temporal_attention.set_attention_scale_multiplier(scale)
+
+ def reset_temporal_attention_scale(self) -> None:
+ for temporal_attention in self.temporal_attentions:
+ temporal_attention.reset_attention_scale_multiplier()
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ cross_attention_kwargs=None,
+ attention_mask=None,
+ enable_temporal_attentions: bool = True
+ ):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ for resnet, attn, temporal_attention \
+ in zip(self.resnets, self.attentions, self.temporal_attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb,
+ use_reentrant=False)
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ use_reentrant=False,
+ )[0]
+ if enable_temporal_attentions and temporal_attention is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temporal_attention),
+ hidden_states, encoder_hidden_states,
+ use_reentrant=False)
+
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ if enable_temporal_attentions and temporal_attention:
+ hidden_states = temporal_attention(hidden_states,
+ encoder_hidden_states=encoder_hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def temporal_parameters(self) -> list:
+ output = []
+ for block in self.temporal_attentions:
+ if block:
+ output.extend(block.parameters())
+ return output
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ positional_encoding_max_length=24,
+ resolution_idx=None
+ ):
+ super().__init__()
+ resnets = []
+ temporal_attentions = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temporal_attentions.append(
+ TransformerTemporal(
+ num_attention_heads=8,
+ attention_head_dim=out_channels // 8,
+ in_channels=out_channels,
+ cross_attention_dim=None,
+ positional_encoding_max_length=positional_encoding_max_length
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temporal_attentions = nn.ModuleList(temporal_attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def set_temporal_attention_scale(self, scale: float = 1.0) -> None:
+ for temporal_attention in self.temporal_attentions:
+ temporal_attention.set_attention_scale_multiplier(scale)
+
+ def reset_temporal_attention_scale(self) -> None:
+ for temporal_attention in self.temporal_attentions:
+ temporal_attention.reset_attention_scale_multiplier()
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ upsample_size=None,
+ encoder_hidden_states=None,
+ enable_temporal_attentions: bool = True
+ ):
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ for resnet, temporal_attention in zip(self.resnets, self.temporal_attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb,
+ use_reentrant=False)
+ if enable_temporal_attentions and temporal_attention is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temporal_attention),
+ hidden_states, encoder_hidden_states,
+ use_reentrant=False)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temporal_attention(hidden_states,
+ encoder_hidden_states=encoder_hidden_states) if enable_temporal_attentions and temporal_attention is not None else hidden_states
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def temporal_parameters(self) -> list:
+ output = []
+ for block in self.temporal_attentions:
+ if block:
+ output.extend(block.parameters())
+ return output
diff --git a/src/python/enfugue/diffusion/animate/pipeline.py b/src/python/enfugue/diffusion/animate/pipeline.py
new file mode 100644
index 00000000..9b37dc45
--- /dev/null
+++ b/src/python/enfugue/diffusion/animate/pipeline.py
@@ -0,0 +1,546 @@
+# Inspired by https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
+from __future__ import annotations
+
+import os
+import torch
+import torch.nn.functional as F
+
+from typing import Optional, Dict, Any, Union, Callable, List, TYPE_CHECKING
+
+from pibble.util.files import load_json
+
+from diffusers.utils import WEIGHTS_NAME, DIFFUSERS_CACHE
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.schedulers import EulerDiscreteScheduler
+
+from einops import rearrange
+
+from enfugue.util import check_download_to_dir, logger
+from enfugue.diffusion.pipeline import EnfugueStableDiffusionPipeline
+
+from enfugue.diffusion.animate.diff.unet import UNet3DConditionModel as AnimateDiffUNet # type: ignore[attr-defined]
+from enfugue.diffusion.animate.hotshot.unet import UNet3DConditionModel as HotshotUNet # type: ignore[attr-defined]
+
+if TYPE_CHECKING:
+ from transformers import (
+ CLIPTokenizer,
+ CLIPTextModel,
+ CLIPImageProcessor,
+ CLIPTextModelWithProjection,
+ )
+ from diffusers.models import (
+ AutoencoderKL,
+ AutoencoderTiny,
+ ControlNetModel,
+ UNet2DConditionModel,
+ )
+ from diffusers.pipelines.stable_diffusion import (
+ StableDiffusionSafetyChecker
+ )
+ from diffusers.schedulers import KarrasDiffusionSchedulers
+ from enfugue.diffusion.support.ip import IPAdapter
+ from enfugue.diffusion.util import Chunker, MaskWeightBuilder
+ from enfugue.diffusion.constants import MASK_TYPE_LITERAL
+
+class EnfugueAnimateStableDiffusionPipeline(EnfugueStableDiffusionPipeline):
+ unet_3d: Optional[Union[AnimateDiffUNet, HotshotUNet]]
+ vae: AutoencoderKL
+
+ STATIC_SCHEDULER_KWARGS = {
+ "num_train_timesteps": 1000,
+ "beta_start": 0.00085,
+ "beta_end": 0.01005,
+ "beta_schedule": "linear"
+ }
+
+ HOTSHOT_XL_PATH = "hotshotco/Hotshot-XL"
+ MOTION_MODULE_V2 = "https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt"
+ MOTION_MODULE = "https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15.ckpt"
+ MOTION_MODULE_PE_KEY = "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe"
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ vae_preview: AutoencoderTiny,
+ text_encoder: Optional[CLIPTextModel],
+ text_encoder_2: Optional[CLIPTextModelWithProjection],
+ tokenizer: Optional[CLIPTokenizer],
+ tokenizer_2: Optional[CLIPTokenizer],
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ force_zeros_for_empty_prompt: bool = True,
+ requires_aesthetic_score: bool = False,
+ force_full_precision_vae: bool = False,
+ controlnets: Optional[Dict[str, ControlNetModel]] = None,
+ ip_adapter: Optional[IPAdapter] = None,
+ engine_size: int = 512, # Recommended even for machines that can handle more
+ tiling_size: Optional[int] = None,
+ tiling_stride: Optional[int] = 32,
+ tiling_mask_type: MASK_TYPE_LITERAL = "bilinear",
+ tiling_mask_kwargs: Dict[str, Any] = {},
+ frame_window_size: Optional[int] = 16,
+ frame_window_stride: Optional[int] = 4,
+ override_scheduler_config: bool = True,
+ ) -> None:
+ super(EnfugueAnimateStableDiffusionPipeline, self).__init__(
+ vae=vae,
+ vae_preview=vae_preview,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=requires_safety_checker,
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
+ force_full_precision_vae=force_full_precision_vae,
+ requires_aesthetic_score=requires_aesthetic_score,
+ controlnets=controlnets,
+ ip_adapter=ip_adapter,
+ engine_size=engine_size,
+ tiling_stride=tiling_stride,
+ tiling_size=tiling_size,
+ tiling_mask_type=tiling_mask_type,
+ tiling_mask_kwargs=tiling_mask_kwargs,
+ frame_window_size=frame_window_size,
+ frame_window_stride=frame_window_stride
+ )
+
+ if override_scheduler_config:
+ self.scheduler_config = {
+ **self.scheduler_config,
+ **EnfugueAnimateStableDiffusionPipeline.STATIC_SCHEDULER_KWARGS
+ }
+ self.scheduler.register_to_config( # type: ignore[attr-defined]
+ **EnfugueAnimateStableDiffusionPipeline.STATIC_SCHEDULER_KWARGS
+ )
+
+ if not self.is_sdxl and not isinstance(self.scheduler, EulerDiscreteScheduler):
+ logger.debug(f"Animation pipeline changing default scheduler from {type(self.scheduler).__name__} to Euler Discrete")
+ self.scheduler = EulerDiscreteScheduler.from_config(self.scheduler_config)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ motion_module: Optional[str] = None,
+ position_encoding_truncate_length: Optional[int]=None,
+ position_encoding_scale_length: Optional[int]=None,
+ task_callback: Optional[Callable[[str], None]]=None,
+ **kwargs: Any
+ ) -> EnfugueAnimateStableDiffusionPipeline:
+ """
+ Override from_pretrained to reload the unet as a 3D condition model instead.
+ """
+ pipe = super(EnfugueAnimateStableDiffusionPipeline, cls).from_pretrained(pretrained_model_name_or_path, **kwargs)
+ unet_dir = os.path.join(pretrained_model_name_or_path, "unet") # type: ignore[arg-type]
+ unet_config = os.path.join(unet_dir, "config.json")
+ unet_weights = os.path.join(unet_dir, WEIGHTS_NAME)
+
+ is_sdxl = os.path.exists(os.path.join(pretrained_model_name_or_path, "text_encoder_2")) # type: ignore[arg-type]
+
+ if not os.path.exists(unet_config):
+ raise IOError(f"Couldn't find UNet config at {unet_config}")
+ if not os.path.exists(unet_weights):
+ # Check for safetensors version
+ safetensors_weights = os.path.join(unet_dir, "diffusion_pytorch_model.safetensors")
+ if os.path.exists(safetensors_weights):
+ unet_weights = safetensors_weights
+ else:
+ raise IOError(f"Couldn't find UNet weights at {unet_weights} or {safetensors_weights}")
+
+ unet = cls.create_unet(
+ load_json(unet_config),
+ kwargs.get("cache_dir", DIFFUSERS_CACHE),
+ is_sdxl=is_sdxl,
+ is_inpainter=False,
+ motion_module=motion_module,
+ position_encoding_truncate_length=position_encoding_truncate_length,
+ position_encoding_scale_length=position_encoding_scale_length,
+ task_callback=task_callback,
+ )
+
+ from enfugue.diffusion.util.torch_util import load_state_dict
+ state_dict = load_state_dict(unet_weights)
+
+ for key in list(state_dict.keys()):
+ if "motion" in key or "temporal" in key:
+ state_dict.pop(key)
+
+ unet.load_state_dict(state_dict, strict=False)
+
+ if "torch_dtype" in kwargs:
+ unet = unet.to(kwargs["torch_dtype"])
+
+ pipe.unet = unet
+ return pipe
+
+ @classmethod
+ def create_unet(
+ cls,
+ config: Dict[str, Any],
+ cache_dir: str,
+ is_sdxl: bool,
+ is_inpainter: bool,
+ task_callback: Optional[Callable[[str], None]]=None,
+ **unet_additional_kwargs: Any
+ ) -> ModelMixin:
+ """
+ Creates the 3D Unet
+ """
+ use_mm_v2: bool = unet_additional_kwargs.pop("use_mm_v2", True)
+ motion_module: Optional[str] = unet_additional_kwargs.pop("motion_module", None)
+ position_encoding_truncate_length: Optional[int] = unet_additional_kwargs.pop("position_encoding_truncate_length", None)
+ position_encoding_scale_length: Optional[int] = unet_additional_kwargs.pop("position_encoding_scale_length", None)
+ if config.get("sample_size", 64) == 128:
+ # SDXL, instantiate Hotshot XL UNet
+ return cls.create_hotshot_unet(
+ config=config,
+ cache_dir=cache_dir,
+ motion_module=motion_module,
+ task_callback=task_callback,
+ position_encoding_truncate_length=position_encoding_truncate_length,
+ position_encoding_scale_length=position_encoding_scale_length,
+ **unet_additional_kwargs
+ )
+ return cls.create_diff_unet(
+ config=config,
+ cache_dir=cache_dir,
+ use_mm_v2=use_mm_v2,
+ motion_module=motion_module,
+ task_callback=task_callback,
+ position_encoding_truncate_length=position_encoding_truncate_length,
+ position_encoding_scale_length=position_encoding_scale_length,
+ **unet_additional_kwargs
+ )
+
+ @classmethod
+ def create_hotshot_unet(
+ cls,
+ config: Dict[str, Any],
+ cache_dir: str,
+ motion_module: Optional[str]=None,
+ task_callback: Optional[Callable[[str], None]]=None,
+ position_encoding_truncate_length: Optional[int]=None,
+ position_encoding_scale_length: Optional[int]=None,
+ **unet_additional_kwargs: Any
+ ) -> ModelMixin:
+ """
+ Creates a UNet3DConditionModel then loads hotshot into it
+ """
+ config["_class_name"] = "UNet3DConditionModel"
+ config["down_block_types"] = [
+ "DownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ ]
+ config["up_block_types"] = [
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "UpBlock3D"
+ ]
+ if position_encoding_scale_length:
+ config["positional_encoding_max_length"] = position_encoding_scale_length
+
+ # Instantiate from 2D model config
+ model = HotshotUNet.from_config(config)
+
+ # Load motion weights into it
+ cls.load_hotshot_state_dict(
+ unet=model,
+ cache_dir=cache_dir,
+ motion_module=motion_module,
+ task_callback=task_callback,
+ position_encoding_truncate_length=position_encoding_truncate_length,
+ position_encoding_scale_length=position_encoding_scale_length,
+ )
+ return model
+
+ @classmethod
+ def load_hotshot_state_dict(
+ cls,
+ unet: HotshotUNet,
+ cache_dir: str,
+ motion_module: Optional[str]=None,
+ task_callback: Optional[Callable[[str], None]]=None,
+ position_encoding_truncate_length: Optional[int]=None,
+ position_encoding_scale_length: Optional[int]=None,
+ ) -> None:
+ """
+ Loads pretrained hotshot weights into the UNet
+ """
+ if motion_module is None:
+ motion_module = cls.HOTSHOT_XL_PATH
+
+ if task_callback is not None:
+ task_callback(f"Loading HotshotXL repository {cls.HOTSHOT_XL_PATH}")
+
+ hotshot_unet = HotshotUNet.from_pretrained(
+ motion_module,
+ subfolder="unet",
+ cache_dir=cache_dir,
+ )
+
+ logger.debug(f"Loading HotShot XL motion module {motion_module} with truncate length '{position_encoding_truncate_length}' and scale length '{position_encoding_scale_length}'")
+
+ hotshot_state_dict = hotshot_unet.state_dict()
+ for key in list(hotshot_state_dict.keys()):
+ if "temporal" not in key:
+ hotshot_state_dict.pop(key)
+ elif key.endswith(".positional_encoding"):
+ if position_encoding_truncate_length is not None:
+ hotshot_state_dict[key] = hotshot_state_dict[key][:, :position_encoding_truncate_length]
+ if position_encoding_scale_length is not None:
+ tensor_shape = hotshot_state_dict[key].shape
+ tensor = rearrange(hotshot_state_dict[key], "(t b) f d -> t b f d", t=1)
+ tensor = F.interpolate(tensor, size=(position_encoding_scale_length, tensor_shape[-1]), mode="bilinear")
+ hotshot_state_dict[key] = rearrange(tensor, "t b f d -> (t b) f d")
+ del tensor
+
+ num_motion_keys = len(hotshot_state_dict.keys())
+ logger.debug(f"Loading {num_motion_keys} keys into UNet state dict (non-strict)")
+ unet.load_state_dict(hotshot_state_dict, strict=False)
+ del hotshot_state_dict
+ del hotshot_unet
+
+ @classmethod
+ def create_diff_unet(
+ cls,
+ config: Dict[str, Any],
+ cache_dir: str,
+ use_mm_v2: bool=True,
+ motion_module: Optional[str]=None,
+ task_callback: Optional[Callable[[str], None]]=None,
+ position_encoding_truncate_length: Optional[int]=None,
+ position_encoding_scale_length: Optional[int]=None,
+ **unet_additional_kwargs: Any
+ ) -> ModelMixin:
+ """
+ Creates a UNet3DConditionModel then loads MM into it
+ """
+ config["_class_name"] = "UNet3DConditionModel"
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ if motion_module is not None:
+ # Detect MM version
+ from enfugue.diffusion.util import load_state_dict
+ logger.debug(f"Loading motion module {motion_module} to detect MMV1/2")
+ state_dict = load_state_dict(motion_module)
+
+ if cls.MOTION_MODULE_PE_KEY in state_dict:
+ position_tensor: torch.Tensor = state_dict[cls.MOTION_MODULE_PE_KEY] # type: ignore[assignment]
+ if position_tensor.shape[1] == 24:
+ use_mm_v2 = False
+ logger.debug("Detected MMV1")
+ elif position_tensor.shape[1] == 32:
+ use_mm_v2 = True
+ logger.debug("Detected MMV2")
+ else:
+ raise ValueError(f"Position encoder tensor has unsupported length {position_tensor.shape[1]}")
+ else:
+ raise ValueError(f"Couldn't detect motion module version from {motion_module}. It may be an unsupported format.")
+
+ motion_module = state_dict # type: ignore[assignment]
+
+ config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
+ default_position_encoding_len = 32 if use_mm_v2 else 24
+ position_encoding_len = default_position_encoding_len
+ if position_encoding_scale_length:
+ position_encoding_len = position_encoding_scale_length
+
+ unet_additional_kwargs["use_inflated_groupnorm"] = use_mm_v2
+ unet_additional_kwargs["unet_use_cross_frame_attention"] = False
+ unet_additional_kwargs["unet_use_temporal_attention"] = False
+ unet_additional_kwargs["use_motion_module"] = True
+ unet_additional_kwargs["motion_module_resolutions"] = [1, 2, 4, 8]
+ unet_additional_kwargs["motion_module_mid_block"] = use_mm_v2
+ unet_additional_kwargs["motion_module_decoder_only"] = False
+ unet_additional_kwargs["motion_module_type"] = "Vanilla"
+ unet_additional_kwargs["motion_module_kwargs"] = {
+ "num_attention_heads": 8,
+ "num_transformer_block": 1,
+ "attention_block_types": [
+ "Temporal_Self",
+ "Temporal_Self"
+ ],
+ "temporal_position_encoding": True,
+ "temporal_position_encoding_max_len": position_encoding_len,
+ "temporal_attention_dim_div": 1
+ }
+
+ model = AnimateDiffUNet.from_config(config, **unet_additional_kwargs)
+ cls.load_diff_state_dict(
+ unet=model,
+ cache_dir=cache_dir,
+ use_mm_v2=use_mm_v2,
+ motion_module=motion_module,
+ task_callback=task_callback,
+ position_encoding_truncate_length=position_encoding_truncate_length,
+ position_encoding_scale_length=position_encoding_scale_length,
+ )
+ return model
+
+ @classmethod
+ def load_diff_state_dict(
+ cls,
+ unet: AnimateDiffUNet,
+ cache_dir: str,
+ use_mm_v2: bool=True,
+ motion_module: Optional[Union[str, Dict[str, torch.Tensor]]]=None,
+ task_callback: Optional[Callable[[str], None]]=None,
+ position_encoding_truncate_length: Optional[int]=None,
+ position_encoding_scale_length: Optional[int]=None,
+ ) -> None:
+ """
+ Loads animate diff state dict into an animate diff unet
+ """
+ if motion_module is None:
+ motion_module = cls.MOTION_MODULE_V2 if use_mm_v2 else cls.MOTION_MODULE
+
+ if task_callback is not None:
+ if not os.path.exists(os.path.join(cache_dir, os.path.basename(motion_module))):
+ task_callback(f"Downloading {motion_module}")
+
+ motion_module = check_download_to_dir(motion_module, cache_dir)
+
+ if isinstance(motion_module, dict):
+ logger.debug(f"Loading AnimateDiff motion module with truncate length '{position_encoding_truncate_length}' and scale length '{position_encoding_scale_length}'")
+
+ state_dict = motion_module
+ else:
+ logger.debug(f"Loading AnimateDiff motion module {motion_module} with truncate length '{position_encoding_truncate_length}' and scale length '{position_encoding_scale_length}'")
+
+ from enfugue.diffusion.util.torch_util import load_state_dict
+ state_dict = load_state_dict(motion_module) # type: ignore[assignment]
+
+ if position_encoding_truncate_length is not None or position_encoding_scale_length is not None:
+ for key in state_dict:
+ if key.endswith(".pe"):
+ if position_encoding_truncate_length is not None:
+ state_dict[key] = state_dict[key][:, :position_encoding_truncate_length] # type: ignore[index]
+ if position_encoding_scale_length is not None:
+ tensor_shape = state_dict[key].shape # type: ignore[union-attr]
+ tensor = rearrange(state_dict[key], "(t b) f d -> t b f d", t=1)
+ tensor = F.interpolate(tensor, size=(position_encoding_scale_length, tensor_shape[-1]), mode="bilinear")
+ state_dict[key] = rearrange(tensor, "t b f d -> (t b) f d") # type: ignore[assignment]
+ del tensor
+
+ num_motion_keys = len(list(state_dict.keys()))
+ unet_version = "V2" if use_mm_v2 else "V1"
+ logger.debug(f"Loading {num_motion_keys} keys into AnimateDiff UNet {unet_version} state dict (non-strict)")
+ unet.load_state_dict(state_dict, strict=False)
+ del state_dict
+
+ def load_motion_module_weights(
+ self,
+ cache_dir: str,
+ use_mm_v2: bool=True,
+ motion_module: Optional[str]=None,
+ task_callback: Optional[Callable[[str], None]]=None,
+ position_encoding_truncate_length: Optional[int]=None,
+ position_encoding_scale_length: Optional[int]=None,
+ ) -> None:
+ """
+ Loads motion module weights after-the-fact
+ """
+ if self.is_sdxl:
+ self.load_hotshot_state_dict(
+ unet=self.unet,
+ motion_module=motion_module,
+ cache_dir=cache_dir,
+ task_callback=task_callback,
+ position_encoding_truncate_length=position_encoding_truncate_length,
+ position_encoding_scale_length=position_encoding_scale_length
+ )
+ else:
+ self.load_diff_state_dict(
+ unet=self.unet,
+ motion_module=motion_module,
+ cache_dir=cache_dir,
+ task_callback=task_callback,
+ use_mm_v2=use_mm_v2,
+ position_encoding_truncate_length=position_encoding_truncate_length,
+ position_encoding_scale_length=position_encoding_scale_length
+ )
+
+ def load_motion_lora_weights(
+ self,
+ state_dict: Dict[str, torch.Tensor],
+ multiplier: float = 1.0,
+ dtype: torch.dtype = torch.float32
+ ) -> None:
+ """
+ Loads motion LoRA checkpoint into the unet
+ """
+ for key in state_dict:
+ if "up." in key:
+ continue
+ up_key = key.replace(".down.", ".up.")
+ model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
+ model_key = model_key.replace("to_out.", "to_out.0.")
+ layer_infos = model_key.split(".")[:-1]
+
+ curr_layer = self.unet
+ while len(layer_infos) > 0:
+ temp_name = layer_infos.pop(0)
+ curr_layer = curr_layer.__getattr__(temp_name)
+
+ weight_down = state_dict[key].to(dtype)
+ weight_up = state_dict[up_key].to(dtype)
+ curr_layer.weight.data += multiplier * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
+
+ def decode_latents(
+ self,
+ latents: torch.Tensor,
+ device: Union[str, torch.device],
+ chunker: Chunker,
+ weight_builder: MaskWeightBuilder,
+ progress_callback: Optional[Callable[[bool], None]]=None,
+ scale_latents: bool=True
+ ) -> torch.Tensor:
+ """
+ Decodes each video frame individually.
+ """
+ animation_frames = latents.shape[2]
+ if scale_latents:
+ latents = 1 / self.vae.config.scaling_factor * latents # type: ignore[attr-defined]
+
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ dtype = latents.dtype
+ # Force full precision VAE
+ #self.vae = self.vae.to(torch.float32)
+ #latents = latents.to(torch.float32)
+ video: List[torch.Tensor] = []
+ for frame_index in range(latents.shape[0]):
+ video.append(
+ super(EnfugueAnimateStableDiffusionPipeline, self).decode_latents(
+ latents=latents[frame_index:frame_index+1],
+ device=device,
+ weight_builder=weight_builder,
+ chunker=chunker,
+ progress_callback=progress_callback,
+ scale_latents=False
+ )
+ )
+ video = torch.cat(video) # type: ignore
+ video = rearrange(video, "(b f) c h w -> b c f h w", f = animation_frames) # type: ignore
+ video = (video / 2 + 0.5).clamp(0, 1) # type: ignore
+ video = video.cpu().float() # type: ignore
+ #self.vae.to(dtype)
+ return video # type: ignore
diff --git a/src/python/enfugue/diffusion/constants.py b/src/python/enfugue/diffusion/constants.py
index f9bc85cb..a1eb9563 100644
--- a/src/python/enfugue/diffusion/constants.py
+++ b/src/python/enfugue/diffusion/constants.py
@@ -1,6 +1,33 @@
-from typing import Literal
+from __future__ import annotations
+
+import os
+
+from typing import (
+ Optional,
+ Union,
+ Tuple,
+ Literal,
+ List,
+ TYPE_CHECKING
+)
+from typing_extensions import (
+ TypedDict,
+ NotRequired
+)
+from enfugue.util import (
+ IMAGE_FIT_LITERAL,
+ IMAGE_ANCHOR_LITERAL
+)
+
+if TYPE_CHECKING:
+ from PIL.Image import Image
__all__ = [
+ "UpscaleStepDict",
+ "ImageDict",
+ "IPAdapterImageDict",
+ "ControlImageDict",
+ "PromptDict",
"DEFAULT_MODEL",
"DEFAULT_INPAINTING_MODEL",
"DEFAULT_SDXL_MODEL",
@@ -24,18 +51,65 @@
"CONTROLNET_POSE",
"CONTROLNET_POSE_XL",
"CONTROLNET_PIDI",
+ "CONTROLNET_PIDI_XL",
"CONTROLNET_LINE",
"CONTROLNET_ANIME",
- "CONTROLNET_LITERAL",
"CONTROLNET_TEMPORAL",
"CONTROLNET_QR",
+ "CONTROLNET_QR_XL",
+ "CONTROLNET_LITERAL",
+ "MOTION_LORA_ZOOM_IN",
+ "MOTION_LORA_ZOOM_OUT",
+ "MOTION_LORA_PAN_LEFT",
+ "MOTION_LORA_PAN_RIGHT",
+ "MOTION_LORA_ROLL_CLOCKWISE",
+ "MOTION_LORA_ROLL_ANTI_CLOCKWISE",
+ "MOTION_LORA_TILT_UP",
+ "MOTION_LORA_TILT_DOWN",
+ "MOTION_LORA_LITERAL",
"SCHEDULER_LITERAL",
"DEVICE_LITERAL",
"PIPELINE_SWITCH_MODE_LITERAL",
"UPSCALE_LITERAL",
"MASK_TYPE_LITERAL",
+ "LOOP_TYPE_LITERAL",
+ "DEFAULT_CHECKPOINT_DIR",
+ "DEFAULT_INVERSION_DIR",
+ "DEFAULT_LORA_DIR",
+ "DEFAULT_LYCORIS_DIR",
+ "DEFAULT_TENSORRT_DIR",
+ "DEFAULT_CACHE_DIR",
+ "DEFAULT_OTHER_DIR",
+ "DEFAULT_DIFFUSERS_DIR",
+ "DEFAULT_SIZE",
+ "DEFAULT_SDXL_SIZE",
+ "DEFAULT_TEMPORAL_SIZE",
+ "DEFAULT_TILING_SIZE",
+ "DEFAULT_TILING_STRIDE",
+ "DEFAULT_TEMPORAL_TILING_SIZE",
+ "DEFAULT_IMAGE_CALLBACK_STEPS",
+ "DEFAULT_CONDITIONING_SCALE",
+ "DEFAULT_IMG2IMG_STRENGTH",
+ "DEFAULT_INFERENCE_STEPS",
+ "DEFAULT_GUIDANCE_SCALE",
+ "DEFAULT_UPSCALE_PROMPT",
+ "DEFAULT_UPSCALE_INFERENCE_STEPS",
+ "DEFAULT_UPSCALE_GUIDANCE_SCALE",
+ "DEFAULT_UPSCALE_TILING_SIZE",
+ "DEFAULT_UPSCALE_TILING_STRIDE",
+ "DEFAULT_REFINER_START",
+ "DEFAULT_REFINER_STRENGTH",
+ "DEFAULT_REFINER_GUIDANCE_SCALE",
+ "DEFAULT_AESTHETIC_SCORE",
+ "DEFAULT_NEGATIVE_AESTHETIC_SCORE",
+ "MODEL_PROMPT_WEIGHT",
+ "GLOBAL_PROMPT_STEP_WEIGHT",
+ "GLOBAL_PROMPT_UPSCALE_WEIGHT",
+ "UPSCALE_PROMPT_STEP_WEIGHT",
+ "MAX_IMAGE_SCALE",
"LATENT_BLEND_METHOD_LITERAL",
- "NOISE_METHOD_LITERAL"
+ "NOISE_METHOD_LITERAL",
+ "IP_ADAPTER_LITERAL"
]
DEFAULT_MODEL = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt"
@@ -44,6 +118,100 @@
DEFAULT_SDXL_REFINER = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors"
DEFAULT_SDXL_INPAINTING_MODEL = "https://huggingface.co/benjamin-paine/sd-xl-alternative-bases/resolve/main/sd_xl_base_1.0_inpainting_0.1.safetensors"
+DEFAULT_CHECKPOINT_DIR = os.path.expanduser("~/.cache/enfugue/checkpoint")
+DEFAULT_INVERSION_DIR = os.path.expanduser("~/.cache/enfugue/inversion")
+DEFAULT_TENSORRT_DIR = os.path.expanduser("~/.cache/enfugue/tensorrt")
+DEFAULT_LORA_DIR = os.path.expanduser("~/.cache/enfugue/lora")
+DEFAULT_LYCORIS_DIR = os.path.expanduser("~/.cache/enfugue/lycoris")
+DEFAULT_CACHE_DIR = os.path.expanduser("~/.cache/enfugue/cache")
+DEFAULT_DIFFUSERS_DIR = os.path.expanduser("~/.cache/enfugue/diffusers")
+DEFAULT_OTHER_DIR = os.path.expanduser("~/.cache/enfugue/other")
+
+DEFAULT_SIZE = 512
+DEFAULT_TILING_SIZE = 512
+DEFAULT_TILING_STRIDE = 32
+DEFAULT_TILING_MASK = "bilinear"
+DEFAULT_TEMPORAL_SIZE = 16
+DEFAULT_TEMPORAL_TILING_SIZE = 12
+DEFAULT_SDXL_SIZE = 1024
+DEFAULT_IMAGE_CALLBACK_STEPS = 10
+DEFAULT_CONDITIONING_SCALE = 1.0
+DEFAULT_IMG2IMG_STRENGTH = 0.8
+DEFAULT_INFERENCE_STEPS = 40
+DEFAULT_GUIDANCE_SCALE = 7.5
+DEFAULT_UPSCALE_PROMPT = "highly detailed, ultra-detailed, intricate detail, high definition, HD, 4k, 8k UHD"
+DEFAULT_UPSCALE_INFERENCE_STEPS = 100
+DEFAULT_UPSCALE_GUIDANCE_SCALE = 12
+DEFAULT_UPSCALE_TILING_STRIDE = 128
+DEFAULT_UPSCALE_TILING_SIZE = DEFAULT_TILING_SIZE
+
+DEFAULT_REFINER_START = 0.85
+DEFAULT_REFINER_STRENGTH = 0.3
+DEFAULT_REFINER_GUIDANCE_SCALE = 5.0
+DEFAULT_AESTHETIC_SCORE = 6.0
+DEFAULT_NEGATIVE_AESTHETIC_SCORE = 2.5
+
+MODEL_PROMPT_WEIGHT = 0.2
+GLOBAL_PROMPT_STEP_WEIGHT = 0.4
+GLOBAL_PROMPT_UPSCALE_WEIGHT = 0.4
+UPSCALE_PROMPT_STEP_WEIGHT = 0.1
+MAX_IMAGE_SCALE = 3.0
+
+CACHE_MODE_LITERAL = ["always", "xl", "tensorrt"]
+VAE_LITERAL = Literal["ema", "mse", "xl", "xl16"]
+DEVICE_LITERAL = Literal["cpu", "cuda", "dml", "mps"]
+PIPELINE_SWITCH_MODE_LITERAL = Literal["offload", "unload"]
+SCHEDULER_LITERAL = Literal[
+ "ddim", "ddpm", "deis",
+ "dpmsm", "dpmsms", "dpmsmk", "dpmsmka",
+ "dpmss", "dpmssk", "heun",
+ "dpmd", "dpmdk", "adpmd",
+ "adpmdk", "dpmsde", "unipc",
+ "lmsd", "lmsdk", "pndm",
+ "eds", "eads"
+]
+UPSCALE_LITERAL = Literal[
+ "esrgan", "esrganime", "gfpgan",
+ "lanczos", "bilinear", "bicubic",
+ "nearest"
+]
+CONTROLNET_LITERAL = Literal[
+ "canny", "mlsd", "hed",
+ "scribble", "tile", "inpaint",
+ "depth", "normal", "pose",
+ "pidi", "line", "anime",
+ "temporal", "qr"
+]
+MOTION_LORA_LITERAL = [
+ "pan-left", "pan-right",
+ "roll-clockwise", "roll-anti-clockwise",
+ "tilt-up", "tilt-down",
+ "zoom-in", "zoom-out"
+]
+LOOP_TYPE_LITERAL = Literal[
+ "loop", "reflect"
+]
+MASK_TYPE_LITERAL = Literal[
+ "constant", "bilinear", "gaussian"
+]
+LATENT_BLEND_METHOD_LITERAL = Literal[
+ "add", "bislerp", "cosine", "cubic",
+ "difference", "inject", "lerp", "slerp",
+ "exclusion", "subtract", "multiply", "overlay",
+ "screen", "color_dodge", "linear_dodge", "glow",
+ "pin_light", "hard_light", "linear_light", "vivid_light"
+]
+NOISE_METHOD_LITERAL = Literal[
+ "default", "crosshatch", "simplex",
+ "perlin", "brownian_fractal", "white",
+ "grey", "pink", "blue", "green",
+ "velvet", "violet", "random_mix"
+]
+IP_ADAPTER_LITERAL = Literal[
+ "default", "plus", "plus-face"
+]
+
+# VAE repos/files
VAE_EMA = (
"stabilityai/sd-vae-ft-ema",
"vae-ft-ema-560000-ema-pruned",
@@ -68,9 +236,7 @@
"sdxl_vae_fp16_fix",
"sdxl_vae_fp16"
)
-
-VAE_LITERAL = Literal["ema", "mse", "xl", "xl16"]
-
+# ControlNet repos/files
CONTROLNET_CANNY = (
"lllyasviel/control_v11p_sd15_canny",
"control_v11p_sd15_canny",
@@ -120,6 +286,10 @@
"lllyasviel/control_v11p_sd15_softedge",
"control_v11p_sd15_softedge",
)
+CONTROLNET_PIDI_XL = (
+ "SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
+ "controlnet-sd-xl-1.0-softedge-dexined"
+)
CONTROLNET_LINE = (
"lllyasviel/control_v11p_sd15_lineart",
"control_v11p_sd15_lineart",
@@ -137,7 +307,16 @@
"control_v1p_sd15_qrcode_monster_v2",
"control_v1p_sd15_qrcode_monster",
)
-# Recommend XL files come from https://huggingface.co/lllyasviel/sd_control_collection/tree/main
+CONTROLNET_QR_XL = (
+ "https://huggingface.co/Nacholmo/controlnet-qr-pattern-sdxl/resolve/main/automatic1111/control_v10e_sdxl_opticalpattern.safetensors",
+ "control_v10e_sdxl_opticalpattern",
+ "control_v04u_sdxl_opticalpattern-half",
+ "control_v03u_sdxl_opticalpattern",
+ "control_v02u_sdxl_qrpattern",
+ "control_v02u_sdxl_opticalpattern",
+ "control_v01u_sdxl_opticalpattern",
+ "control_v01u_sdxl_qrpattern",
+)
CONTROLNET_CANNY_XL = (
"diffusers/controlnet-canny-sdxl-1.0",
"diffusers_xl_canny_full",
@@ -153,43 +332,76 @@
"OpenPoseXL2",
"controlnet-openpose-sdxl-1.0",
)
-CONTROLNET_LITERAL = Literal[
- "canny", "mlsd", "hed",
- "scribble", "tile", "inpaint",
- "depth", "normal", "pose",
- "pidi", "line", "anime",
- "temporal", "qr"
-]
-SCHEDULER_LITERAL = Literal[
- "ddim", "ddpm", "deis",
- "dpmsm", "dpmsmk", "dpmsmka",
- "dpmss", "dpmssk", "heun",
- "dpmd", "dpmdk", "adpmd",
- "adpmdk", "dpmsde", "unipc",
- "lmsd", "lmsdk", "pndm",
- "eds", "eads"
+MOTION_LORA_PAN_LEFT = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_PanLeft.ckpt"
+MOTION_LORA_PAN_RIGHT = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_PanRight.ckpt"
+MOTION_LORA_ROLL_CLOCKWISE = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_RollingClockwise.ckpt"
+MOTION_LORA_ROLL_ANTI_CLOCKWISE = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_RollingAnticlockwise.ckpt"
+MOTION_LORA_TILT_UP = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_TiltUp.ckpt"
+MOTION_LORA_TILT_DOWN = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_TiltDown.ckpt"
+MOTION_LORA_ZOOM_IN = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_ZoomIn.ckpt"
+MOTION_LORA_ZOOM_OUT = "https://huggingface.co/guoyww/animatediff/resolve/main/v2_lora_ZoomOut.ckpt"
+
+MultiModelType = Union[str, List[str]]
+WeightedMultiModelType = Union[
+ str, Tuple[str, float],
+ List[Union[str, Tuple[str, float]]]
]
-DEVICE_LITERAL = Literal["cpu", "cuda", "dml", "mps"]
+class ImageDict(TypedDict):
+ """
+ An image or video with optional fitting details
+ """
+ image: Union[str, Image, List[Image]]
+ skip_frames: NotRequired[Optional[int]]
+ divide_frames: NotRequired[Optional[int]]
+ fit: NotRequired[Optional[IMAGE_FIT_LITERAL]]
+ anchor: NotRequired[Optional[IMAGE_ANCHOR_LITERAL]]
+ invert: NotRequired[bool]
-PIPELINE_SWITCH_MODE_LITERAL = Literal["offload", "unload"]
+class ControlImageDict(ImageDict):
+ """
+ Extends the image dict additionally with controlnet details
+ """
+ controlnet: CONTROLNET_LITERAL
+ scale: NotRequired[float]
+ start: NotRequired[Optional[float]]
+ end: NotRequired[Optional[float]]
+ process: NotRequired[bool]
-UPSCALE_LITERAL = Literal["esrgan", "esrganime", "gfpgan", "lanczos", "bilinear", "bicubic", "nearest"]
+class IPAdapterImageDict(ImageDict):
+ """
+ Extends the image dict additionally with IP adapter scale
+ """
+ scale: NotRequired[float]
-MASK_TYPE_LITERAL = Literal["constant", "bilinear", "gaussian"]
+class UpscaleStepDict(TypedDict):
+ """
+ All the options for each upscale step
+ """
+ method: UPSCALE_LITERAL
+ amount: Union[int, float]
+ strength: NotRequired[float]
+ num_inference_steps: NotRequired[int]
+ scheduler: NotRequired[SCHEDULER_LITERAL]
+ guidance_scale: NotRequired[float]
+ controlnets: NotRequired[List[Union[CONTROLNET_LITERAL, Tuple[CONTROLNET_LITERAL, float]]]]
+ prompt: NotRequired[str]
+ prompt_2: NotRequired[str]
+ negative_prompt: NotRequired[str]
+ negative_prompt_2: NotRequired[str]
+ chunking_size: NotRequired[Optional[int]]
+ chunking_frames: NotRequired[Optional[int]]
+ chunking_mask: NotRequired[MASK_TYPE_LITERAL]
-LATENT_BLEND_METHOD_LITERAL = Literal[
- "add", "bislerp", "cosine", "cubic",
- "difference", "inject", "lerp", "slerp",
- "exclusion", "subtract", "multiply", "overlay",
- "screen", "color_dodge", "linear_dodge", "glow",
- "pin_light", "hard_light", "linear_light", "vivid_light"
-]
-
-NOISE_METHOD_LITERAL = Literal[
- "default", "crosshatch", "simplex",
- "perlin", "brownian_fractal", "white",
- "grey", "pink", "blue", "green",
- "velvet", "violet", "random_mix"
-]
+class PromptDict(TypedDict):
+ """
+ A prompt step, optionally with frame details
+ """
+ positive: str
+ negative: NotRequired[Optional[str]]
+ positive_2: NotRequired[Optional[str]]
+ negative_2: NotRequired[Optional[str]]
+ weight: NotRequired[Optional[float]]
+ start: NotRequired[Optional[int]]
+ end: NotRequired[Optional[int]]
diff --git a/src/python/enfugue/diffusion/engine.py b/src/python/enfugue/diffusion/engine.py
index 87b721d9..e1720844 100644
--- a/src/python/enfugue/diffusion/engine.py
+++ b/src/python/enfugue/diffusion/engine.py
@@ -3,7 +3,8 @@
import time
import datetime
-from typing import Optional, Any, Union, Dict, TYPE_CHECKING
+from typing import Optional, Any, Union, Dict, Type, List, TYPE_CHECKING
+from typing_extensions import Self
from multiprocessing import Queue as MakeQueue
from multiprocessing.queues import Queue
from queue import Empty
@@ -12,20 +13,23 @@
from pibble.util.strings import Serializer
from pibble.util.helpers import resolve
-from enfugue.diffusion.process import DiffusionEngineProcess
from enfugue.util import logger
+from enfugue.diffusion.process import (
+ EngineProcess,
+ DiffusionEngineProcess
+)
if TYPE_CHECKING:
- from enfugue.diffusion.plan import DiffusionPlan
+ from enfugue.diffusion.invocation import LayeredInvocation
+__all__ = [
+ "Engine",
+ "DiffusionEngine"
+]
-__all__ = ["DiffusionEngine"]
-
-
-class DiffusionEngine:
+class Engine:
LOGGING_DELAY_MS = 10000
-
- process: DiffusionEngineProcess
+ process: EngineProcess
def __init__(self, configuration: Optional[APIConfiguration] = None):
self.configuration = APIConfiguration()
@@ -34,7 +38,14 @@ def __init__(self, configuration: Optional[APIConfiguration] = None):
if configuration is not None:
self.configuration = configuration
- def __enter__(self) -> DiffusionEngine:
+ @property
+ def process_class(self) -> Type[EngineProcess]:
+ """
+ Gets the class of the process
+ """
+ return EngineProcess
+
+ def __enter__(self) -> Self:
"""
When entering the engine via context manager, start the process.
"""
@@ -62,8 +73,6 @@ def check_delete_queue(self, queue_name: str) -> None:
Deletes a queue.
"""
if hasattr(self, f"_{queue_name}"):
- if self.is_alive():
- raise IOError(f"Attempted to close queue {queue_name} while process is still alive")
delattr(self, f"_{queue_name}")
@property
@@ -80,6 +89,20 @@ def instructions(self) -> None:
"""
self.check_delete_queue("instructions")
+ @property
+ def intermediates(self) -> Queue:
+ """
+ Gets the intemerdiates queue
+ """
+ return self.check_get_queue("intermediates")
+
+ @intermediates.deleter
+ def intermediates(self) -> None:
+ """
+ Deletes the intermediates queue
+ """
+ self.check_delete_queue("intermediates")
+
@property
def results(self) -> Queue:
"""
@@ -94,19 +117,28 @@ def results(self) -> None:
"""
self.check_delete_queue("results")
- @property
- def intermediates(self) -> Queue:
+ def get_queues(self) -> List[Queue]:
"""
- Gets the intemerdiates queue
+ Gets queues to pass to the process
"""
- return self.check_get_queue("intermediates")
+ return [self.instructions, self.results, self.intermediates]
- @intermediates.deleter
- def intermediates(self) -> None:
+ def delete_queues(self) -> None:
"""
- Deletes the intermediates queue
+ Deletes queues after killing a process
"""
- self.check_delete_queue("intermediates")
+ try:
+ del self.instructions
+ except:
+ pass
+ try:
+ del self.results
+ except:
+ pass
+ try:
+ del self.intermediates
+ except:
+ pass
def spawn_process(
self,
@@ -122,9 +154,8 @@ def spawn_process(
return
logger.debug("No current engine process, creating.")
- self.process = DiffusionEngineProcess(self.instructions, self.results, self.intermediates, self.configuration)
-
- poll_delay_seconds = DiffusionEngineProcess.POLLING_DELAY_MS / 250
+ poll_delay_seconds = self.process_class.POLLING_DELAY_MS / 250
+ self.process = self.process_class(self.configuration, *self.get_queues())
try:
logger.debug("Starting process.")
@@ -158,7 +189,7 @@ def terminate_process(self, timeout: Optional[Union[int, float]] = 10) -> None:
if self.process.is_alive():
start = datetime.datetime.now()
- sleep_time = DiffusionEngineProcess.POLLING_DELAY_MS / 500
+ sleep_time = self.process_class.POLLING_DELAY_MS / 500
self.dispatch("stop")
time.sleep(sleep_time)
while self.process.is_alive():
@@ -169,29 +200,17 @@ def terminate_process(self, timeout: Optional[Union[int, float]] = 10) -> None:
time.sleep(sleep_time)
break
time.sleep(sleep_time)
+
if hasattr(self, "process") and self.process.is_alive():
logger.debug("Sending term one more time...")
self.process.terminate()
time.sleep(sleep_time)
- if hasattr(self, "process") and self.process.is_alive():
- raise IOError("Couldn't terminate process")
- try:
- del self.process
- except AttributeError:
- pass
- try:
- del self.intermediates
- except AttributeError:
- pass
- try:
- del self.results
- except AttributeError:
- pass
- try:
- del self.instructions
- except AttributeError:
- pass
+ if hasattr(self, "process"):
+ if self.process.is_alive():
+ raise IOError("Couldn't terminate process")
+ delattr(self, "process")
+ self.delete_queues()
def keepalive(self, timeout: Union[int, float] = 0.2) -> bool:
"""
@@ -211,7 +230,12 @@ def keepalive(self, timeout: Union[int, float] = 0.2) -> bool:
raise IOError(f"Incorrect ping response {ping_response}")
return True
- def dispatch(self, action: str, payload: Any = None, spawn_process: bool = True) -> Any:
+ def dispatch(
+ self,
+ action: str,
+ payload: Any = None,
+ spawn_process: bool = True
+ ) -> Any:
"""
Sends a payload, does not wait for a response.
"""
@@ -241,7 +265,7 @@ def last_intermediate(self, id: int) -> Any:
if step_deserialized["id"] == id:
if intermediate_data is None:
intermediate_data = {"id": id}
- for key in ["step", "total", "rate", "images", "task"]:
+ for key in ["step", "total", "rate", "images", "task", "video"]:
if key in step_deserialized:
intermediate_data[key] = step_deserialized[key]
else:
@@ -291,17 +315,37 @@ def wait(self, id: int, timeout: Optional[Union[int, float]] = None) -> Any:
if timeout is not None and (datetime.datetime.now() - start).total_seconds() > timeout:
raise TimeoutError("Timed out waiting for response.")
- def invoke(self, action: str, payload: Any = None, timeout: Optional[Union[int, float]] = None) -> Any:
+ def invoke(
+ self,
+ action: str,
+ payload: Any = None,
+ timeout: Optional[Union[int, float]] = None
+ ) -> Any:
"""
Issue a single request synchronously using arg syntax.
"""
return self.wait(self.dispatch(action, payload), timeout)
- def execute(self, plan: DiffusionPlan, timeout: Optional[Union[int, float]] = None, wait: bool = False) -> Any:
+class DiffusionEngine(Engine):
+ """
+ The base extension of this class is for the stable diffusion engine
+ """
+ @property
+ def process_class(self) -> Type[EngineProcess]:
+ """
+ Change to diffusionengine
+ """
+ return DiffusionEngineProcess
+
+ def execute(
+ self,
+ plan: LayeredInvocation,
+ timeout: Optional[Union[int, float]] = None, wait: bool = False
+ ) -> Any:
"""
This is a helpful method to just serialize and execute a plan.
"""
- id = self.dispatch("plan", plan.get_serialization_dict())
+ id = self.dispatch("plan", plan.serialize())
if wait:
return self.wait(id, timeout)
return id
diff --git a/src/python/enfugue/diffusion/interpolate/__init__.py b/src/python/enfugue/diffusion/interpolate/__init__.py
new file mode 100644
index 00000000..8201ee7e
--- /dev/null
+++ b/src/python/enfugue/diffusion/interpolate/__init__.py
@@ -0,0 +1,5 @@
+from enfugue.diffusion.interpolate.interpolator import InterpolationEngine
+
+InterpolationEngine # Silence importcheck
+
+__all__ = ["InterpolationEngine"]
diff --git a/src/python/enfugue/diffusion/interpolate/interpolator.py b/src/python/enfugue/diffusion/interpolate/interpolator.py
new file mode 100644
index 00000000..87196b0b
--- /dev/null
+++ b/src/python/enfugue/diffusion/interpolate/interpolator.py
@@ -0,0 +1,422 @@
+from __future__ import annotations
+
+import os
+import numpy as np
+
+from typing import (
+ Any,
+ Iterator,
+ Type,
+ Union,
+ List,
+ Tuple,
+ Iterable,
+ Dict,
+ Optional,
+ Callable,
+ TYPE_CHECKING
+)
+
+if TYPE_CHECKING:
+ from enfugue.diffusion.interpolate.model import InterpolatorModel # type: ignore[attr-defined]
+
+from PIL import Image
+from datetime import datetime
+
+from enfugue.diffusion.engine import Engine
+from enfugue.diffusion.process import EngineProcess
+from enfugue.util import (
+ get_frames_or_image_from_file,
+ get_frames_or_image,
+ check_make_directory,
+ logger
+)
+
+_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
+
+__all__ = ["InterpolationEngine"]
+
+class InterpolatorEngineProcess(EngineProcess):
+ """
+ Capture the interpolator in a process because tensorflow has a lot of global state.
+ """
+ interpolator_path: str
+
+ @property
+ def interpolator(self) -> InterpolatorModel:
+ """
+ Gets or instantiates the interpolator
+ """
+ if not hasattr(self, "_interpolator"):
+ if getattr(self, "interpolator_path", None) is None:
+ raise IOError("Can't get interpolator - path was never sent to process.")
+ logger.debug(f"Loading interpolator from {self.interpolator_path}")
+ from enfugue.diffusion.interpolate.model import InterpolatorModel # type: ignore[attr-defined]
+ self._interpolator = InterpolatorModel(self.interpolator_path)
+ return self._interpolator
+
+ def interpolate_recursive(
+ self,
+ frames: Iterable[Image.Image],
+ multiplier: Union[int, Tuple[int, ...]] = 2,
+ ) -> Iterator[Image]:
+ """
+ Provides a generator for interpolating between multiple frames.
+ """
+ if isinstance(multiplier, tuple) or isinstance(multiplier, list): # type: ignore[unreachable]
+ if len(multiplier) == 1:
+ multiplier = multiplier[0]
+ else:
+ this_multiplier = multiplier[0]
+ recursed_multiplier = multiplier[1:]
+ for frame in self.interpolate_recursive(
+ frames=self.interpolate_recursive(
+ frames=frames, # type: ignore[arg-type]
+ multiplier=recursed_multiplier,
+ ),
+ multiplier=this_multiplier,
+ ):
+ yield frame
+ return
+
+ previous_frame = None
+ frame_index = 0
+ for frame in frames:
+ frame_index += 1
+ if previous_frame is not None:
+ for i in range(multiplier - 1): # type: ignore[unreachable]
+ yield self.interpolate(
+ previous_frame,
+ frame,
+ (i + 1) / multiplier
+ )
+ yield frame
+ previous_frame = frame
+ frame_start = datetime.now()
+
+ def loop(
+ self,
+ frames: Iterable[Image.Image],
+ ease_frames: int = 2,
+ double_ease_frames: int = 1,
+ hold_frames: int = 0,
+ trigger_callback: Optional[Callable[[Image.Image], Image]] = None,
+ ) -> Iterable[Image]:
+ """
+ Takes a video and creates a gently-looping version of it.
+ """
+ if trigger_callback is None:
+ trigger_callback = lambda image: image
+
+ # Memoized frames
+ frame_list: List[Image.Image] = [frame for frame in frames]
+
+ if double_ease_frames:
+ double_ease_start_frames, frame_list = frame_list[:double_ease_frames], frame_list[double_ease_frames:]
+ else:
+ double_ease_start_frames = []
+ if ease_frames:
+ ease_start_frames, frame_list = frame_list[:ease_frames], frame_list[ease_frames:]
+ else:
+ ease_start_frames = []
+
+ if double_ease_frames:
+ frame_list, double_ease_end_frames = frame_list[:-double_ease_frames], frame_list[-double_ease_frames:]
+ else:
+ double_ease_end_frames = []
+ if ease_frames:
+ frame_list, ease_end_frames = frame_list[:-ease_frames], frame_list[-ease_frames:]
+ else:
+ ease_end_frames = []
+
+ # Interpolate frames
+ double_ease_start_frames = [
+ trigger_callback(frame) for frame in self.interpolate_recursive(
+ frames=double_ease_start_frames,
+ multiplier=(2,2),
+ )
+ ]
+ ease_start_frames = [
+ trigger_callback(frame) for frame in self.interpolate_recursive(
+ frames=ease_start_frames,
+ multiplier=2,
+ )
+ ]
+ ease_end_frames = [
+ trigger_callback(frame) for frame in self.interpolate_recursive(
+ frames=ease_end_frames,
+ multiplier=2,
+ )
+ ]
+ double_ease_end_frames = [
+ trigger_callback(frame) for frame in self.interpolate_recursive(
+ frames=double_ease_end_frames,
+ multiplier=(2,2),
+ )
+ ]
+
+ # Return to one list
+ frame_list = double_ease_start_frames + ease_start_frames + frame_list + ease_end_frames + double_ease_end_frames
+
+ # Iterate
+ for frame in frame_list:
+ yield frame
+
+ # Hold on final frame
+ for i in range(hold_frames):
+ yield frame_list[-1]
+
+ # Reverse the frames
+ frame_list.reverse()
+ for frame in frame_list[1:-1]:
+ yield frame
+
+ # Hold on first frame
+ for i in range(hold_frames):
+ yield frame_list[-1]
+
+ def handle_plan(
+ self,
+ instruction_id: int,
+ instruction_payload: Dict[str, Any]
+ ) -> Union[str, List[Image]]:
+ """
+ Handles an entire video potentially with recursion
+ """
+ interpolate_frames = instruction_payload["frames"]
+ if isinstance(interpolate_frames, list):
+ interpolate_frames = tuple(interpolate_frames)
+
+ images = instruction_payload["images"]
+ if isinstance(images, str):
+ images = get_frames_or_image_from_file(images)
+ elif isinstance(images, Image.Image):
+ images = get_frames_or_image(images)
+
+ image_count = len(images)
+ interpolated_count = image_count
+ if isinstance(interpolate_frames, tuple):
+ for multiplier in interpolate_frames:
+ interpolated_count *= multiplier
+ else:
+ interpolated_count *= interpolate_frames
+
+ reflect = instruction_payload.get("reflect", False)
+ ease_frames = instruction_payload.get("ease_frames", 2)
+ double_ease_frames = instruction_payload.get("double_ease_frames", 1)
+ hold_frames = instruction_payload.get("hold_frames", 0)
+
+ frame_complete = 0
+ frame_start = datetime.now()
+ frame_times: List[float] = []
+
+ if reflect:
+ interpolated_count += ease_frames + (2 * double_ease_frames)
+
+ def trigger_callback(image: Image.Image) -> Image.Image:
+ """
+ Triggers the callback, which sends progress back up the line
+ """
+ nonlocal frame_complete
+ nonlocal frame_start
+ nonlocal frame_times
+ frame_time = datetime.now()
+ frame_seconds = (frame_time - frame_start).total_seconds()
+ frame_times.append(frame_seconds)
+ frame_complete += 1
+ frame_rate = (sum(frame_times[-8:]) / min(frame_complete, 8))
+
+ if frame_complete % 8 == 0:
+ logger.debug(f"Completed {frame_complete}/{interpolated_count} frames (average {frame_rate} sec/frame)")
+
+ self.intermediates.put_nowait({
+ "id": instruction_id,
+ "step": frame_complete,
+ "total": interpolated_count,
+ "rate": None if not frame_rate else 1.0 / frame_rate,
+ "task": "Interpolating"
+ })
+ frame_start = frame_time
+ return image
+
+ if interpolate_frames:
+ logger.debug(f"Beginning interpolation - will interpolate {image_count} frames with interpolation amount(s) [{interpolate_frames}] (a total of {interpolated_count} frames")
+ images = [
+ trigger_callback(img) for img in
+ self.interpolate_recursive(
+ frames=images,
+ multiplier=interpolate_frames,
+ )
+ ]
+ elif reflect:
+ interpolated_count -= image_count # Small interpolation amount
+
+ if reflect:
+ logger.debug(f"Beginning reflection, will interpolate {double_ease_frames} frame(s) twice, {ease_frames} frame(s) once and hold {hold_frames} frame(s).")
+ images = self.loop(
+ frames=images,
+ ease_frames=ease_frames,
+ double_ease_frames=double_ease_frames,
+ hold_frames=hold_frames,
+ trigger_callback=trigger_callback
+ )
+
+ if "save_path" in instruction_payload:
+ from enfugue.diffusion.util.video_util import Video
+ Video(images).save(
+ instruction_payload["save_path"],
+ rate=instruction_payload.get("video_rate", 8.0),
+ encoder=instruction_payload.get("video_codec", "avc1"),
+ overwrite=True
+ )
+ return instruction_payload["save_path"]
+ return images
+
+ def interpolate(
+ self,
+ left: Image,
+ right: Image,
+ alpha: float
+ ) -> Image:
+ """
+ Executes an individual interpolation.
+ """
+ left_data = np.asarray(left.convert("RGB")).astype(np.float32) / _UINT8_MAX_F
+ right_data = np.asarray(right.convert("RGB")).astype(np.float32) / _UINT8_MAX_F
+ mid_data = self.interpolator(
+ np.expand_dims(left_data, axis=0),
+ np.expand_dims(right_data, axis=0),
+ np.full(shape=(1,), fill_value=alpha, dtype=np.float32)
+ )[0]
+ mid_data = (
+ np.clip(mid_data * _UINT8_MAX_F, 0.0, _UINT8_MAX_F) + 0.5
+ ).astype(np.uint8)
+ return Image.fromarray(mid_data)
+
+ def handle(
+ self,
+ instruction_id: int,
+ instruction_action: str,
+ instruction_payload: Any
+ ) -> Any:
+ """
+ Processes two images and returns the interpolated image.
+ """
+ if not isinstance(instruction_payload, dict):
+ raise IOError("Expected dictionary payload")
+
+ if "path" in instruction_payload:
+ self.interpolator_path = instruction_payload["path"]
+
+ if instruction_action == "plan":
+ return self.handle_plan(
+ instruction_id=instruction_id,
+ instruction_payload=instruction_payload
+ )
+
+ to_process = []
+
+ if instruction_action == "process":
+ left = instruction_payload["left"]
+ right = instruction_payload["right"]
+ alpha = instruction_payload["alpha"]
+ to_process.append((left, right, alpha))
+ elif instruction_action == "batch":
+ for image_dict in instruction_payload["batch"]:
+ left = image_dict["left"]
+ right = image_dict["right"]
+ alpha = image_dict["alpha"]
+ to_process.append((left, right, alpha))
+
+ results = []
+ for left, right, alpha in to_process:
+ if isinstance(alpha, list):
+ results.append([
+ self.interpolate(left, right, a)
+ for a in alpha
+ ])
+ else:
+ results.append(self.interpolate(left, right, alpha))
+
+ return results
+
+class InterpolationEngine(Engine):
+ """
+ Manages the interpolate in a sub-process
+ """
+
+ STYLE_MODEL_REPO = "akhaliq/frame-interpolation-film-style"
+
+ @property
+ def process_class(self) -> Type[EngineProcess]:
+ """
+ Override to pass interpolator process
+ """
+ return InterpolatorEngineProcess
+
+ @property
+ def model_dir(self) -> str:
+ """
+ Gets the model directory from config
+ """
+ path = self.configuration.get("enfugue.engine.other", "~/.cache/enfugue/other")
+ if path.startswith("~"):
+ path = os.path.expanduser(path)
+ path = os.path.realpath(path)
+ check_make_directory(path)
+ return path
+
+ @property
+ def style_model_path(self) -> str:
+ """
+ Gets the style model path
+ """
+ if not hasattr(self, "_style_model_path"):
+ from huggingface_hub import snapshot_download
+ self._style_model_path = os.path.join(
+ self.model_dir,
+ "models--" + self.STYLE_MODEL_REPO.replace("/", "--")
+ )
+ if not os.path.exists(self._style_model_path):
+ os.makedirs(self._style_model_path)
+ if not os.path.exists(os.path.join(self._style_model_path, "saved_model.pb")):
+ snapshot_download(
+ self.STYLE_MODEL_REPO,
+ local_dir=self._style_model_path,
+ local_dir_use_symlinks=False
+ )
+ return self._style_model_path
+
+ def dispatch(
+ self,
+ action: str,
+ payload: Any = None,
+ spawn_process: bool = True
+ ) -> Any:
+ """
+ Intercept dispatch to inject path
+ """
+ if isinstance(payload, dict) and "path" not in payload:
+ payload["path"] = self.style_model_path
+ return super(InterpolationEngine, self).dispatch(
+ action,
+ payload,
+ spawn_process
+ )
+
+ def __call__(
+ self,
+ images: List[Image],
+ interpolate_frames: Union[int, Tuple[int]],
+ ) -> List[Image]:
+ """
+ Executes interpolation.
+ """
+ return self.invoke(
+ "recursive",
+ {
+ "path": self.style_model_path,
+ "images": images,
+ "frames": interpolate_frames
+ }
+ )
diff --git a/src/python/enfugue/diffusion/interpolate/model.py b/src/python/enfugue/diffusion/interpolate/model.py
new file mode 100644
index 00000000..c46bd1ce
--- /dev/null
+++ b/src/python/enfugue/diffusion/interpolate/model.py
@@ -0,0 +1,210 @@
+# type: ignore
+# Copyright 2022 Google LLC
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# https://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A wrapper class for running a frame interpolation TF2 saved model.
+
+Usage:
+ model_path='/tmp/saved_model/'
+ it = Interpolator(model_path)
+ result_batch = it.interpolate(image_batch_0, image_batch_1, batch_dt)
+
+ Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
+ (B,H,W,C) layout, batch_dt is the sub-frame time in range [0,1], (B,) layout.
+"""
+from typing import List, Optional
+import numpy as np
+import tensorflow as tf
+
+
+def _pad_to_align(x, align):
+ """Pad image batch x so width and height divide by align.
+
+ Args:
+ x: Image batch to align.
+ align: Number to align to.
+
+ Returns:
+ 1) An image padded so width % align == 0 and height % align == 0.
+ 2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
+ to undo the padding.
+ """
+ # Input checking.
+ assert np.ndim(x) == 4
+ assert align > 0, 'align must be a positive number.'
+
+ height, width = x.shape[-3:-1]
+ height_to_pad = (align - height % align) if height % align != 0 else 0
+ width_to_pad = (align - width % align) if width % align != 0 else 0
+
+ bbox_to_pad = {
+ 'offset_height': height_to_pad // 2,
+ 'offset_width': width_to_pad // 2,
+ 'target_height': height + height_to_pad,
+ 'target_width': width + width_to_pad
+ }
+ padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
+ bbox_to_crop = {
+ 'offset_height': height_to_pad // 2,
+ 'offset_width': width_to_pad // 2,
+ 'target_height': height,
+ 'target_width': width
+ }
+ return padded_x, bbox_to_crop
+
+
+def image_to_patches(image: np.ndarray, block_shape: List[int]) -> np.ndarray:
+ """Folds an image into patches and stacks along the batch dimension.
+
+ Args:
+ image: The input image of shape [B, H, W, C].
+ block_shape: The number of patches along the height and width to extract.
+ Each patch is shaped (H/block_shape[0], W/block_shape[1])
+
+ Returns:
+ The extracted patches shaped [num_blocks, patch_height, patch_width,...],
+ with num_blocks = block_shape[0] * block_shape[1].
+ """
+ block_height, block_width = block_shape
+ num_blocks = block_height * block_width
+
+ height, width, channel = image.shape[-3:]
+ patch_height, patch_width = height//block_height, width//block_width
+
+ assert height == (
+ patch_height * block_height
+ ), 'block_height=%d should evenly divide height=%d.'%(block_height, height)
+ assert width == (
+ patch_width * block_width
+ ), 'block_width=%d should evenly divide width=%d.'%(block_width, width)
+
+ patch_size = patch_height * patch_width
+ paddings = 2*[[0, 0]]
+
+ patches = tf.space_to_batch(image, [patch_height, patch_width], paddings)
+ patches = tf.split(patches, patch_size, 0)
+ patches = tf.stack(patches, axis=3)
+ patches = tf.reshape(patches,
+ [num_blocks, patch_height, patch_width, channel])
+ return patches.numpy()
+
+
+def patches_to_image(patches: np.ndarray, block_shape: List[int]) -> np.ndarray:
+ """Unfolds patches (stacked along batch) into an image.
+
+ Args:
+ patches: The input patches, shaped [num_patches, patch_H, patch_W, C].
+ block_shape: The number of patches along the height and width to unfold.
+ Each patch assumed to be shaped (H/block_shape[0], W/block_shape[1]).
+
+ Returns:
+ The unfolded image shaped [B, H, W, C].
+ """
+ block_height, block_width = block_shape
+ paddings = 2 * [[0, 0]]
+
+ patch_height, patch_width, channel = patches.shape[-3:]
+ patch_size = patch_height * patch_width
+
+ patches = tf.reshape(patches,
+ [1, block_height, block_width, patch_size, channel])
+ patches = tf.split(patches, patch_size, axis=3)
+ patches = tf.stack(patches, axis=0)
+ patches = tf.reshape(patches,
+ [patch_size, block_height, block_width, channel])
+ image = tf.batch_to_space(patches, [patch_height, patch_width], paddings)
+ return image.numpy()
+
+
+class InterpolatorModel:
+ """A class for generating interpolated frames between two input frames.
+
+ Uses TF2 saved model format.
+ """
+
+ def __init__(self, model_path: str,
+ align: Optional[int] = None,
+ block_shape: Optional[List[int]] = None) -> None:
+ """Loads a saved model.
+
+ Args:
+ model_path: Path to the saved model. If none are provided, uses the
+ default model.
+ align: 'If >1, pad the input size so it divides with this before
+ inference.'
+ block_shape: Number of patches along the (height, width) to sid-divide
+ input images.
+ """
+ self._model = tf.compat.v2.saved_model.load(model_path)
+ self._align = align or None
+ self._block_shape = block_shape or None
+
+ def interpolate(self, x0: np.ndarray, x1: np.ndarray,
+ dt: np.ndarray) -> np.ndarray:
+ """Generates an interpolated frame between given two batches of frames.
+
+ All input tensors should be np.float32 datatype.
+
+ Args:
+ x0: First image batch. Dimensions: (batch_size, height, width, channels)
+ x1: Second image batch. Dimensions: (batch_size, height, width, channels)
+ dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
+
+ Returns:
+ The result with dimensions (batch_size, height, width, channels).
+ """
+ if self._align is not None:
+ x0, bbox_to_crop = _pad_to_align(x0, self._align)
+ x1, _ = _pad_to_align(x1, self._align)
+
+ inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
+ result = self._model(inputs, training=False)
+ image = result['image']
+
+ if self._align is not None:
+ image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
+ return image.numpy()
+
+ def __call__(self, x0: np.ndarray, x1: np.ndarray,
+ dt: np.ndarray) -> np.ndarray:
+ """Generates an interpolated frame between given two batches of frames.
+
+ All input tensors should be np.float32 datatype.
+
+ Args:
+ x0: First image batch. Dimensions: (batch_size, height, width, channels)
+ x1: Second image batch. Dimensions: (batch_size, height, width, channels)
+ dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
+
+ Returns:
+ The result with dimensions (batch_size, height, width, channels).
+ """
+ if self._block_shape is not None and np.prod(self._block_shape) > 1:
+ # Subdivide high-res images into managable non-overlapping patches.
+ x0_patches = image_to_patches(x0, self._block_shape)
+ x1_patches = image_to_patches(x1, self._block_shape)
+
+ # Run the interpolator on each patch pair.
+ output_patches = []
+ for image_0, image_1 in zip(x0_patches, x1_patches):
+ mid_patch = self.interpolate(image_0[np.newaxis, ...],
+ image_1[np.newaxis, ...], dt)
+ output_patches.append(mid_patch)
+
+ # Reconstruct interpolated image by stitching interpolated patches.
+ output_patches = np.concatenate(output_patches, axis=0)
+ return patches_to_image(output_patches, self._block_shape)
+ else:
+ # Invoke the interpolator once.
+ return self.interpolate(x0, x1, dt)
diff --git a/src/python/enfugue/diffusion/invocation.py b/src/python/enfugue/diffusion/invocation.py
new file mode 100644
index 00000000..93681d43
--- /dev/null
+++ b/src/python/enfugue/diffusion/invocation.py
@@ -0,0 +1,1721 @@
+from __future__ import annotations
+
+import io
+import inspect
+
+from contextlib import contextmanager, ExitStack
+from datetime import datetime
+from PIL.PngImagePlugin import PngInfo
+
+from dataclasses import (
+ dataclass,
+ asdict,
+ field,
+)
+
+from typing import (
+ Optional,
+ Dict,
+ Any,
+ Union,
+ Tuple,
+ List,
+ Callable,
+ Iterator,
+ Callable,
+ Optional,
+ TYPE_CHECKING,
+)
+from random import randint
+from pibble.util.strings import Serializer
+
+from enfugue.util import (
+ logger,
+ feather_mask,
+ fit_image,
+ get_frames_or_image,
+ get_frames_or_image_from_file,
+ save_frames_or_image,
+ redact_images_from_metadata,
+ merge_tokens,
+)
+
+from enfugue.diffusion.constants import *
+
+if TYPE_CHECKING:
+ from PIL.Image import Image
+ from enfugue.diffusers.manager import DiffusionPipelineManager
+ from diffusers.pipeline.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+ from enfugue.util import IMAGE_FIT_LITERAL, IMAGE_ANCHOR_LITERAL
+
+__all__ = ["LayeredInvocation"]
+
+@dataclass
+class LayeredInvocation:
+ """
+ A serializable class holding all vars for an invocation
+ """
+ # Dimensions, required
+ width: int
+ height: int
+ # Model args
+ model: Optional[str]=None
+ refiner: Optional[str]=None
+ inpainter: Optional[str]=None
+ vae: Optional[str]=None
+ refiner_vae: Optional[str]=None
+ inpainter_vae: Optional[str]=None
+ lora: Optional[Union[str, List[str], Tuple[str, float], List[Union[str, Tuple[str, float]]]]]=None
+ lycoris: Optional[Union[str, List[str], Tuple[str, float], List[Union[str, Tuple[str, float]]]]]=None
+ inversion: Optional[Union[str, List[str]]]=None
+ scheduler: Optional[SCHEDULER_LITERAL]=None
+ ip_adapter_model: Optional[IP_ADAPTER_LITERAL]=None
+ # Custom model args
+ model_prompt: Optional[str]=None
+ model_prompt_2: Optional[str]=None
+ model_negative_prompt: Optional[str]=None
+ model_negative_prompt_2: Optional[str]=None
+ # Invocation
+ prompts: Optional[List[PromptDict]]=None
+ prompt: Optional[str]=None
+ prompt_2: Optional[str]=None
+ negative_prompt: Optional[str]=None
+ negative_prompt_2: Optional[str]=None
+ clip_skip: Optional[int]=None
+ tiling_size: Optional[int]=None
+ tiling_stride: Optional[int]=None
+ tiling_mask_type: Optional[MASK_TYPE_LITERAL]=None
+ tiling_mask_kwargs: Optional[Dict[str, Any]]=None
+ # Layers
+ layers: List[Dict[str, Any]]=field(default_factory=list) #TODO: stronger type
+ # Generation
+ samples: int=1
+ iterations: int=1
+ seed: Optional[int]=None
+ tile: Union[bool, Tuple[bool, bool], List[bool]]=False
+ # Tweaks
+ freeu_factors: Optional[Tuple[float, float, float, float]]=None
+ guidance_scale: Optional[float]=None
+ num_inference_steps: Optional[int]=None
+ noise_offset: Optional[float]=None
+ noise_method: NOISE_METHOD_LITERAL="perlin"
+ noise_blend_method: LATENT_BLEND_METHOD_LITERAL="inject"
+ # Animation
+ animation_frames: Optional[int]=None
+ animation_rate: int=8
+ frame_window_size: Optional[int]=16
+ frame_window_stride: Optional[int]=4
+ loop: bool=False
+ motion_module: Optional[str]=None
+ motion_scale: Optional[float]=None
+ position_encoding_truncate_length: Optional[int]=None
+ position_encoding_scale_length: Optional[int]=None
+ # img2img
+ strength: Optional[float]=None
+ # Inpainting
+ mask: Optional[Union[str, Image, List[Image]]]=None
+ crop_inpaint: bool=True
+ inpaint_feather: int=32
+ outpaint: bool=True
+ # Refining
+ refiner_start: Optional[float]=None
+ refiner_strength: Optional[float]=None
+ refiner_guidance_scale: float=DEFAULT_REFINER_GUIDANCE_SCALE
+ refiner_aesthetic_score: float=DEFAULT_AESTHETIC_SCORE
+ refiner_negative_aesthetic_score: float=DEFAULT_NEGATIVE_AESTHETIC_SCORE
+ refiner_prompt: Optional[str]=None
+ refiner_prompt_2: Optional[str]=None
+ refiner_negative_prompt: Optional[str]=None
+ refiner_negative_prompt_2: Optional[str]=None
+ # Flags
+ build_tensorrt: bool=False
+ # Post-processing
+ upscale: Optional[Union[UpscaleStepDict, List[UpscaleStepDict]]]=None
+ interpolate_frames: Optional[Union[int, Tuple[int, ...], List[int]]]=None
+ reflect: bool=False
+
+ @staticmethod
+ def merge_prompts(*args: Tuple[Optional[str], float]) -> Optional[str]:
+ """
+ Merges prompts if they are not null
+ """
+ if all([not prompt for prompt, weight in args]):
+ return None
+ return merge_tokens(**dict([
+ (prompt, weight)
+ for prompt, weight in args
+ if prompt
+ ]))
+
+ @classmethod
+ def get_image_bounding_box(
+ cls,
+ mask: Image,
+ size: int,
+ feather: int=32
+ ) -> List[Tuple[int, int]]:
+ """
+ Gets the feathered bounding box for an image
+ """
+ width, height = mask.size
+ x0, y0, x1, y1 = mask.getbbox()
+
+ # Add feather
+ x0 = max(0, x0 - feather)
+ x1 = min(width, x1 + feather)
+ y0 = max(0, y0 - feather)
+ y1 = min(height, y1 + feather)
+
+ # Create centered frame about the bounding box
+ bbox_width = x1 - x0
+ bbox_height = y1 - y0
+
+ if bbox_width < size:
+ x0 = max(0, x0 - ((size - bbox_width) // 2))
+ x1 = min(width, x0 + size)
+ x0 = max(0, x1 - size)
+ if bbox_height < size:
+ y0 = max(0, y0 - ((size - bbox_height) // 2))
+ y1 = min(height, y0 + size)
+ y0 = max(0, y1 - size)
+
+ return [(x0, y0), (x1, y1)]
+
+ @classmethod
+ def get_inpaint_bounding_box(
+ cls,
+ mask: Union[Image, List[Image]],
+ size: int,
+ feather: int=32
+ ) -> List[Tuple[int, int]]:
+ """
+ Gets the bounding box of places inpainted
+ """
+ # Find bounding box
+ (x0, y0) = (0, 0)
+ (x1, y1) = (0, 0)
+ for mask_image in (mask if isinstance(mask, list) else [mask]):
+ (frame_x0, frame_y0), (frame_x1, frame_y1) = cls.get_image_bounding_box(
+ mask=mask_image,
+ size=size,
+ feather=feather,
+ )
+ x0 = max(x0, frame_x0)
+ y0 = max(y0, frame_y0)
+ x1 = max(x1, frame_x1)
+ y1 = max(y1, frame_y1)
+
+ return [(x0, y0), (x1, y1)]
+
+ @classmethod
+ def paste_inpaint_image(
+ cls,
+ background: Image,
+ foreground: Image,
+ position: Tuple[int, int],
+ inpaint_feather: int=32,
+ ) -> Image:
+ """
+ Pastes the inpaint image on the background with an appropriately feathered mask.
+ """
+ from PIL import Image
+ image = background.copy()
+
+ width, height = image.size
+ foreground_width, foreground_height = foreground.size
+ left, top = position[:2]
+ right, bottom = left + foreground_width, top + foreground_height
+
+ feather_left = left > 0
+ feather_top = top > 0
+ feather_right = right < width
+ feather_bottom = bottom < height
+
+ mask = Image.new("L", (foreground_width, foreground_height), 255)
+ for i in range(inpaint_feather):
+ multiplier = (i + 1) / (inpaint_feather + 1)
+ pixels = []
+ if feather_left:
+ pixels.extend([(i, j) for j in range(foreground_height)])
+ if feather_top:
+ pixels.extend([(j, i) for j in range(foreground_width)])
+ if feather_right:
+ pixels.extend([(foreground_width - i - 1, j) for j in range(foreground_height)])
+ if feather_bottom:
+ pixels.extend([(j, foreground_height - i - 1) for j in range(foreground_width)])
+ for x, y in pixels:
+ mask.putpixel((x, y), int(mask.getpixel((x, y)) * multiplier))
+
+ image.paste(foreground, position, mask=mask)
+ return image
+
+ @property
+ def upscale_steps(self) -> Iterator[UpscaleStepDict]:
+ """
+ Iterates over upscale steps.
+ """
+ if self.upscale is not None:
+ if isinstance(self.upscale, list):
+ for step in self.upscale:
+ yield step
+ else:
+ yield self.upscale
+
+ @property
+ def kwargs(self) -> Dict[str, Any]:
+ """
+ Returns the keyword arguments that will passed to the pipeline invocation.
+ """
+ return {
+ "width": self.width,
+ "height": self.height,
+ "strength": self.strength,
+ "animation_frames": self.animation_frames,
+ "tile": tuple(self.tile[:2]) if isinstance(self.tile, list) else self.tile,
+ "freeu_factors": self.freeu_factors,
+ "num_inference_steps": self.num_inference_steps,
+ "num_results_per_prompt": self.samples,
+ "noise_offset": self.noise_offset,
+ "noise_method": self.noise_method,
+ "noise_blend_method": self.noise_blend_method,
+ "loop": self.loop,
+ "tiling_size": self.tiling_size,
+ "tiling_stride": self.tiling_stride,
+ "tiling_mask_type": self.tiling_mask_type,
+ "motion_scale": self.motion_scale,
+ "frame_window_size": self.frame_window_size,
+ "frame_window_stride": self.frame_window_stride,
+ "guidance_scale": self.guidance_scale,
+ "refiner_start": self.refiner_start,
+ "refiner_strength": self.refiner_strength,
+ "refiner_guidance_scale": self.refiner_guidance_scale,
+ "refiner_aesthetic_score": self.refiner_aesthetic_score,
+ "refiner_negative_aesthetic_score": self.refiner_negative_aesthetic_score,
+ "refiner_prompt": self.refiner_prompt,
+ "refiner_prompt_2": self.refiner_prompt_2,
+ "refiner_negative_prompt": self.refiner_negative_prompt,
+ "refiner_negative_prompt_2": self.refiner_negative_prompt_2,
+ "ip_adapter_model": self.ip_adapter_model,
+ }
+
+ @classmethod
+ def prepare_image(
+ cls,
+ width: int,
+ height: int,
+ image: Union[str, Image, List[Image]],
+ animation_frames: Optional[int]=None,
+ fit: Optional[IMAGE_FIT_LITERAL]=None,
+ anchor: Optional[IMAGE_ANCHOR_LITERAL]=None,
+ skip_frames: Optional[int]=None,
+ divide_frames: Optional[int]=None,
+ x: Optional[int]=None,
+ y: Optional[int]=None,
+ w: Optional[int]=None,
+ h: Optional[int]=None,
+ return_mask: bool=True,
+ ) -> Union[Image, List[Image], Tuple[Image, Image], Tuple[List[Image], List[Image]]]:
+ """
+ Fits an image on the canvas and returns it and it's alpha mask
+ """
+ from PIL import Image
+
+ if isinstance(image, str):
+ image = get_frames_or_image_from_file(image)
+
+ if skip_frames:
+ image = image[skip_frames:]
+ if divide_frames:
+ image = [
+ img for i, img in enumerate(image)
+ if i % divide_frames == 0
+ ]
+
+ if w is not None and h is not None:
+ fitted_image = fit_image(image, w, h, fit, anchor)
+ else:
+ fitted_image = fit_image(image, width, height, fit, anchor)
+
+ if x is not None and y is not None:
+ if isinstance(fitted_image, list):
+ for i, img in enumerate(fitted_image):
+ blank_image = Image.new("RGBA", (width, height), (0,0,0,0))
+ if img.mode == "RGBA":
+ blank_image.paste(img, (x, y), img)
+ else:
+ blank_image.paste(img, (x, y))
+ fitted_image[i] = blank_image
+ else:
+ blank_image = Image.new("RGBA", (width, height), (0,0,0,0))
+ if fitted_image.mode == "RGBA":
+ blank_image.paste(fitted_image, (x, y), fitted_image)
+ else:
+ blank_image.paste(fitted_image, (x, y))
+ fitted_image = blank_image
+
+ if isinstance(fitted_image, list):
+ if not animation_frames:
+ fitted_image = fitted_image[0]
+ else:
+ fitted_image = fitted_image[:animation_frames]
+
+ if not return_mask:
+ return fitted_image
+
+ if isinstance(fitted_image, list):
+ image_mask = [
+ Image.new("1", (width, height), (1))
+ for i in range(len(fitted_image))
+ ]
+ else:
+ image_mask = Image.new("1", (width, height), (1))
+
+ black = Image.new("1", (width, height), (0))
+
+ if isinstance(fitted_image, list):
+ for i, img in enumerate(fitted_image):
+ fitted_alpha = img.split()[-1]
+ fitted_alpha_inverse_clamp = Image.eval(fitted_alpha, lambda a: 0 if a > 128 else 255)
+ image_mask[i].paste(black, mask=fitted_alpha_inverse_clamp)
+ else:
+ fitted_alpha = fitted_image.split()[-1]
+ fitted_alpha_inverse_clamp = Image.eval(fitted_alpha, lambda a: 0 if a > 128 else 255)
+ image_mask.paste(black, mask=fitted_alpha_inverse_clamp) # type: ignore[attr-defined]
+
+ return fitted_image, image_mask
+
+ @classmethod
+ def assemble(
+ cls,
+ size: int=512,
+ image: Optional[Union[str, Image, List[Image], ImageDict]]=None,
+ ip_adapter_images: Optional[List[IPAdapterImageDict]]=None,
+ control_images: Optional[List[ControlImageDict]]=None,
+ loop: Union[bool, str]=False,
+ **kwargs: Any
+ ) -> LayeredInvocation:
+ """
+ Assembles an invocation from layers, standardizing arguments
+ """
+ invocation_kwargs = dict([
+ (k, v) for k, v in kwargs.items()
+ if k in inspect.signature(cls).parameters
+ ])
+ ignored_kwargs = set(list(kwargs.keys())) - set(list(invocation_kwargs.keys()))
+
+ # Add directly passed images to layers
+ added_layers = []
+ if image:
+ if isinstance(image, dict):
+ added_layers.append(image)
+ else:
+ added_layers.append({"image": image})
+ if ip_adapter_images:
+ for ip_adapter_image in ip_adapter_images:
+ added_layers.append({
+ "image": ip_adapter_image["image"],
+ "ip_adapter_scale": ip_adapter_image.get("scale", 1.0),
+ "fit": ip_adapter_image.get("fit", None),
+ "anchor": ip_adapter_image.get("anchor", None),
+ })
+ if control_images:
+ for control_image in control_images:
+ added_layers.append({
+ "image": control_image["image"],
+ "fit": control_image.get("fit", None),
+ "anchor": control_image.get("anchor", None),
+ "control_units": [
+ {
+ "controlnet": control_image["controlnet"],
+ "scale": control_image.get("scale", 1.0),
+ "start": control_image.get("start", None),
+ "end": control_image.get("end", None),
+ "process": control_image.get("process", True),
+ }
+ ]
+ })
+
+ # Reassign layers
+ if "layers" in invocation_kwargs:
+ invocation_kwargs["layers"].extend(added_layers)
+ else:
+ invocation_kwargs["layers"] = added_layers
+
+ # Gather size of images for defaults and trim video
+ animation_frames = invocation_kwargs.get("animation_frames", None)
+ image_width, image_height = 0, 0
+ for layer in invocation_kwargs["layers"]:
+ # Standardize images
+ if isinstance(layer["image"], str):
+ layer["image"] = get_frames_or_image_from_file(layer["image"])
+
+ elif not isinstance(layer["image"], list):
+ layer["image"] = get_frames_or_image(layer["image"])
+
+ skip_frames = layer.pop("skip_frames", None)
+ divide_frames = layer.pop("divide_frames", None)
+
+ if skip_frames and isinstance(layer["image"], list):
+ layer["image"] = layer["image"][skip_frames:]
+
+ if divide_frames and isinstance(layer["image"], list):
+ layer["image"] = [
+ img for i, img in enumerate(layer["image"])
+ if i % divide_frames == 0
+ ]
+
+ if isinstance(layer["image"], list):
+ if animation_frames:
+ layer["image"] = layer["image"][:animation_frames]
+ else:
+ layer["image"] = layer["image"][0]
+
+ # Check if this image is visible
+ if (
+ layer.get("image", None) is not None and
+ (
+ layer.get("denoise", False) or
+ (
+ not layer.get("ip_adapter_scale", None) and
+ not layer.get("control_units", [])
+ )
+ )
+ ):
+ layer_x = layer.get("x", 0)
+ layer_y = layer.get("y", 0)
+
+ if isinstance(layer["image"], list):
+ image_w, image_h = layer["image"][0].size
+ else:
+ image_w, image_h = layer["image"].size
+
+ layer_w = layer.get("w", image_w)
+ layer_h = layer.get("h", image_h)
+ image_width = max(image_width, layer_x + layer_w)
+ image_height = max(image_height, layer_y + layer_h)
+
+ # Check sizes
+ if not invocation_kwargs.get("width", None):
+ invocation_kwargs["width"] = image_width if image_width else size
+ if not invocation_kwargs.get("height", None):
+ invocation_kwargs["height"] = image_height if image_height else size
+
+ # Add seed if not set
+ if not invocation_kwargs.get("seed", None):
+ invocation_kwargs["seed"] = randint(0,2**32)
+
+ # Check loop
+ if isinstance(loop, bool):
+ invocation_kwargs["loop"] = loop
+ elif isinstance(loop, str):
+ invocation_kwargs["loop"] = loop == "loop"
+ invocation_kwargs["reflect"] = loop == "reflect"
+
+ if ignored_kwargs:
+ logger.warning(f"Ignored keyword arguments: {ignored_kwargs}")
+
+ return cls(**invocation_kwargs)
+
+ @classmethod
+ def minimize_dict(
+ cls,
+ kwargs: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """
+ Pops unnecessary variables from an invocation dict
+ """
+ all_keys = list(kwargs.keys())
+ minimal_keys = []
+
+ has_refiner = bool(kwargs.get("refiner", None))
+ has_noise = bool(kwargs.get("noise_offset", None))
+ will_inpaint = bool(kwargs.get("mask", None))
+ will_animate = bool(kwargs.get("animation_frames", None))
+
+ has_ip_adapter = bool(kwargs.get("ip_adapter_images", None))
+ has_ip_adapter = has_ip_adapter or any([
+ bool(layer.get("ip_adapter_scale", None))
+ for layer in kwargs.get("layers", [])
+ ])
+
+ for key in all_keys:
+ value = kwargs[key]
+ if value is None:
+ continue
+ if "layers" == key and not value:
+ continue
+ if "tile" == key and value == False:
+ continue
+ if "refiner" in key and not has_refiner:
+ continue
+ if "inpaint" in key and not will_inpaint:
+ continue
+ if "ip_adapter" in key and not has_ip_adapter:
+ continue
+ if (
+ (
+ "motion" in key or
+ "temporal" in key or
+ "animation" in key or
+ "frame" in key or
+ "loop" in key or
+ "reflect" in key
+ )
+ and not will_animate
+ ):
+ continue
+
+ if "noise" in key and not has_noise:
+ continue
+
+ minimal_keys.append(key)
+
+ return dict([
+ (key, kwargs[key])
+ for key in minimal_keys
+ ])
+
+ @classmethod
+ def format_serialization_dict(
+ cls,
+ save_directory: Optional[str]=None,
+ save_name: Optional[str]=None,
+ mask: Optional[Union[str, Image, List[Image]]]=None,
+ layers: Optional[List[Dict]]=None,
+ **kwargs: Any
+ ) -> Dict[str, Any]:
+ """
+ Formats kwargs to remove images and instead reutnr temporary paths if possible
+ """
+ kwargs["mask"] = mask
+ kwargs["layers"] = layers
+ if save_directory is not None:
+ if mask is not None:
+ if isinstance(mask, dict):
+ mask["image"] = save_frames_or_image(
+ image=mask["image"],
+ directory=save_directory,
+ name=save_name
+ )
+ else:
+ kwargs["mask"] = save_frames_or_image(
+ image=mask,
+ directory=save_directory,
+ name=f"{save_name}_mask" if save_name is not None else None
+ )
+
+ if layers is not None:
+ for i, layer in enumerate(layers):
+ layer_image = layer.get("image", None)
+ if layer_image:
+ layer["image"] = save_frames_or_image(
+ image=layer_image,
+ directory=save_directory,
+ name=f"{save_name}_layer_{i}" if save_name is not None else None
+ )
+
+ return cls.minimize_dict(kwargs)
+
+ def serialize(
+ self,
+ save_directory: Optional[str]=None,
+ save_name: Optional[str]=None,
+ ) -> Dict[str, Any]:
+ """
+ Assembles self into a serializable dict
+ """
+ return self.format_serialization_dict(
+ save_directory=save_directory,
+ save_name=save_name,
+ **asdict(self)
+ )
+
+ @contextmanager
+ def preprocessors(
+ self,
+ pipeline: DiffusionPipelineManager
+ ) -> Iterator[Dict[str, Callable[[Image], Image]]]:
+ """
+ Gets all preprocessors needed for this invocation
+ """
+ needs_background_remover = False
+ needs_control_processors = []
+ to_check: List[Dict[str, Any]] = []
+
+ if self.layers is not None:
+ for layer in self.layers:
+ if layer.get("image", None) is not None:
+ to_check.append(layer)
+ for image_dict in to_check:
+ if image_dict.get("remove_background", False):
+ needs_background_remover = True
+ for control_dict in image_dict.get("control_units", []):
+ if control_dict.get("process", True) and control_dict.get("controlnet", None) is not None:
+ needs_control_processors.append(control_dict["controlnet"])
+
+ with ExitStack() as stack:
+ processors: Dict[str, Callable[[Image], Image]] = {}
+ if needs_background_remover:
+ processors["background_remover"] = stack.enter_context(
+ pipeline.background_remover.remover()
+ )
+ if needs_control_processors:
+ processor_names = list(set(needs_control_processors))
+ with pipeline.control_image_processor.processors(*processor_names) as processor_callables:
+ processors = {**processors, **dict(zip(processor_names, processor_callables))}
+ yield processors
+ else:
+ yield processors
+
+ def preprocess(
+ self,
+ pipeline: DiffusionPipelineManager,
+ intermediate_dir: Optional[str]=None,
+ raise_when_unused: bool=True,
+ task_callback: Optional[Callable[[str], None]]=None,
+ progress_callback: Optional[Callable[[int, int, float], None]]=None,
+ **kwargs: Any
+ ) -> Dict[str, Any]:
+ """
+ Processes/transforms arguments
+ """
+ from PIL import Image, ImageOps
+ from enfugue.diffusion.util.prompt_util import Prompt
+
+ # Gather images for preprocessing
+ control_images: Dict[str, List[Dict]] = {}
+ ip_adapter_images = []
+ invocation_mask = None
+ invocation_image = None
+ no_inference = False
+
+ if self.layers:
+ if task_callback is not None:
+ task_callback("Pre-processing layers")
+
+ # Blank images used for merging
+ black = Image.new("1", (self.width, self.height), (0))
+ white = Image.new("1", (self.width, self.height), (1))
+
+ mask = self.mask
+
+ # Standardize mask
+ if mask is not None:
+ if isinstance(mask, dict):
+ invert = mask.get("invert", False)
+ mask = mask.get("image", None)
+ if isinstance(mask, str):
+ mask = get_frames_or_image_from_file(mask)
+ if not mask:
+ raise ValueError("Expected mask dictionary to have 'image' key")
+ if invert:
+ from PIL import ImageOps
+ if isinstance(mask, list):
+ mask = [
+ ImageOps.invert(img) for img in mask
+ ]
+ else:
+ mask = ImageOps.invert(mask)
+ elif isinstance(mask, str):
+ mask = get_frames_or_image_from_file(mask)
+
+ mask = self.prepare_image(
+ width=self.width,
+ height=self.height,
+ image=mask,
+ animation_frames=self.animation_frames,
+ return_mask=False
+ )
+
+ if self.animation_frames:
+ invocation_mask = [
+ white.copy()
+ for i in range(self.animation_frames)
+ ]
+ invocation_image = [
+ Image.new("RGBA", (self.width, self.height), (0,0,0,0))
+ for i in range(self.animation_frames)
+ ]
+ else:
+ invocation_image = Image.new("RGBA", (self.width, self.height), (0,0,0,0))
+ invocation_mask = white.copy()
+
+ has_invocation_image = False
+
+ # Get a count of preprocesses required
+ images_to_preprocess = 0
+ for i, layer in enumerate(self.layers):
+ layer_image = layer.get("image", None)
+ layer_skip_frames = layer.get("skip_frames", None)
+ layer_divide_frames = layer.get("divide_frames", None)
+
+ if isinstance(layer_image, str):
+ layer_image = get_frames_or_image_from_file(layer_image)
+ layer["image"] = layer_image
+
+ image_count = len(layer_image) if isinstance(layer_image, list) else 1
+ if layer_skip_frames:
+ image_count -= layer_skip_frames
+ if layer_divide_frames:
+ image_count = image_count // layer_divide_frames
+
+ if self.animation_frames:
+ image_count = min(image_count, self.animation_frames)
+
+ if layer.get("remove_background", False):
+ images_to_preprocess += image_count
+
+ control_units = layer.get("control_units", None)
+ if control_units:
+ for control_unit in control_units:
+ if control_unit.get("process", True):
+ images_to_preprocess += image_count
+
+ images_preprocessed = 0
+ last_frame_time = datetime.now()
+ frame_times = []
+
+ def trigger_preprocess_callback(image: Image) -> Image:
+ """
+ Triggers the preprocessor callback
+ """
+ nonlocal last_frame_time
+ nonlocal images_preprocessed
+ if progress_callback is not None:
+ images_preprocessed += 1
+ frame_time = datetime.now()
+ frame_seconds = (frame_time - last_frame_time).total_seconds()
+ frame_times.append(frame_seconds)
+ frame_time_samples = min(len(frame_times), 8)
+ frame_time_average = sum(frame_times[-8:]) / frame_time_samples
+
+ progress_callback(images_preprocessed, images_to_preprocess, 1 / frame_time_average)
+
+ last_frame_time = frame_time
+ return image
+
+ # Preprocess images
+ if images_to_preprocess:
+ logger.debug(f"Pre-processing layers, with {images_to_preprocess} image processing step(s)")
+
+ with self.preprocessors(pipeline) as processors:
+ # Iterate through layers
+ for i, layer in enumerate(self.layers):
+ # Basic information for layer
+ w = layer.get("w", None)
+ h = layer.get("h", None)
+ x = layer.get("x", None)
+ y = layer.get("y", None)
+
+ layer_image = layer.get("image", None)
+ fit = layer.get("fit", None)
+ anchor = layer.get("anchor", None)
+ remove_background = layer.get("remove_background", None)
+ skip_frames = layer.get("skip_frames", None)
+ divide_frames = layer.get("divide_frames", None)
+
+ # Capabilities of layer
+ denoise = layer.get("denoise", False)
+ prompt_scale = layer.get("ip_adapter_scale", False)
+ control_units = layer.get("control_units", [])
+
+ if not layer_image:
+ logger.warning(f"No image, skipping laying {i}")
+ continue
+
+ if isinstance(layer_image, str):
+ layer_image = get_frames_or_image_from_file(layer_image)
+
+ if remove_background:
+ if isinstance(layer_image, list):
+ layer_image = [
+ trigger_preprocess_callback(processors["background_remover"](img))
+ for img in layer_image
+ ]
+ else:
+ layer_image = trigger_preprocess_callback(processors["background_remover"](layer_image))
+
+ fit_layer_image, fit_layer_mask = self.prepare_image(
+ width=self.width,
+ height=self.height,
+ image=layer_image,
+ fit=fit,
+ anchor=anchor,
+ animation_frames=self.animation_frames,
+ divide_frames=divide_frames,
+ skip_frames=skip_frames,
+ w=w,
+ h=h,
+ x=x,
+ y=y
+ )
+
+ if isinstance(fit_layer_mask, list):
+ inverse_fit_layer_mask = [
+ ImageOps.invert(img)
+ for img in fit_layer_mask
+ ]
+ else:
+ inverse_fit_layer_mask = ImageOps.invert(fit_layer_mask)
+
+ is_passthrough = not denoise and not prompt_scale and not control_units
+
+ if denoise or is_passthrough:
+ has_invocation_image = True
+
+ if isinstance(fit_layer_image, list):
+ for i in range(len(invocation_image)): # type: ignore[arg-type]
+ invocation_image[i].paste( # type: ignore[index]
+ fit_layer_image[i] if i < len(fit_layer_image) else fit_layer_image[-1],
+ mask=fit_layer_mask[i] if i < len(fit_layer_mask) else fit_layer_mask[-1]
+ )
+ if is_passthrough:
+ invocation_mask[i].paste( # type: ignore[index]
+ black,
+ mask=fit_layer_mask[i] if i < len(fit_layer_mask) else fit_layer_mask[-1]
+ )
+ elif isinstance(invocation_image, list):
+ for i in range(len(invocation_image)):
+ invocation_image[i].paste(fit_layer_image, mask=fit_layer_mask)
+ if is_passthrough:
+ invocation_mask[i].paste(black, mask=fit_layer_mask) # type: ignore[index]
+ else:
+ invocation_image.paste(fit_layer_image, mask=fit_layer_mask) # type: ignore[attr-defined]
+ if is_passthrough:
+ invocation_mask.paste(black, mask=fit_layer_mask) # type: ignore[union-attr]
+
+ if prompt_scale:
+ # ip adapter
+ ip_adapter_images.append({
+ "image": layer_image,
+ "scale": float(prompt_scale)
+ })
+
+ if control_units:
+ for control_unit in control_units:
+ controlnet = control_unit["controlnet"]
+
+ if controlnet not in control_images:
+ control_images[controlnet] = []
+
+ if control_unit.get("process", True):
+ if isinstance(fit_layer_image, list):
+ control_image = [
+ trigger_preprocess_callback(processors[controlnet](img))
+ for img in fit_layer_image
+ ]
+ else:
+ control_image = trigger_preprocess_callback(processors[controlnet](fit_layer_image))
+ elif control_unit.get("invert", False):
+ if isinstance(fit_layer_image, list):
+ control_image = [
+ ImageOps.invert(img)
+ for img in fit_layer_image
+ ]
+ else:
+ control_image = ImageOps.invert(fit_layer_image)
+ else:
+ control_image = fit_layer_image
+
+ control_images[controlnet].append({
+ "start": control_unit.get("start", 0.0),
+ "end": control_unit.get("end", 1.0),
+ "scale": control_unit.get("scale", 1.0),
+ "image": control_image
+ })
+
+ if not has_invocation_image:
+ invocation_image = None
+ invocation_mask = None
+ else:
+ if mask:
+ if isinstance(invocation_mask, list):
+ if not isinstance(mask, list):
+ mask = [mask.copy() for i in range(len(invocation_mask))] # type: ignore[union-attr]
+ for i in range(len(mask)):
+ mask[i].paste(
+ white,
+ mask=Image.eval(
+ feather_mask(invocation_mask[i]),
+ lambda a: 0 if a < 128 else 255
+ )
+ )
+ invocation_mask = [
+ img.convert("L")
+ for img in mask
+ ]
+ else:
+ # Final mask merge
+ mask.paste( # type: ignore[union-attr]
+ white,
+ mask=Image.eval(
+ feather_mask(invocation_mask),
+ lambda a: 0 if a < 128 else 255
+ )
+ )
+ invocation_mask = mask.convert("L") # type: ignore[union-attr]
+ else:
+ if isinstance(invocation_mask, list):
+ invocation_mask = [
+ feather_mask(img).convert("L") # type: ignore[union-attr]
+ for img in invocation_mask
+ ]
+ else:
+ invocation_mask = feather_mask(invocation_mask).convert("L") # type: ignore[union-attr]
+
+ # Evaluate mask
+ mask_max, mask_min = None, None
+ if isinstance(invocation_mask, list):
+ for img in invocation_mask:
+ this_max, this_min = img.getextrema()
+ if mask_max is None:
+ mask_max = this_max
+ else:
+ mask_max = max(mask_max, this_max) # type: ignore[unreachable]
+ if mask_min is None:
+ mask_min = this_min
+ else:
+ mask_min = min(mask_min, this_min) # type: ignore[unreachable]
+ else:
+ mask_max, mask_min = invocation_mask.getextrema()
+
+ if mask_max == mask_min == 0:
+ # Nothing to do
+ if raise_when_unused:
+ raise IOError("Nothing to do - canvas is covered by non-denoised images. Either modify the canvas such that there is blank space to be filled, enable denoising on an image on the canvas, or add inpainting.")
+ # Might have no invocation
+ invocation_mask = None
+ no_inference = True
+ elif mask_max == mask_min == 255:
+ # No inpainting
+ invocation_mask = None
+ elif not self.outpaint and not mask:
+ # Disabled outpainting
+ invocation_mask = None
+
+ # Evaluate prompts
+ prompts = self.prompts
+ if prompts:
+ # Prompt travel
+ prompts = [
+ Prompt( # type: ignore[misc]
+ positive=self.merge_prompts(
+ (prompt["positive"], 1.0),
+ (self.model_prompt, MODEL_PROMPT_WEIGHT)
+ ),
+ positive_2=self.merge_prompts(
+ (prompt.get("positive_2",None), 1.0),
+ (self.model_prompt_2, MODEL_PROMPT_WEIGHT)
+ ),
+ negative=self.merge_prompts(
+ (prompt.get("negative",None), 1.0),
+ (self.model_negative_prompt, MODEL_PROMPT_WEIGHT)
+ ),
+ negative_2=self.merge_prompts(
+ (prompt.get("negative_2",None), 1.0),
+ (self.model_negative_prompt_2, MODEL_PROMPT_WEIGHT)
+ ),
+ start=prompt.get("start",None),
+ end=prompt.get("end",None),
+ weight=prompt.get("weight", 1.0)
+ )
+ for prompt in prompts
+ ]
+ elif self.prompt is not None:
+ prompts = [
+ Prompt( # type: ignore[list-item]
+ positive=self.merge_prompts(
+ (self.prompt, 1.0),
+ (self.model_prompt, MODEL_PROMPT_WEIGHT)
+ ),
+ positive_2=self.merge_prompts(
+ (self.prompt_2, 1.0),
+ (self.model_prompt_2, MODEL_PROMPT_WEIGHT)
+ ),
+ negative=self.merge_prompts(
+ (self.negative_prompt, 1.0),
+ (self.model_negative_prompt, MODEL_PROMPT_WEIGHT)
+ ),
+ negative_2=self.merge_prompts(
+ (self.negative_prompt_2, 1.0),
+ (self.model_negative_prompt_2, MODEL_PROMPT_WEIGHT)
+ )
+ )
+ ]
+
+ # Refiner prompts
+ refiner_prompts = {
+ "refiner_prompt": self.merge_prompts(
+ (self.refiner_prompt, 1.0),
+ (self.model_prompt, MODEL_PROMPT_WEIGHT)
+ ),
+ "refiner_prompt_2": self.merge_prompts(
+ (self.refiner_prompt_2, 1.0),
+ (self.model_prompt_2, MODEL_PROMPT_WEIGHT)
+ ),
+ "refiner_negative_prompt": self.merge_prompts(
+ (self.refiner_negative_prompt, 1.0),
+ (self.model_negative_prompt, MODEL_PROMPT_WEIGHT)
+ ),
+ "refiner_negative_prompt_2": self.merge_prompts(
+ (self.refiner_negative_prompt_2, 1.0),
+ (self.model_negative_prompt_2, MODEL_PROMPT_WEIGHT)
+ )
+ }
+
+ # Completed pre-processing
+ results_dict: Dict[str, Any] = {
+ "no_inference": no_inference or (self.layers and (not invocation_mask and not self.strength and not control_images and not ip_adapter_images))
+ }
+
+ if invocation_image:
+ results_dict["image"] = invocation_image
+ if invocation_mask:
+ results_dict["mask"] = invocation_mask
+ if control_images:
+ results_dict["control_images"] = control_images
+ if ip_adapter_images:
+ results_dict["ip_adapter_images"] = ip_adapter_images
+ if prompts:
+ results_dict["prompts"] = prompts
+
+ results_dict = self.minimize_dict({
+ **self.kwargs,
+ **refiner_prompts,
+ **results_dict
+ })
+ return results_dict
+
+ def execute(
+ self,
+ pipeline: DiffusionPipelineManager,
+ task_callback: Optional[Callable[[str], None]] = None,
+ progress_callback: Optional[Callable[[int, int, float], None]] = None,
+ image_callback: Optional[Callable[[List[Image]], None]] = None,
+ image_callback_steps: Optional[int] = None,
+ ) -> StableDiffusionPipelineOutput:
+ """
+ This is the main interface for execution.
+
+ The first step will be the one that executes with the selected number of samples,
+ and then each subsequent step will be performed on the number of outputs from the
+ first step.
+ """
+ # We import here so this file can be imported by processes without initializing torch
+ from diffusers.utils.pil_utils import PIL_INTERPOLATION
+
+ if task_callback is None:
+ task_callback = lambda arg: None
+
+ # Set up the pipeline
+ pipeline._task_callback = task_callback
+ self.prepare_pipeline(pipeline)
+
+ cropped_inpaint_position = None
+ background = None
+ has_post_processing = bool(self.upscale)
+ if self.animation_frames:
+ has_post_processing = has_post_processing or bool(self.interpolate_frames) or self.reflect
+
+ invocation_kwargs = self.preprocess(
+ pipeline,
+ raise_when_unused=not has_post_processing,
+ task_callback=task_callback,
+ progress_callback=progress_callback,
+ )
+
+ if invocation_kwargs.pop("no_inference", False):
+ if "image" not in invocation_kwargs:
+ raise IOError("No inference and no images.")
+ images = invocation_kwargs["image"]
+ if not isinstance(images, list):
+ images = [images]
+ nsfw = [False] * len(images)
+ else:
+ # Determine if we're doing cropped inpainting
+ if "mask" in invocation_kwargs and self.crop_inpaint:
+ (x0, y0), (x1, y1) = self.get_inpaint_bounding_box(
+ invocation_kwargs["mask"],
+ size=self.tiling_size if self.tiling_size else 1024 if pipeline.inpainter_is_sdxl else 512,
+ feather=self.inpaint_feather
+ )
+ if isinstance(invocation_kwargs["mask"], list):
+ mask_width, mask_height = invocation_kwargs["mask"][0].size
+ else:
+ mask_width, mask_height = invocation_kwargs["mask"].size
+ bbox_width = x1 - x0
+ bbox_height = y1 - y0
+ pixel_ratio = (bbox_height * bbox_width) / (mask_width * mask_height)
+ pixel_savings = (1.0 - pixel_ratio) * 100
+
+ if pixel_ratio < 0.75:
+ logger.debug(f"Calculated pixel area savings of {pixel_savings:.1f}% by cropping to ({x0}, {y0}), ({x1}, {y1}) ({bbox_width}px by {bbox_height}px)")
+ cropped_inpaint_position = (x0, y0, x1, y1)
+ else:
+ logger.debug(
+ f"Calculated pixel area savings of {pixel_savings:.1f}% are insufficient, will not crop"
+ )
+
+ if cropped_inpaint_position is not None:
+ # Get copies prior to crop
+ if isinstance(invocation_kwargs["image"], list):
+ background = [
+ img.copy()
+ for img in invocation_kwargs["image"]
+ ]
+ else:
+ background = invocation_kwargs["image"].copy()
+
+ # First wrap callbacks if needed
+ if image_callback is not None:
+ # Hijack image callback to paste onto background
+ original_image_callback = image_callback
+
+ def pasted_image_callback(images: List[Image]) -> None:
+ """
+ Paste the images then callback.
+ """
+ if isinstance(background, list):
+ images = [
+ self.paste_inpaint_image((background[i] if i < len(background) else background[-1]), image, cropped_inpaint_position) # type: ignore
+ for i, image in enumerate(images)
+ ]
+ else:
+ images = [
+ self.paste_inpaint_image(background, image, cropped_inpaint_position) # type: ignore
+ for image in images
+ ]
+
+ original_image_callback(images)
+
+ image_callback = pasted_image_callback
+
+ # Now crop images
+ if isinstance(invocation_kwargs["image"], list):
+ invocation_kwargs["image"] = [
+ img.crop(cropped_inpaint_position)
+ for img in invocation_kwargs["image"]
+ ]
+ invocation_kwargs["mask"] = [
+ img.crop(cropped_inpaint_position)
+ for img in invocation_kwargs["mask"]
+ ]
+ else:
+ invocation_kwargs["image"] = invocation_kwargs["image"].crop(cropped_inpaint_position)
+ invocation_kwargs["mask"] = invocation_kwargs["mask"].crop(cropped_inpaint_position)
+
+ # Also crop control images
+ if "control_images" in invocation_kwargs:
+ for controlnet in invocation_kwargs["control_images"]:
+ for image_dict in invocation_kwargs[controlnet]:
+ image_dict["image"] = image_dict["image"].crop(cropped_inpaint_position)
+
+ # Assign height and width
+ x0, y0, x1, y1 = cropped_inpaint_position
+ invocation_kwargs["width"] = x1 - x0
+ invocation_kwargs["height"] = y1 - y0
+
+ # Execute primary inference
+ images, nsfw = self.execute_inference(
+ pipeline,
+ task_callback=task_callback,
+ progress_callback=progress_callback,
+ image_callback=image_callback,
+ image_callback_steps=image_callback_steps,
+ invocation_kwargs=invocation_kwargs
+ )
+
+ if background is not None and cropped_inpaint_position is not None:
+ # Paste the image back onto the background
+ for i, image in enumerate(images):
+ images[i] = self.paste_inpaint_image(
+ background[i] if isinstance(background, list) else background,
+ image,
+ cropped_inpaint_position[:2]
+ )
+
+ # Execte upscale, if requested
+ images, nsfw = self.execute_upscale(
+ pipeline,
+ images=images,
+ nsfw=nsfw,
+ task_callback=task_callback,
+ progress_callback=progress_callback,
+ image_callback=image_callback,
+ invocation_kwargs=invocation_kwargs
+ )
+
+ pipeline.stop_keepalive() # Make sure this is stopped
+ pipeline.clear_memory()
+ return self.format_output(images, nsfw)
+
+ def prepare_pipeline(self, pipeline: DiffusionPipelineManager) -> None:
+ """
+ Assigns pipeline-level variables.
+ """
+ pipeline.start_keepalive() # Make sure this is going
+
+ if self.animation_frames is not None and self.animation_frames > 0:
+ pipeline.animator = self.model
+ pipeline.animator_vae = self.vae
+ pipeline.frame_window_size = self.frame_window_size
+ pipeline.frame_window_stride = self.frame_window_stride
+ pipeline.position_encoding_truncate_length = self.position_encoding_truncate_length
+ pipeline.position_encoding_scale_length = self.position_encoding_scale_length
+ pipeline.motion_module = self.motion_module
+ else:
+ pipeline.model = self.model
+ pipeline.vae = self.vae
+
+ pipeline.tiling_size = self.tiling_size
+ pipeline.tiling_stride = self.tiling_stride
+ pipeline.tiling_mask_type = self.tiling_mask_type
+
+ pipeline.refiner = self.refiner
+ pipeline.refiner_vae = self.refiner_vae
+
+ pipeline.inpainter = self.inpainter
+ pipeline.inpainter_vae = self.inpainter_vae
+
+ pipeline.lora = self.lora
+ pipeline.lycoris = self.lycoris
+ pipeline.inversion = self.inversion
+ pipeline.scheduler = self.scheduler
+
+ if self.build_tensorrt:
+ pipeline.build_tensorrt = True
+
+ def execute_inference(
+ self,
+ pipeline: DiffusionPipelineManager,
+ task_callback: Optional[Callable[[str], None]] = None,
+ progress_callback: Optional[Callable[[int, int, float], None]] = None,
+ image_callback: Optional[Callable[[List[Image]], None]] = None,
+ image_callback_steps: Optional[int] = None,
+ invocation_kwargs: Dict[str, Any] = {}
+ ) -> Tuple[List[Image], List[bool]]:
+ """
+ Executes primary inference
+ """
+ from PIL import Image, ImageDraw
+
+ # Define progress and latent callback kwargs, we'll add task callbacks ourself later
+ callback_kwargs = {
+ "progress_callback": progress_callback,
+ "latent_callback_steps": image_callback_steps,
+ "latent_callback_type": "pil",
+ "task_callback": task_callback
+ }
+
+ if self.seed is not None:
+ # Set up the RNG
+ pipeline.seed = self.seed
+
+ total_images = self.iterations
+ if self.animation_frames is not None and self.animation_frames > 0:
+ total_images *= self.animation_frames
+ else:
+ total_images *= self.samples
+
+ width = pipeline.size if self.width is None else self.width
+ height = pipeline.size if self.height is None else self.height
+
+ if "image" in invocation_kwargs:
+ if isinstance(invocation_kwargs["image"], list):
+ images = [
+ invocation_kwargs["image"][i].copy() if len(invocation_kwargs["image"]) > i else invocation_kwargs["image"][-1].copy()
+ for i in range(total_images)
+ ]
+ else:
+ images = [
+ invocation_kwargs["image"].copy()
+ for i in range(total_images)
+ ]
+ else:
+ images = [
+ Image.new("RGBA", (width, height))
+ for i in range(total_images)
+ ]
+ image_draw = [
+ ImageDraw.Draw(image)
+ for image in images
+ ]
+ nsfw_content_detected = [False] * total_images
+
+ # Trigger the callback with base images after scaling and processing
+ if image_callback is not None and invocation_kwargs.get("image", None):
+ image_callback(images)
+
+ # Determine what controlnets to use
+ controlnets = (
+ None if not invocation_kwargs.get("control_images", None)
+ else list(invocation_kwargs["control_images"].keys())
+ )
+
+ for it in range(self.iterations):
+ if image_callback is not None:
+ def iteration_image_callback(callback_images: List[Image]) -> None:
+ """
+ Wrap the original image callback so we're actually pasting the initial image on the main canvas
+ """
+ for j, callback_image in enumerate(callback_images):
+ image_index = (it * self.samples) + j
+ images[image_index] = callback_image
+ image_callback(images) # type: ignore
+ else:
+ iteration_image_callback = None # type: ignore
+
+ if invocation_kwargs.get("animation_frames", None):
+ pipeline.animator_controlnets = controlnets
+ elif invocation_kwargs.get("mask", None):
+ pipeline.inpainter_controlnets = controlnets
+ else:
+ pipeline.controlnets = controlnets
+
+ result = pipeline(
+ latent_callback=iteration_image_callback,
+ **invocation_kwargs,
+ **callback_kwargs
+ )
+
+ for j, image in enumerate(result["images"]):
+ image_index = (it * self.samples) + j
+ images[image_index] = image
+ nsfw_content_detected[image_index] = nsfw_content_detected[image_index] or (
+ "nsfw_content_detected" in result and result["nsfw_content_detected"][j]
+ )
+
+ # Call the callback
+ if image_callback is not None:
+ image_callback(images)
+
+ return images, nsfw_content_detected
+
+ def execute_upscale(
+ self,
+ pipeline: DiffusionPipelineManager,
+ images: List[Image],
+ nsfw: List[bool],
+ task_callback: Optional[Callable[[str], None]] = None,
+ progress_callback: Optional[Callable[[int, int, float], None]] = None,
+ image_callback: Optional[Callable[[List[Image]], None]] = None,
+ image_callback_steps: Optional[int] = None,
+ invocation_kwargs: Dict[str, Any] = {}
+ ) -> Tuple[List[Image], List[bool]]:
+ """
+ Executes upscale steps
+ """
+ from diffusers.utils.pil_utils import PIL_INTERPOLATION
+ animation_frames = invocation_kwargs.get("animation_frames", None)
+
+ for upscale_step in self.upscale_steps:
+ method = upscale_step["method"]
+ amount = upscale_step["amount"]
+ num_inference_steps = upscale_step.get("num_inference_steps", DEFAULT_UPSCALE_INFERENCE_STEPS)
+ guidance_scale = upscale_step.get("guidance_scale", DEFAULT_UPSCALE_GUIDANCE_SCALE)
+ prompt = upscale_step.get("prompt", DEFAULT_UPSCALE_PROMPT)
+ prompt_2 = upscale_step.get("prompt_2", None)
+ negative_prompt = upscale_step.get("negative_prompt", None)
+ negative_prompt_2 = upscale_step.get("negative_prompt_2", None)
+ strength = upscale_step.get("strength", None)
+ controlnets = upscale_step.get("controlnets", None)
+ scheduler = upscale_step.get("scheduler", self.scheduler)
+ tiling_stride = upscale_step.get("tiling_stride", DEFAULT_UPSCALE_TILING_STRIDE)
+ tiling_size = upscale_step.get("tiling_size", DEFAULT_UPSCALE_TILING_SIZE)
+ tiling_mask_type = upscale_step.get("tiling_mask_type", None)
+ tiling_mask_kwargs = upscale_step.get("tiling_mask_kwargs", None)
+ noise_offset = upscale_step.get("noise_offset", None)
+ noise_method = upscale_step.get("noise_method", None)
+ noise_blend_method = upscale_step.get("noise_blend_method", None)
+ refiner = self.refiner is not None and upscale_step.get("refiner", True)
+
+ prompt = self.merge_prompts( # type: ignore[assignment]
+ (prompt, 1.0),
+ (self.prompt, GLOBAL_PROMPT_UPSCALE_WEIGHT),
+ (self.model_prompt, MODEL_PROMPT_WEIGHT),
+ (self.refiner_prompt, MODEL_PROMPT_WEIGHT),
+ *[
+ (prompt_dict["positive"], GLOBAL_PROMPT_UPSCALE_WEIGHT)
+ for prompt_dict in (self.prompts if self.prompts is not None else [])
+ ]
+ )
+
+ prompt_2 = self.merge_prompts(
+ (prompt_2, 1.0),
+ (self.prompt_2, GLOBAL_PROMPT_UPSCALE_WEIGHT),
+ (self.model_prompt_2, MODEL_PROMPT_WEIGHT),
+ (self.refiner_prompt_2, MODEL_PROMPT_WEIGHT),
+ *[
+ (prompt_dict.get("positive_2", None), GLOBAL_PROMPT_UPSCALE_WEIGHT)
+ for prompt_dict in (self.prompts if self.prompts is not None else [])
+ ]
+ )
+
+ negative_prompt = self.merge_prompts(
+ (negative_prompt, 1.0),
+ (self.negative_prompt, GLOBAL_PROMPT_UPSCALE_WEIGHT),
+ (self.model_negative_prompt, MODEL_PROMPT_WEIGHT),
+ (self.refiner_negative_prompt, MODEL_PROMPT_WEIGHT),
+ *[
+ (prompt_dict.get("negative", None), GLOBAL_PROMPT_UPSCALE_WEIGHT)
+ for prompt_dict in (self.prompts if self.prompts is not None else [])
+ ]
+ )
+
+ negative_prompt_2 = self.merge_prompts(
+ (negative_prompt_2, 1.0),
+ (self.negative_prompt_2, GLOBAL_PROMPT_UPSCALE_WEIGHT),
+ (self.model_negative_prompt_2, MODEL_PROMPT_WEIGHT),
+ (self.refiner_negative_prompt_2, MODEL_PROMPT_WEIGHT),
+ *[
+ (prompt_dict.get("negative_2", None), GLOBAL_PROMPT_UPSCALE_WEIGHT)
+ for prompt_dict in (self.prompts if self.prompts is not None else [])
+ ]
+ )
+
+ @contextmanager
+ def get_upscale_image() -> Iterator[Callable[[Image], Image]]:
+ if method in ["esrgan", "esrganime", "gfpgan"]:
+ if refiner:
+ pipeline.unload_pipeline("clearing memory for upscaler")
+ pipeline.unload_inpainter("clearing memory for upscaler")
+ pipeline.offload_refiner()
+ else:
+ pipeline.offload_pipeline()
+ pipeline.offload_animator()
+ pipeline.offload_inpainter()
+ pipeline.unload_refiner("clearing memory for upscaler")
+ if method == "gfpgan":
+ with pipeline.upscaler.gfpgan(tile=512) as upscale:
+ yield upscale
+ else:
+ with pipeline.upscaler.esrgan(tile=512, anime=method=="esrganime") as upscale:
+ yield upscale
+ elif method in PIL_INTERPOLATION:
+ def pil_resize(image: Image) -> Image:
+ return image.resize(
+ (int(width * amount), int(height * amount)),
+ resample=PIL_INTERPOLATION[method]
+ )
+ yield pil_resize
+ else:
+ logger.error(f"Unknown method {method}")
+ def no_resize(image: Image) -> Image:
+ return image
+ yield no_resize
+
+ if task_callback:
+ task_callback(f"Upscaling samples")
+
+ with get_upscale_image() as upscale_image:
+ if progress_callback is not None:
+ progress_callback(0, len(images), 0.0)
+ for i, image in enumerate(images):
+ upscale_start = datetime.now()
+ if nsfw is not None and nsfw[i]:
+ logger.debug(f"Image {i} had NSFW content, not upscaling.")
+ continue
+ logger.debug(f"Upscaling sample {i} by {amount} using {method}")
+ images[i] = upscale_image(image)
+ upscale_time = (datetime.now() - upscale_start).total_seconds()
+ if progress_callback:
+ progress_callback(i+1, len(images), 1/upscale_time)
+
+ if image_callback:
+ image_callback(images)
+
+ if strength is not None and strength > 0:
+ if task_callback:
+ task_callback("Preparing upscale pipeline")
+
+ if refiner:
+ # Refiners have safety disabled from the jump
+ logger.debug("Using refiner for upscaling.")
+ re_enable_safety = False
+ tiling_size = max(tiling_size, pipeline.refiner_size)
+ tiling_stride = min(tiling_stride, pipeline.refiner_size // 2)
+ else:
+ # Disable pipeline safety here, it gives many false positives when upscaling.
+ # We'll re-enable it after.
+ logger.debug("Using base pipeline for upscaling.")
+ re_enable_safety = pipeline.safe
+ if animation_frames:
+ tiling_size = max(tiling_size, pipeline.animator_size)
+ tiling_stride = min(tiling_stride, pipeline.animator_size // 2)
+ else:
+ tiling_size = max(tiling_size, pipeline.size)
+ tiling_stride = min(tiling_stride, pipeline.size // 2)
+ pipeline.safe = False
+
+ if scheduler is not None:
+ pipeline.scheduler = scheduler
+
+ if animation_frames:
+ upscaled_images = [images]
+ else:
+ upscaled_images = images
+
+ for i, image in enumerate(upscaled_images):
+ if nsfw is not None and nsfw[i]:
+ logger.debug(f"Image {i} had NSFW content, not upscaling.")
+ continue
+
+ if isinstance(image, list):
+ width, height = image[0].size
+ else:
+ width, height = image.size
+
+ kwargs = {
+ "width": width,
+ "height": height,
+ "image": image,
+ "num_results_per_prompt": 1,
+ "prompt": prompt,
+ "prompt_2": prompt_2,
+ "negative_prompt": negative_prompt,
+ "negative_prompt_2": negative_prompt_2,
+ "strength": strength,
+ "num_inference_steps": num_inference_steps,
+ "guidance_scale": guidance_scale,
+ "tiling_size": tiling_size,
+ "tiling_stride": tiling_stride,
+ "tiling_mask_type": tiling_mask_type,
+ "tiling_mask_kwargs": tiling_mask_kwargs,
+ "progress_callback": progress_callback,
+ "latent_callback": image_callback,
+ "latent_callback_type": "pil",
+ "latent_callback_steps": image_callback_steps,
+ "noise_offset": noise_offset,
+ "noise_method": noise_method,
+ "noise_blend_method": noise_blend_method,
+ "animation_frames": animation_frames,
+ "motion_scale": invocation_kwargs.get("motion_scale", None),
+ "tile": invocation_kwargs.get("tile", None),
+ "loop": invocation_kwargs.get("loop", False)
+ }
+
+ if controlnets is not None:
+ if not isinstance(controlnets, list):
+ controlnets = [controlnets] # type: ignore[unreachable]
+
+ controlnet_names = []
+ controlnet_weights = []
+
+ for controlnet in controlnets:
+ if isinstance(controlnet, tuple):
+ controlnet, weight = controlnet
+ else:
+ weight = 1.0
+ if controlnet not in controlnet_names:
+ controlnet_names.append(controlnet)
+ controlnet_weights.append(weight)
+
+ logger.debug(f"Enabling controlnet(s) {controlnet_names} for upscaling")
+
+ if refiner:
+ pipeline.refiner_controlnets = controlnet_names
+ upscale_pipline = pipeline.refiner_pipeline
+ is_sdxl = pipeline.refiner_is_sdxl
+ elif animation_frames:
+ pipeline.animator_controlnets = controlnet_names
+ upscale_pipeline = pipeline.animator_pipeline
+ is_sdxl = pipeline.animator_is_sdxl
+ else:
+ pipeline.controlnets = controlnet_names
+ upscale_pipeline = pipeline.pipeline
+ is_sdxl = pipeline.is_sdxl
+
+ controlnet_unique_names = list(set(controlnet_names))
+
+ with pipeline.control_image_processor.processors(*controlnet_unique_names) as controlnet_processors:
+ controlnet_processor_dict = dict(zip(controlnet_unique_names, controlnet_processors))
+
+ def get_processed_image(controlnet: str) -> Union[Image, List[Image]]:
+ if isinstance(image, list):
+ return [
+ controlnet_processor_dict[controlnet](img)
+ for img in image
+ ]
+ else:
+ return controlnet_processor_dict[controlnet](image)
+
+ kwargs["control_images"] = dict([
+ (
+ controlnet_name,
+ [(
+ get_processed_image(controlnet_name),
+ controlnet_weight
+ )]
+ )
+ for controlnet_name, controlnet_weight in zip(controlnet_names, controlnet_weights)
+ ])
+
+ elif refiner:
+ pipeline.refiner_controlnets = None
+ upscale_pipeline = pipeline.refiner_pipeline
+ elif animation_frames:
+ pipeline.animator_controlnets = None
+ upscale_pipeline = pipeline.animator_pipeline
+ else:
+ pipeline.controlnets = None
+ upscale_pipeline = pipeline.pipeline
+
+ logger.debug(f"Upscaling sample {i} with arguments {kwargs}")
+ pipeline.stop_keepalive() # Stop here to kill during upscale diffusion
+ if task_callback:
+ task_callback(f"Re-diffusing Upscaled Sample {i+1}")
+
+ image = upscale_pipeline(
+ generator=pipeline.generator,
+ device=pipeline.device,
+ offload_models=pipeline.pipeline_sequential_onload,
+ **kwargs
+ ).images
+ pipeline.start_keepalive() # Return keepalive between iterations
+
+ if animation_frames:
+ images = image
+ else:
+ images[i] = image[0]
+ if image_callback is not None:
+ image_callback(images)
+
+ if re_enable_safety:
+ pipeline.safe = True
+ if refiner:
+ logger.debug("Offloading refiner for next inference.")
+ pipeline.refiner_controlnets = None
+ pipeline.offload_refiner()
+ elif animation_frames:
+ pipeline.animator_controlnets = None # Make sure we reset controlnets
+ else:
+ pipeline.controlnets = None # Make sure we reset controlnets
+
+ return images, nsfw
+
+ def format_output(
+ self,
+ images: List[Image],
+ nsfw: List[bool]
+ ) -> StableDiffusionPipelineOutput:
+ """
+ Adds Enfugue metadata to an image result
+ """
+ from PIL import Image
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+
+ metadata_dict = self.serialize()
+ redact_images_from_metadata(metadata_dict)
+ formatted_images = []
+
+ for i, image in enumerate(images):
+ byte_io = io.BytesIO()
+ metadata = PngInfo()
+ metadata.add_text("EnfugueGenerationData", Serializer.serialize(metadata_dict))
+ image.save(byte_io, format="PNG", pnginfo=metadata)
+ formatted_images.append(Image.open(byte_io))
+
+ return StableDiffusionPipelineOutput(
+ images=formatted_images,
+ nsfw_content_detected=nsfw
+ )
diff --git a/src/python/enfugue/diffusion/manager.py b/src/python/enfugue/diffusion/manager.py
index 80d0a417..52944ef8 100644
--- a/src/python/enfugue/diffusion/manager.py
+++ b/src/python/enfugue/diffusion/manager.py
@@ -16,6 +16,7 @@
from pibble.api.configuration import APIConfiguration
from pibble.api.exceptions import ConfigurationError
from pibble.util.files import dump_json, load_json
+from pibble.util.numeric import human_size
from enfugue.util import logger, check_download, check_make_directory, find_file_in_directory
from enfugue.diffusion.constants import *
@@ -35,6 +36,7 @@
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
from enfugue.diffusion.pipeline import EnfugueStableDiffusionPipeline
from enfugue.diffusion.support import ControlImageProcessor, Upscaler, IPAdapter, BackgroundRemover
+ from enfugue.diffusion.animate.pipeline import EnfugueAnimateStableDiffusionPipeline
def noop(*args: Any) -> None:
"""
@@ -110,6 +112,11 @@ class DiffusionPipelineManager:
DEFAULT_CHUNK = 64
DEFAULT_SIZE = 512
+ DEFAULT_TEMPORAL_CHUNK = 4
+ DEFAULT_TEMPORAL_SIZE = 16
+
+ is_default_inpainter: bool = False
+ is_default_animator: bool = False
_keepalive_thread: KeepaliveThread
_keepalive_callback: Callable[[], None]
@@ -117,17 +124,20 @@ class DiffusionPipelineManager:
_pipeline: EnfugueStableDiffusionPipeline
_refiner_pipeline: EnfugueStableDiffusionPipeline
_inpainter_pipeline: EnfugueStableDiffusionPipeline
- _size: int
- _refiner_size: int
- _inpainter_size: int
+ _animator_pipeline: EnfugueAnimateStableDiffusionPipeline
_task_callback: Optional[Callable[[str], None]] = None
- def __init__(self, configuration: Optional[APIConfiguration] = None, optimize: bool = True) -> None:
+ def __init__(
+ self,
+ configuration: Optional[APIConfiguration] = None,
+ optimize: bool = True
+ ) -> None:
self.configuration = APIConfiguration()
if configuration:
self.configuration = configuration
if optimize:
self.optimize_configuration()
+ self.hijack_downloads()
def optimize_configuration(self) -> None:
"""
@@ -147,6 +157,28 @@ def task_callback(self, message: str) -> None:
"""
if getattr(self, "_task_callback", None) is not None:
self._task_callback(message) # type: ignore[misc]
+ else:
+ logger.debug(message)
+
+ def hijack_downloads(self) -> None:
+ """
+ Steals huggingface hub HTTP GET to report back to the UI.
+ """
+ import huggingface_hub
+ import huggingface_hub.file_download
+
+ huggingface_http_get = huggingface_hub.file_download.http_get
+
+ def http_get(url: str, *args: Any, **kwargs: Any) -> Any:
+ """
+ Call the task callback then execute the standard function
+ """
+ if self.offline:
+ raise ValueError(f"Offline mode enabled, but need to download {url}. Exiting.")
+ self.task_callback(f"Downloading {url}")
+ return huggingface_http_get(url, *args, **kwargs)
+
+ huggingface_hub.file_download.http_get = http_get
def check_download_model(self, local_dir: str, remote_url: str) -> str:
"""
@@ -160,7 +192,16 @@ def check_download_model(self, local_dir: str, remote_url: str) -> str:
if self.offline:
raise ValueError(f"File {output_file} does not exist in {local_dir} and offline mode is enabled, refusing to download from {remote_url}")
self.task_callback(f"Downloading {remote_url}")
- check_download(remote_url, output_path)
+
+ def progress_callback(written_bytes: int, total_bytes: int) -> None:
+ percentage = (written_bytes / total_bytes) * 100.0
+ self.task_callback(f"Downloading {remote_url}: {percentage:0.1f}% ({human_size(written_bytes)}/{human_size(total_bytes)})")
+
+ check_download(
+ remote_url,
+ output_path,
+ progress_callback=progress_callback
+ )
return output_path
@classmethod
@@ -201,7 +242,21 @@ def safe(self, val: bool) -> None:
"""
if val != getattr(self, "_safe", None):
self._safe = val
- self.unload_pipeline("safety checking enabled or disabled")
+ if hasattr(self, "_pipeline"):
+ if self._pipeline.safety_checker is None and val:
+ self.unload_pipeline("safety checking enabled")
+ else:
+ self._pipeline.safety_checking_disabled = not val
+ if hasattr(self, "_inpainter_pipeline"):
+ if self._inpainter_pipeline.safety_checker is None and val:
+ self.unload_inpainter("safety checking enabled")
+ else:
+ self._inpainter_pipeline.safety_checking_disabled = not val
+ if hasattr(self, "_animator_pipeline"):
+ if self._animator_pipeline.safety_checker is None and val:
+ self.unload_animator("safety checking enabled")
+ else:
+ self._animator_pipeline.safety_checking_disabled = not val
@property
def device(self) -> torch.device:
@@ -240,19 +295,19 @@ def clear_memory(self) -> None:
if self.device.type == "cuda":
import torch
import torch.cuda
-
torch.cuda.empty_cache()
+ torch.cuda.synchronize()
elif self.device.type == "mps":
import torch
import torch.mps
-
torch.mps.empty_cache()
+ torch.mps.synchronize()
gc.collect()
@property
def seed(self) -> int:
"""
- Gets the seed. If there is none, creates a random one once.
+ Gets the seed. If there is None, creates a random one once.
"""
if not hasattr(self, "_seed"):
self._seed = self.configuration.get("enfugue.seed", random.randint(0, 2**63 - 1))
@@ -368,12 +423,12 @@ def get_scheduler_class(
kwargs: Dict[str, Any] = {}
if not scheduler:
return None
- elif scheduler in ["dpmsm", "dpmsmk", "dpmsmka"]:
+ elif scheduler in ["dpmsm", "dpmsms", "dpmsmk", "dpmsmka"]:
from diffusers.schedulers import DPMSolverMultistepScheduler
+ if scheduler in ["dpmsms", "dpmsmka"]:
+ kwargs["algorithm_type"] = "sde-dpmsolver++"
if scheduler in ["dpmsmk", "dpmsmka"]:
kwargs["use_karras_sigmas"] = True
- if scheduler == "dpmsmka":
- kwargs["algorithm_type"] = "sde-dpmsolver++"
return (DPMSolverMultistepScheduler, kwargs)
elif scheduler in ["dpmss", "dpmssk"]:
from diffusers.schedulers import DPMSolverSinglestepScheduler
@@ -449,7 +504,18 @@ def scheduler(
if not new_scheduler:
if hasattr(self, "_scheduler"):
delattr(self, "_scheduler")
- self.unload_pipeline("returning to default scheduler")
+ if hasattr(self, "_pipeline"):
+ logger.debug("Reverting pipeline scheduler to default.")
+ self._pipeline.revert_scheduler()
+ if hasattr(self, "_inpainter_pipeline"):
+ logger.debug("Reverting inpainter pipeline scheduler to default.")
+ self._inpainter_pipeline.revert_scheduler()
+ if hasattr(self, "_refiner_pipeline"):
+ logger.debug("Reverting refiner pipeline scheduler to default.")
+ self._refiner_pipeline.revert_scheduler()
+ if hasattr(self, "_animator_pipeline"):
+ logger.debug("Reverting animator pipeline scheduler to default.")
+ self._animator_pipeline.revert_scheduler()
return
scheduler_class = self.get_scheduler_class(new_scheduler)
scheduler_config: Dict[str, Any] = {}
@@ -470,6 +536,9 @@ def scheduler(
if hasattr(self, "_refiner_pipeline"):
logger.debug(f"Hot-swapping refiner pipeline scheduler.")
self._refiner_pipeline.scheduler = self.scheduler.from_config({**self._refiner_pipeline.scheduler_config, **self.scheduler_config}) # type: ignore
+ if hasattr(self, "_animator_pipeline"):
+ logger.debug(f"Hot-swapping animator pipeline scheduler.")
+ self._animator_pipeline.scheduler = self.scheduler.from_config({**self._animator_pipeline.scheduler_config, **self.scheduler_config}) # type: ignore
def get_vae_path(self, vae: Optional[str] = None) -> Optional[Union[str, Tuple[str, ...]]]:
"""
@@ -562,13 +631,6 @@ def get_vae(
logger.debug(f"Received KeyError on '{ex}' when instantiating VAE from single file, trying to use XL VAE loader fix.")
result = self.get_xl_vae(vae)
else:
- expected_vae_location = os.path.join(self.engine_cache_dir, "models--" + vae.replace("/", "--"))
-
- if not os.path.exists(expected_vae_location):
- if self.offline:
- raise IOError(f"Offline mode enabled, cannot download {vae} to {expected_vae_location}")
- self.task_callback(f"Downloading VAE weights from repository {vae}")
- logger.info(f"VAE {vae} does not exist in cache directory {self.engine_cache_dir}, it will be downloaded.")
result = AutoencoderKL.from_pretrained(
vae,
torch_dtype=self.dtype,
@@ -584,11 +646,6 @@ def get_vae_preview(self, use_xl: bool) -> AutoencoderTiny:
"""
from diffusers.models import AutoencoderTiny
repo = "madebyollin/taesdxl" if use_xl else "madebyollin/taesd"
- expected_path = os.path.join(self.engine_cache_dir, "models--{0}".format(repo.replace("/", "--")))
- if not os.path.exists(expected_path):
- if self.offline:
- raise IOError(f"Offline mode enabled, cannot download {repo} to {expected_path}")
- self.task_callback(f"Downloading preview VAE weights from repository {repo}")
return AutoencoderTiny.from_pretrained(
repo,
cache_dir=self.engine_cache_dir,
@@ -598,7 +655,7 @@ def get_vae_preview(self, use_xl: bool) -> AutoencoderTiny:
@property
def vae(self) -> Optional[AutoencoderKL]:
"""
- Gets the configured VAE (or none.)
+ Gets the configured VAE (or None.)
"""
if not hasattr(self, "_vae"):
self._vae = self.get_vae(self.vae_name)
@@ -649,7 +706,7 @@ def vae_name(self) -> Optional[str]:
@property
def refiner_vae(self) -> Optional[AutoencoderKL]:
"""
- Gets the configured refiner VAE (or none.)
+ Gets the configured refiner VAE (or None.)
"""
if not hasattr(self, "_refiner_vae"):
self._refiner_vae = self.get_vae(self.refiner_vae_name)
@@ -700,7 +757,7 @@ def refiner_vae_name(self) -> Optional[str]:
@property
def inpainter_vae(self) -> Optional[AutoencoderKL]:
"""
- Gets the configured inpainter VAE (or none.)
+ Gets the configured inpainter VAE (or None.)
"""
if not hasattr(self, "_inpainter_vae"):
self._inpainter_vae = self.get_vae(self.inpainter_vae_name)
@@ -748,73 +805,79 @@ def inpainter_vae_name(self) -> Optional[str]:
self._inpainter_vae_name = self.configuration.get("enfugue.vae.inpainter", None)
return self._inpainter_vae_name
+ @property
+ def animator_vae(self) -> Optional[AutoencoderKL]:
+ """
+ Gets the configured animator VAE (or None.)
+ """
+ if not hasattr(self, "_animator_vae"):
+ self._animator_vae = self.get_vae(self.animator_vae_name)
+ return self._animator_vae
+
+ @animator_vae.setter
+ def animator_vae(
+ self,
+ new_vae: Optional[str],
+ ) -> None:
+ """
+ Sets a new animator vae.
+ """
+ pretrained_path = self.get_vae_path(new_vae)
+ existing_vae = getattr(self, "_animator_vae", None)
+
+ if (
+ (not existing_vae and new_vae)
+ or (existing_vae and not new_vae)
+ or (existing_vae and new_vae and self.animator_vae_name != new_vae)
+ ):
+ if not new_vae:
+ self._animator_vae_name = None # type: ignore
+ self._animator_vae = None
+ self.unload_animator("VAE resetting to default")
+ else:
+ self._animator_vae_name = new_vae
+ self._animator_vae = self.get_vae(pretrained_path)
+ if self.animator_tensorrt_is_ready and "vae" in self.TENSORRT_STAGES:
+ self.unload_animator("VAE changing")
+ elif hasattr(self, "_animator_pipeline"):
+ logger.debug(f"Hot-swapping animator pipeline VAE to {new_vae}")
+ self._animator_pipeline.vae = self._animator_vae # type: ignore [assignment]
+ if self.animator_is_sdxl:
+ self._animator_pipeline.register_to_config( # type: ignore[attr-defined]
+ force_full_precision_vae = new_vae in ["xl", "stabilityai/sdxl-vae"]
+ )
+
+ @property
+ def animator_vae_name(self) -> Optional[str]:
+ """
+ Gets the name of the VAE, if one was set.
+ """
+ if not hasattr(self, "_animator_vae_name"):
+ self._animator_vae_name = self.configuration.get("enfugue.vae.animator", None)
+ return self._animator_vae_name
+
@property
def size(self) -> int:
"""
- Gets the base engine size in pixels when chunking (default always.)
+ Gets the trained size of the model
"""
if not hasattr(self, "_size"):
return 1024 if self.is_sdxl else 512
return self._size
- @size.setter
- def size(self, new_size: Optional[int]) -> None:
- """
- Sets the base engine size in pixels.
- """
- if new_size is None:
- if hasattr(self, "_size"):
- check_reload_size = self._size
- delattr(self, "_size")
- if check_reload_size != self.size and self.tensorrt_is_ready:
- self.unload_pipeline("engine size changing")
- elif hasattr(self, "_pipeline"):
- logger.debug("setting pipeline engine size in place.")
- self._pipeline.engine_size = self.size
- return
- if hasattr(self, "_size") and self._size != new_size:
- if self.tensorrt_is_ready:
- self.unload_pipeline("engine size changing")
- elif hasattr(self, "_pipeline"):
- logger.debug("Setting pipeline engine size in-place.")
- self._pipeline.engine_size = new_size
- self._size = new_size
-
@property
def refiner_size(self) -> int:
"""
- Gets the refiner engine size in pixels when chunking (default always.)
+ Gets the trained size of the refiner
"""
if not hasattr(self, "_refiner_size"):
return 1024 if self.refiner_is_sdxl else 512
return self._refiner_size
- @refiner_size.setter
- def refiner_size(self, new_refiner_size: Optional[int]) -> None:
- """
- Sets the refiner engine size in pixels.
- """
- if new_refiner_size is None:
- if hasattr(self, "_refiner_size"):
- if self._refiner_size != self.size and self.refiner_tensorrt_is_ready:
- self.unload_refiner("engine size changing")
- elif hasattr(self, "_refiner_pipeline"):
- logger.debug("Setting refiner engine size in-place.")
- self._refiner_pipeline.engine_size = self.size
- delattr(self, "_refiner_size")
- elif hasattr(self, "_refiner_size") and self._refiner_size != new_refiner_size:
- if self.refiner_tensorrt_is_ready:
- self.unload_refiner("engine size changing")
- elif hasattr(self, "_refiner_pipeline"):
- logger.debug("Setting refiner engine size in-place.")
- self._refiner_pipeline.engine_size = new_refiner_size
- if new_refiner_size is not None:
- self._refiner_size = new_refiner_size
-
@property
def inpainter_size(self) -> int:
"""
- Gets the inpainter engine size in pixels when chunking (default always.)
+ Gets the trained size of the inpainter
"""
if not hasattr(self, "_inpainter_size"):
if self.inpainter:
@@ -822,45 +885,149 @@ def inpainter_size(self) -> int:
return self.size
return self._inpainter_size
- @inpainter_size.setter
- def inpainter_size(self, new_inpainter_size: Optional[int]) -> None:
+ @property
+ def animator_size(self) -> int:
"""
- Sets the inpainter engine size in pixels.
+ Gets the trained size of the animator
"""
- if new_inpainter_size is None:
- if hasattr(self, "_inpainter_size"):
- if self._inpainter_size != self.size and self.inpainter_tensorrt_is_ready:
- self.unload_inpainter("engine size changing")
- elif hasattr(self, "_inpainter_pipeline"):
- logger.debug("Setting inpainter engine size in-place.")
- self._inpainter_pipeline.engine_size = self.size
- delattr(self, "_inpainter_size")
- elif hasattr(self, "_inpainter_size") and self._inpainter_size != new_inpainter_size:
- if self.inpainter_tensorrt_is_ready:
- self.unload_inpainter("engine size changing")
- elif hasattr(self, "_inpainter_pipeline"):
- logger.debug("Setting inpainter engine size in-place.")
- self._inpainter_pipeline.engine_size = new_inpainter_size
- if new_inpainter_size is not None:
- self._inpainter_size = new_inpainter_size
+ if not hasattr(self, "_animator_size"):
+ if self.animator:
+ return 1024 if self.animator_is_sdxl else 512
+ return self.size
+ return self._animator_size
+
+ @property
+ def tiling_size(self) -> Optional[int]:
+ """
+ Gets the tiling size in pixels.
+ """
+ if not hasattr(self, "_tiling_size"):
+ self._tiling_size = self.configuration.get("enfugue.tile.size", None)
+ return self._tiling_size
+
+ @tiling_size.setter
+ def tiling_size(self, new_tiling_size: Optional[int]) -> None:
+ """
+ Sets the new tiling size. This will require a restart if pipelines are loaded and using tensorrt.
+ """
+ if (
+ (self.tiling_size is None and new_tiling_size is not None) or
+ (self.tiling_size is not None and new_tiling_size is None) or
+ (self.tiling_size is not None and new_tiling_size is not None and self.tiling_size != new_tiling_size)
+ ):
+ if hasattr(self, "_pipeline") and self.tensorrt_is_ready:
+ self.unload_pipeline("engine tiling size changing")
+ if hasattr(self, "_inpainter_pipeline") and self.inpainter_tensorrt_is_ready:
+ self.unload_inpainter("engine tiling size changing")
+ if hasattr(self, "_refiner_pipeline") and self.refiner_tensorrt_is_ready:
+ self.unload_refiner("engine tiling size changing")
+ if hasattr(self, "_animator_pipeline") and self.animator_tensorrt_is_ready:
+ self.unload_animator("engine tiling size changing")
+
+ @property
+ def tiling_stride(self) -> int:
+ """
+ Gets the chunking size in pixels.
+ """
+ if not hasattr(self, "_tiling_stride"):
+ self._tiling_stride = int(
+ self.configuration.get("enfugue.tile.stride", self.size // 4)
+ )
+ return self._tiling_stride
+
+ @tiling_stride.setter
+ def tiling_stride(self, new_tiling_stride: int) -> None:
+ """
+ Sets the new tiling stride. This doesn't require a restart.
+ """
+ self._tiling_stride = new_tiling_stride
+
+ @property
+ def tensorrt_size(self) -> int:
+ """
+ Gets the size of an active tensorrt engine.
+ """
+ if self.tiling_size is not None:
+ return self.tiling_size
+ return self.size
@property
- def chunking_size(self) -> int:
+ def inpainter_tensorrt_size(self) -> int:
+ """
+ Gets the size of an active inpainter tensorrt engine.
+ """
+ if self.tiling_size is not None:
+ return self.tiling_size
+ return self.inpainter_size
+
+ @property
+ def refiner_tensorrt_size(self) -> int:
+ """
+ Gets the size of an active refiner tensorrt engine.
+ """
+ if self.tiling_size is not None:
+ return self.tiling_size
+ return self.refiner_size
+
+ @property
+ def animator_tensorrt_size(self) -> int:
+ """
+ Gets the size of an active animator tensorrt engine.
+ """
+ if self.tiling_size is not None:
+ return self.tiling_size
+ return self.animator_size
+
+ @property
+ def frame_window_size(self) -> int:
+ """
+ Gets the animator frame window engine size in frames when chunking (default always.)
+ """
+ if not hasattr(self, "_frame_window_size"):
+ self._frame_window_size = self.configuration.get("enfugue.frames", DiffusionPipelineManager.DEFAULT_TEMPORAL_SIZE)
+ return self._frame_window_size
+
+ @frame_window_size.setter
+ def frame_window_size(self, new_frame_window_size: Optional[int]) -> None:
+ """
+ Sets the animator engine size in pixels.
+ """
+ if new_frame_window_size is None:
+ if hasattr(self, "_frame_window_size"):
+ if self._frame_window_size != self.frame_window_size and self.tensorrt_is_ready:
+ self.unload_animator("engine frame window size changing")
+ elif hasattr(self, "_animator_pipeline"):
+ logger.debug("Setting animator engine size in-place.")
+ self._animator_pipeline.frame_window_size = new_frame_window_size # type: ignore[assignment]
+ delattr(self, "_frame_window_size")
+ elif hasattr(self, "_frame_window_size") and self._frame_window_size != new_frame_window_size:
+ if self.tensorrt_is_ready:
+ self.unload_animator("engine size changing")
+ elif hasattr(self, "_animator_pipeline"):
+ logger.debug("Setting animator frame window engine size in-place.")
+ self._animator_pipeline.frame_window_size = new_frame_window_size
+ if new_frame_window_size is not None:
+ self._frame_window_size = new_frame_window_size
+
+ @property
+ def frame_window_stride(self) -> Optional[int]:
"""
Gets the chunking size in pixels.
"""
- if not hasattr(self, "_chunking_size"):
- self._chunking_size = int(
- self.configuration.get("enfugue.chunk.size", DiffusionPipelineManager.DEFAULT_CHUNK)
+ if not hasattr(self, "_frame_window_stride"):
+ self._frame_window_stride = int(
+ self.configuration.get("enfugue.temporal.size", DiffusionPipelineManager.DEFAULT_TEMPORAL_CHUNK)
)
- return self._chunking_size
+ return self._frame_window_stride
- @chunking_size.setter
- def chunking_size(self, new_chunking_size: int) -> None:
+ @frame_window_stride.setter
+ def frame_window_stride(self, new_frame_window_stride: Optional[int]) -> None:
"""
Sets the new chunking size. This doesn't require a restart.
"""
- self._chunking_size = new_chunking_size
+ self._frame_window_stride = new_frame_window_stride # type: ignore[assignment]
+ if hasattr(self, "_animator_pipeline"):
+ self._animator_pipeline.frame_window_stride = new_frame_window_stride # type: ignore[assignment]
@property
def engine_root(self) -> str:
@@ -958,6 +1125,18 @@ def engine_tensorrt_dir(self) -> str:
check_make_directory(path)
return path
+ @property
+ def engine_motion_dir(self) -> str:
+ """
+ Gets where motion modules are saved.
+ """
+ path = self.configuration.get("enfugue.engine.motion", "~/.cache/enfugue/motion")
+ if path.startswith("~"):
+ path = os.path.expanduser(path)
+ path = os.path.realpath(path)
+ check_make_directory(path)
+ return path
+
@property
def model_tensorrt_dir(self) -> str:
"""
@@ -989,6 +1168,17 @@ def inpainter_tensorrt_dir(self) -> str:
check_make_directory(path)
return path
+ @property
+ def animator_tensorrt_dir(self) -> str:
+ """
+ Gets where tensorrt engines will be built per animator.
+ """
+ if not self.animator_name:
+ raise ValueError("No animator set")
+ path = os.path.join(self.engine_tensorrt_dir, self.animator_name)
+ check_make_directory(path)
+ return path
+
@property
def engine_diffusers_dir(self) -> str:
"""
@@ -1032,6 +1222,17 @@ def inpainter_diffusers_dir(self) -> str:
check_make_directory(path)
return path
+ @property
+ def animator_diffusers_dir(self) -> str:
+ """
+ Gets where the diffusers cache will be for the current animator.
+ """
+ if not self.animator_name:
+ raise ValueError("No animator set")
+ path = os.path.join(self.engine_diffusers_dir, f"{self.animator_name}-animator")
+ check_make_directory(path)
+ return path
+
@property
def engine_onnx_dir(self) -> str:
"""
@@ -1075,6 +1276,17 @@ def inpainter_onnx_dir(self) -> str:
check_make_directory(path)
return path
+ @property
+ def animator_onnx_dir(self) -> str:
+ """
+ Gets where the onnx cache will be for the current animator.
+ """
+ if not self.animator_name:
+ raise ValueError("No animator set")
+ path = os.path.join(self.engine_onnx_dir, self.animator_name)
+ check_make_directory(path)
+ return path
+
@staticmethod
def get_clip_key(
size: int, lora: List[Tuple[str, float]], lycoris: List[Tuple[str, float]], inversion: List[str], **kwargs: Any
@@ -1213,6 +1425,42 @@ def inpainter_onnx_clip_dir(self) -> str:
self.write_model_metadata(metadata_path)
return path
+ @property
+ def animator_clip_key(self) -> str:
+ """
+ Gets the CLIP key for the current configuration.
+ """
+ return DiffusionPipelineManager.get_clip_key(
+ size=self.animator_size,
+ lora=[],
+ lycoris=[],
+ inversion=[]
+ )
+
+ @property
+ def animator_tensorrt_clip_dir(self) -> str:
+ """
+ Gets where the tensorrt CLIP engine will be stored.
+ """
+ path = os.path.join(self.animator_tensorrt_dir, "clip", self.animator_clip_key)
+ check_make_directory(path)
+ metadata_path = os.path.join(path, "metadata.json")
+ if not os.path.exists(metadata_path):
+ self.write_model_metadata(metadata_path)
+ return path
+
+ @property
+ def animator_onnx_clip_dir(self) -> str:
+ """
+ Gets where the onnx CLIP engine will be stored.
+ """
+ path = os.path.join(self.animator_onnx_dir, "clip", self.animator_clip_key)
+ check_make_directory(path)
+ metadata_path = os.path.join(path, "metadata.json")
+ if not os.path.exists(metadata_path):
+ self.write_model_metadata(metadata_path)
+ return path
+
@staticmethod
def get_unet_key(
size: int,
@@ -1253,7 +1501,7 @@ def model_unet_key(self) -> str:
Gets the UNET key for the current configuration.
"""
return DiffusionPipelineManager.get_unet_key(
- size=self.size,
+ size=self.tensorrt_size,
lora=self.lora_names_weights,
lycoris=self.lycoris_names_weights,
inversion=self.inversion_names,
@@ -1289,7 +1537,7 @@ def refiner_unet_key(self) -> str:
Gets the UNET key for the current configuration.
"""
return DiffusionPipelineManager.get_unet_key(
- size=self.refiner_size,
+ size=self.refiner_tensorrt_size,
lora=[],
lycoris=[],
inversion=[]
@@ -1325,7 +1573,7 @@ def inpainter_unet_key(self) -> str:
Gets the UNET key for the current configuration.
"""
return DiffusionPipelineManager.get_unet_key(
- size=self.inpainter_size,
+ size=self.inpainter_tensorrt_size,
lora=[],
lycoris=[],
inversion=[]
@@ -1355,6 +1603,42 @@ def inpainter_onnx_unet_dir(self) -> str:
self.write_model_metadata(metadata_path)
return path
+ @property
+ def animator_unet_key(self) -> str:
+ """
+ Gets the UNET key for the current configuration.
+ """
+ return DiffusionPipelineManager.get_unet_key(
+ size=self.animator_tensorrt_size,
+ lora=[],
+ lycoris=[],
+ inversion=[]
+ )
+
+ @property
+ def animator_tensorrt_unet_dir(self) -> str:
+ """
+ Gets where the tensorrt UNET engine will be stored for the animator.
+ """
+ path = os.path.join(self.animator_tensorrt_dir, "unet", self.animator_unet_key)
+ check_make_directory(path)
+ metadata_path = os.path.join(path, "metadata.json")
+ if not os.path.exists(metadata_path):
+ self.write_model_metadata(metadata_path)
+ return path
+
+ @property
+ def animator_onnx_unet_dir(self) -> str:
+ """
+ Gets where the onnx UNET engine will be stored for the animator.
+ """
+ path = os.path.join(self.animator_onnx_dir, "unet", self.animator_unet_key)
+ check_make_directory(path)
+ metadata_path = os.path.join(path, "metadata.json")
+ if not os.path.exists(metadata_path):
+ self.write_model_metadata(metadata_path)
+ return path
+
@staticmethod
def get_controlled_unet_key(
size: int,
@@ -1469,7 +1753,7 @@ def inpainter_controlled_unet_key(self) -> str:
Gets the UNET key for the current configuration.
"""
return DiffusionPipelineManager.get_controlled_unet_key(
- size=self.inpainter_size,
+ size=self.inpainter_tensorrt_size,
lora=[],
lycoris=[],
inversion=[]
@@ -1501,6 +1785,44 @@ def inpainter_onnx_controlled_unet_dir(self) -> str:
self.write_model_metadata(metadata_path)
return path
+ @property
+ def animator_controlled_unet_key(self) -> str:
+ """
+ Gets the UNET key for the current configuration.
+ """
+ return DiffusionPipelineManager.get_controlled_unet_key(
+ size=self.animator_tensorrt_size,
+ lora=[],
+ lycoris=[],
+ inversion=[]
+ )
+
+ @property
+ def animator_tensorrt_controlled_unet_dir(self) -> str:
+ """
+ Gets where the tensorrt Controlled UNet engine will be stored for the animator.
+ TODO: determine if this should exist.
+ """
+ path = os.path.join(self.animator_tensorrt_dir, "controlledunet", self.animator_controlled_unet_key)
+ check_make_directory(path)
+ metadata_path = os.path.join(path, "metadata.json")
+ if not os.path.exists(metadata_path):
+ self.write_model_metadata(metadata_path)
+ return path
+
+ @property
+ def animator_onnx_controlled_unet_dir(self) -> str:
+ """
+ Gets where the onnx Controlled UNet engine will be stored for the animator.
+ TODO: determine if this should exist.
+ """
+ path = os.path.join(self.animator_onnx_dir, "controlledunet", self.animator_controlled_unet_key)
+ check_make_directory(path)
+ metadata_path = os.path.join(path, "metadata.json")
+ if not os.path.exists(metadata_path):
+ self.write_model_metadata(metadata_path)
+ return path
+
@staticmethod
def get_vae_key(size: int, **kwargs: Any) -> str:
"""
@@ -1604,6 +1926,37 @@ def inpainter_onnx_vae_dir(self) -> str:
self.write_model_metadata(metadata_path)
return path
+ @property
+ def animator_vae_key(self) -> str:
+ """
+ Gets the UNET key for the current configuration.
+ """
+ return DiffusionPipelineManager.get_vae_key(size=self.animator_size)
+
+ @property
+ def animator_tensorrt_vae_dir(self) -> str:
+ """
+ Gets where the tensorrt VAE engine will be stored for the animator.
+ """
+ path = os.path.join(self.animator_tensorrt_dir, "vae", self.animator_vae_key)
+ check_make_directory(path)
+ metadata_path = os.path.join(path, "metadata.json")
+ if not os.path.exists(metadata_path):
+ self.write_model_metadata(metadata_path)
+ return path
+
+ @property
+ def animator_onnx_vae_dir(self) -> str:
+ """
+ Gets where the onnx VAE engine will be stored for the animator.
+ """
+ path = os.path.join(self.animator_onnx_dir, "vae", self.animator_vae_key)
+ check_make_directory(path)
+ metadata_path = os.path.join(path, "metadata.json")
+ if not os.path.exists(metadata_path):
+ self.write_model_metadata(metadata_path)
+ return path
+
@property
def tensorrt_is_supported(self) -> bool:
"""
@@ -1632,6 +1985,8 @@ def tensorrt_is_enabled(self, new_enabled: bool) -> None:
self.unload_pipeline("TensorRT enabled or disabled")
if new_enabled != self.tensorrt_is_enabled and self.inpainter_tensorrt_is_ready:
self.unload_inpainter("TensorRT enabled or disabled")
+ if new_enabled != self.tensorrt_is_enabled and self.animator_tensorrt_is_ready:
+ self.unload_animator("TensorRT enabled or disabled")
if new_enabled != self.tensorrt_is_enabled and self.refiner_tensorrt_is_ready:
self.unload_refiner("TensorRT enabled or disabled")
self._tensorrt_enabled = new_enabled
@@ -1709,6 +2064,31 @@ def inpainter_tensorrt_is_ready(self) -> bool:
trt_ready = trt_ready and os.path.exists(Engine.get_engine_path(self.inpainter_tensorrt_unet_dir))
return trt_ready
+ @property
+ def animator_tensorrt_is_ready(self) -> bool:
+ """
+ Checks to determine if Tensor RT is ready based on the existence of engines for the animator
+ """
+ if not self.tensorrt_is_supported:
+ return False
+ if self.animator is None:
+ return False
+ from enfugue.diffusion.rt.engine import Engine
+
+ trt_ready = True
+ if "vae" in self.TENSORRT_STAGES:
+ trt_ready = trt_ready and os.path.exists(Engine.get_engine_path(self.animator_tensorrt_vae_dir))
+ if "clip" in self.TENSORRT_STAGES:
+ trt_ready = trt_ready and os.path.exists(Engine.get_engine_path(self.animator_tensorrt_clip_dir))
+ if self.animator_controlnets or self.TENSORRT_ALWAYS_USE_CONTROLLED_UNET:
+ if "unet" in self.TENSORRT_STAGES:
+ trt_ready = trt_ready and os.path.exists(
+ Engine.get_engine_path(self.animator_tensorrt_controlled_unet_dir)
+ )
+ elif "unet" in self.TENSORRT_STAGES:
+ trt_ready = trt_ready and os.path.exists(Engine.get_engine_path(self.animator_tensorrt_unet_dir))
+ return trt_ready
+
@property
def build_tensorrt(self) -> bool:
"""
@@ -1731,6 +2111,8 @@ def build_tensorrt(self, new_build: bool) -> None:
self.unload_pipeline("preparing for TensorRT build")
if not self.inpainter_tensorrt_is_ready and self.tensorrt_is_supported:
self.unload_inpainter("preparing for TensorRT build")
+ if not self.animator_tensorrt_is_ready and self.tensorrt_is_supported:
+ self.unload_animator("preparing for TensorRT build")
if not self.refiner_tensorrt_is_ready and self.tensorrt_is_supported:
self.unload_refiner("preparing for TensorRT build")
@@ -1755,6 +2137,14 @@ def inpainter_use_tensorrt(self) -> bool:
"""
return (self.inpainter_tensorrt_is_ready or self.build_tensorrt) and self.tensorrt_is_enabled
+ @property
+ def animator_use_tensorrt(self) -> bool:
+ """
+ Gets the ultimate decision on whether the tensorrt pipeline should be used for the animator.
+ """
+ return False
+ # return (self.animator_tensorrt_is_ready or self.build_tensorrt) and self.tensorrt_is_enabled
+
@property
def use_directml(self) -> bool:
"""
@@ -1804,6 +2194,16 @@ def create_inpainter(self) -> bool:
"""
return bool(self.configuration.get("enfugue.pipeline.inpainter", True))
+ @property
+ def create_animator(self) -> bool:
+ """
+ Defines how to switch to inpainting.
+ """
+ configured = self.configuration.get("enfugue.pipeline.animator", None)
+ if configured is None:
+ return not self.is_sdxl
+ return configured
+
@property
def refiner_strength(self) -> float:
"""
@@ -1928,6 +2328,17 @@ def inpainter_pipeline_class(self) -> Type:
return EnfugueStableDiffusionPipeline
+ @property
+ def animator_pipeline_class(self) -> Type:
+ """
+ Gets the pipeline class to use.
+ """
+ if self.animator_use_tensorrt:
+ raise RuntimeError("No TensorRT animation pipeline exists yet.")
+ else:
+ from enfugue.diffusion.animate.pipeline import EnfugueAnimateStableDiffusionPipeline
+ return EnfugueAnimateStableDiffusionPipeline
+
def check_get_default_model(self, model: str) -> str:
"""
Checks if a model is a default model, in which case the remote URL is returned
@@ -1971,11 +2382,16 @@ def model(self, new_model: Optional[str]) -> None:
model = find_file_in_directory(self.engine_checkpoints_dir, model)
if not model:
raise ValueError(f"Cannot find model {new_model}")
+
model_name, _ = os.path.splitext(os.path.basename(model))
if self.model_name != model_name:
self.unload_pipeline("model changing")
- if not hasattr(self, "_inpainter") and getattr(self, "_inpainter_pipeline", None) is not None:
+ if self.is_default_animator and getattr(self, "_animator_pipeline", None) is not None:
+ self.unload_animator("base model changing")
+ self.is_default_animator = False
+ if self.is_default_inpainter and getattr(self, "_inpainter_pipeline", None) is not None:
self.unload_inpainter("base model changing")
+ self.is_default_inpainter = False
self._model = model
@property
@@ -2078,6 +2494,53 @@ def inpainter_name(self) -> Optional[str]:
if self.inpainter is None:
return None
return os.path.splitext(os.path.basename(self.inpainter))[0]
+
+ @property
+ def animator(self) -> Optional[str]:
+ """
+ Gets the configured animator.
+ """
+ if not hasattr(self, "_animator"):
+ self._animator = self.configuration.get("enfugue.animator", None)
+ return self._animator
+
+ @animator.setter
+ def animator(self, new_animator: Optional[str]) -> None:
+ """
+ Sets a new animator. Destroys the animator pipelline.
+ """
+ if new_animator is None:
+ self._animator = None
+ return
+ animator = self.check_get_default_model(new_animator)
+ if animator.startswith("http"):
+ animator = self.check_download_model(self.engine_checkpoints_dir, animator)
+ elif not os.path.isabs(animator):
+ animator = find_file_in_directory(self.engine_checkpoints_dir, animator) # type: ignore[assignment]
+ if not animator:
+ raise ValueError(f"Cannot find animator {new_animator}")
+
+ animator_name, _ = os.path.splitext(os.path.basename(animator))
+ if self.animator_name != animator_name:
+ self.unload_animator("model changing")
+
+ self._animator = animator
+
+ @property
+ def animator_name(self) -> Optional[str]:
+ """
+ Gets just the basename of the animator
+ """
+ if self.animator is None:
+ return None
+ return os.path.splitext(os.path.basename(self.animator))[0]
+
+ @property
+ def has_animator(self) -> bool:
+ """
+ Returns true if the animator is set.
+ """
+ return self.animator is not None
@property
def dtype(self) -> torch.dtype:
@@ -2092,9 +2555,6 @@ def dtype(self) -> torch.dtype:
if device_type == "cpu":
logger.debug("Inferencing on cpu, must use dtype bfloat16")
self._torch_dtype = torch.bfloat16
- elif device_type == "mps":
- logger.debug("Inferencing on mps, defaulting to dtype float16")
- self._torch_dtype = torch.float16
elif device_type == "cuda" and torch.version.hip:
logger.debug("Inferencing on rocm, must use dtype float32") # type: ignore[unreachable]
self._torch_dtype = torch.float
@@ -2138,6 +2598,7 @@ def dtype(self, new_dtype: Union[str, torch.dtype]) -> None:
self.unload_pipeline("data type changing")
self.unload_refiner("data type changing")
self.unload_inpainter("data type changing")
+ self.unload_animator("data type changing")
self._torch_dtype = new_dtype
@@ -2175,12 +2636,14 @@ def lora(self, new_lora: Optional[Union[str, List[str], Tuple[str, float], List[
for i, (model, weight) in enumerate(lora):
if model.startswith("http"):
- model = self.check_download_model(self.engine_lora_dir, model)
+ find_model = self.check_download_model(self.engine_lora_dir, model)
elif not os.path.isabs(model):
- model = find_file_in_directory(self.engine_lora_dir, model) # type: ignore[assignment]
- if not model:
+ find_model = find_file_in_directory(self.engine_lora_dir, model) # type: ignore[assignment]
+ else:
+ find_model = model
+ if not find_model:
raise ValueError(f"Cannot find LoRA model {model}")
- lora[i] = (model, weight)
+ lora[i] = (find_model, weight)
if getattr(self, "_lora", []) != lora:
self.unload_pipeline("LoRA changing")
@@ -2284,6 +2747,96 @@ def inversion_names(self) -> List[str]:
"""
return [os.path.splitext(os.path.basename(inversion))[0] for inversion in self.inversion]
+ @property
+ def reload_motion_module(self) -> bool:
+ """
+ Returns true if the motion module should be reloaded.
+ """
+ return getattr(self, "_reload_motion_module", False)
+
+ @reload_motion_module.setter
+ def reload_motion_module(self, reload: bool) -> None:
+ """
+ Sets if the motion module should be reloaded.
+ """
+ self._reload_motion_module = reload
+
+ @property
+ def motion_module(self) -> Optional[str]:
+ """
+ Gets optional configured non-default motion module.
+ """
+ return getattr(self, "_motion_module", None)
+
+ @motion_module.setter
+ def motion_module(self, new_module: Optional[str]) -> None:
+ """
+ Sets a new motion module or reverts to default.
+ """
+ if (
+ self.motion_module is None and new_module is not None or
+ self.motion_module is not None and new_module is None or
+ (
+ self.motion_module is not None and
+ new_module is not None and
+ self.motion_module != new_module
+ )
+ ):
+ self.reload_motion_module = True
+ if new_module is not None and not os.path.isabs(new_module):
+ new_module = os.path.join(self.engine_motion_dir, new_module)
+ if new_module is not None and not os.path.exists(new_module):
+ raise IOError(f"Cannot find or access motion module at {new_module}")
+ self._motion_module = new_module
+
+ @property
+ def position_encoding_truncate_length(self) -> Optional[int]:
+ """
+ An optional length (frames) to truncate position encoder tensors to
+ """
+ return getattr(self, "_position_encoding_truncate_length", None)
+
+ @position_encoding_truncate_length.setter
+ def position_encoding_truncate_length(self, new_length: Optional[int]) -> None:
+ """
+ Sets position encoder truncate length.
+ """
+ if (
+ self.position_encoding_truncate_length is None and new_length is not None or
+ self.position_encoding_truncate_length is not None and new_length is None or
+ (
+ self.position_encoding_truncate_length is not None and
+ new_length is not None and
+ self.position_encoding_truncate_length != new_length
+ )
+ ):
+ self.reload_motion_module = True
+ self._position_encoding_truncate_length = new_length
+
+ @property
+ def position_encoding_scale_length(self) -> Optional[int]:
+ """
+ An optional length (frames) to scale position encoder tensors to
+ """
+ return getattr(self, "_position_encoding_scale_length", None)
+
+ @position_encoding_scale_length.setter
+ def position_encoding_scale_length(self, new_length: Optional[int]) -> None:
+ """
+ Sets position encoder scale length.
+ """
+ if (
+ self.position_encoding_scale_length is None and new_length is not None or
+ self.position_encoding_scale_length is not None and new_length is None or
+ (
+ self.position_encoding_scale_length is not None and
+ new_length is not None and
+ self.position_encoding_scale_length != new_length
+ )
+ ):
+ self.reload_motion_module = True
+ self._position_encoding_scale_length = new_length
+
@property
def model_diffusers_cache_dir(self) -> Optional[str]:
"""
@@ -2338,6 +2891,24 @@ def inpainter_engine_cache_exists(self) -> bool:
"""
return self.inpainter_diffusers_cache_dir is not None
+ @property
+ def animator_diffusers_cache_dir(self) -> Optional[str]:
+ """
+ Ggets where the diffusers cache directory is saved for this animator, if there is any.
+ """
+ if os.path.exists(os.path.join(self.animator_diffusers_dir, "model_index.json")):
+ return self.animator_diffusers_dir
+ elif os.path.exists(os.path.join(self.animator_tensorrt_dir, "model_index.json")):
+ return self.animator_tensorrt_dir
+ return None
+
+ @property
+ def animator_engine_cache_exists(self) -> bool:
+ """
+ Gets whether or not the diffusers cache exists.
+ """
+ return self.animator_diffusers_cache_dir is not None
+
@property
def should_cache(self) -> bool:
"""
@@ -2358,6 +2929,16 @@ def should_cache_inpainter(self) -> bool:
return self.inpainter_is_sdxl
return configured in ["always", True]
+ @property
+ def should_cache_animator(self) -> bool:
+ """
+ Returns true if the animator model should always be cached.
+ """
+ configured = self.configuration.get("enfugue.pipeline.cache", None)
+ if configured == "xl":
+ return self.animator_is_sdxl
+ return configured in ["always", True]
+
@property
def should_cache_refiner(self) -> bool:
"""
@@ -2415,13 +2996,27 @@ def inpainter_is_sdxl(self) -> bool:
Otherwise, we guess by file name.
"""
if not self.inpainter_name:
- return False
+ return self.is_sdxl
if getattr(self, "_inpainter_pipeline", None) is not None:
return self._inpainter_pipeline.is_sdxl
if self.inpainter_diffusers_cache_dir is not None:
return os.path.exists(os.path.join(self.inpainter_diffusers_cache_dir, "text_encoder_2")) # type: ignore[arg-type]
return "xl" in self.inpainter_name.lower()
+ @property
+ def animator_is_sdxl(self) -> bool:
+ """
+ If the animator model is cached, we can know for sure by checking for sdxl-exclusive models.
+ Otherwise, we guess by file name.
+ """
+ if not self.animator_name:
+ return False
+ if getattr(self, "_animator_pipeline", None) is not None:
+ return self._animator_pipeline.is_sdxl
+ if self.animator_diffusers_cache_dir is not None:
+ return os.path.exists(os.path.join(self.animator_diffusers_cache_dir, "text_encoder_2")) # type: ignore[arg-type]
+ return "xl" in self.animator_name.lower()
+
def check_create_engine_cache(self) -> None:
"""
Converts a .ckpt file to the directory structure from diffusers
@@ -2477,6 +3072,25 @@ def check_create_inpainter_engine_cache(self) -> None:
del pipe
self.clear_memory()
+ def check_create_animator_engine_cache(self) -> None:
+ """
+ Converts a .safetensor file to diffusers cache
+ """
+ if not self.animator_engine_cache_exists and self.animator:
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
+ download_from_original_stable_diffusion_ckpt,
+ )
+
+ _, ext = os.path.splitext(self.animator)
+ pipe = download_from_original_stable_diffusion_ckpt(
+ self.animator,
+ num_in_channels=9 if "inpaint" in self.animator.lower() else 4,
+ from_safetensors=ext == ".safetensors"
+ ).to(torch_dtype=self.dtype)
+ pipe.save_pretrained(self.animator_diffusers_dir)
+ del pipe
+ self.clear_memory()
+
def swap_pipelines(self, to_gpu: EnfugueStableDiffusionPipeline, to_cpu: EnfugueStableDiffusionPipeline) -> None:
"""
Swaps pipelines in and out of GPU.
@@ -2503,8 +3117,8 @@ def pipeline(self) -> EnfugueStableDiffusionPipeline:
kwargs = {
"cache_dir": self.engine_cache_dir,
- "engine_size": self.size,
- "chunking_size": self.chunking_size,
+ "engine_size": self.tensorrt_size,
+ "tiling_stride": self.tiling_stride,
"requires_safety_checker": self.safe,
"torch_dtype": self.dtype,
"cache_dir": self.engine_cache_dir,
@@ -2542,7 +3156,10 @@ def pipeline(self) -> EnfugueStableDiffusionPipeline:
kwargs["tokenizer_2"] = None
kwargs["text_encoder_2"] = None
- kwargs["build_half"] = "16" in str(self.dtype)
+ if "16" in str(self.dtype):
+ kwargs["build_half"] = True
+ kwargs["variant"] = "fp16"
+
logger.debug(
f"Initializing TensorRT pipeline from diffusers cache directory at {self.model_diffusers_cache_dir}. Arguments are {redact(kwargs)}"
)
@@ -2560,9 +3177,14 @@ def pipeline(self) -> EnfugueStableDiffusionPipeline:
kwargs["text_encoder_2"] = None
if vae is not None:
kwargs["vae"] = vae
+
+ if "16" in str(self.dtype):
+ kwargs["variant"] = "fp16"
+
logger.debug(
f"Initializing pipeline from diffusers cache directory at {self.model_diffusers_cache_dir}. Arguments are {redact(kwargs)}"
)
+
pipeline = self.pipeline_class.from_pretrained(
self.model_diffusers_cache_dir,
local_files_only=self.offline,
@@ -2583,17 +3205,17 @@ def pipeline(self) -> EnfugueStableDiffusionPipeline:
# We can fix that here, though, by forcing full precision VAE
pipeline.register_to_config(force_full_precision_vae=True)
if self.should_cache:
- logger.debug("Saving pipeline to pretrained.")
+ self.task_callback("Saving pipeline to pretrained cache.")
pipeline.save_pretrained(self.model_diffusers_dir)
if not self.tensorrt_is_ready:
for lora, weight in self.lora:
- logger.debug(f"Adding LoRA {lora} to pipeline with weight {weight}")
+ self.task_callback(f"Adding LoRA {os.path.basename(lora)} to pipeline with weight {weight}")
pipeline.load_lora_weights(lora, multiplier=weight)
for lycoris, weight in self.lycoris:
- logger.debug(f"Adding lycoris {lycoris} to pipeline")
+ self.task_callback(f"Adding lycoris {os.path.basename(lycoris)} to pipeline")
pipeline.load_lycoris_weights(lycoris, multiplier=weight)
for inversion in self.inversion:
- logger.debug(f"Adding textual inversion {inversion} to pipeline")
+ self.task_callback(f"Adding textual inversion {os.path.basename(inversion)} to pipeline")
pipeline.load_textual_inversion(inversion)
# load scheduler
@@ -2629,8 +3251,8 @@ def refiner_pipeline(self) -> EnfugueStableDiffusionPipeline:
kwargs = {
"cache_dir": self.engine_cache_dir,
- "engine_size": self.refiner_size,
- "chunking_size": self.chunking_size,
+ "engine_size": self.refiner_tensorrt_size,
+ "tiling_stride": self.tiling_stride,
"torch_dtype": self.dtype,
"requires_safety_checker": False,
"force_full_precision_vae": self.refiner_is_sdxl and "16" not in self.refiner and (
@@ -2669,7 +3291,10 @@ def refiner_pipeline(self) -> EnfugueStableDiffusionPipeline:
kwargs["text_encoder_2"] = None
kwargs["tokenizer_2"] = None
- kwargs["build_half"] = "16" in str(self.dtype)
+ if "16" in str(self.dtype):
+ kwargs["build_half"] = True
+ kwargs["variant"] = "fp16"
+
logger.debug(
f"Initializing refiner TensorRT pipeline from diffusers cache directory at {self.refiner_diffusers_cache_dir}. Arguments are {redact(kwargs)}"
)
@@ -2690,11 +3315,17 @@ def refiner_pipeline(self) -> EnfugueStableDiffusionPipeline:
else:
kwargs["text_encoder_2"] = None
kwargs["tokenizer_2"] = None
+
if vae is not None:
kwargs["vae"] = vae
+
+ if "16" in str(self.dtype):
+ kwargs["variant"] = "fp16"
+
logger.debug(
f"Initializing refiner pipeline from diffusers cache directory at {self.refiner_diffusers_cache_dir}. Arguments are {redact(kwargs)}"
)
+
refiner_pipeline = self.refiner_pipeline_class.from_pretrained(
self.refiner_diffusers_cache_dir,
safety_checker=None,
@@ -2714,11 +3345,13 @@ def refiner_pipeline(self) -> EnfugueStableDiffusionPipeline:
load_safety_checker=False,
**kwargs,
)
+
if refiner_pipeline.is_sdxl and "16" not in self.refiner and (self.refiner_vae_name is None or "16" not in self.refiner_vae_name):
refiner_pipeline.register_to_config(force_full_precision_vae=True)
if self.should_cache_refiner:
- logger.debug("Saving pipeline to pretrained.")
+ self.task_callback("Saving pipeline to pretrained.")
refiner_pipeline.save_pretrained(self.refiner_diffusers_dir)
+
# load scheduler
if self.scheduler is not None:
logger.debug(f"Setting refiner scheduler to {self.scheduler.__name__}") # type: ignore[attr-defined]
@@ -2778,14 +3411,17 @@ def inpainter_pipeline(self) -> EnfugueStableDiffusionPipeline:
else:
raise ConfigurationError(f"No target inpainter, creation is disabled, and default inpainter does not exist at {target_checkpoint_path}")
self.inpainter = target_checkpoint_path
+ self.is_default_inpainter = True
+ else:
+ self.is_default_inpainter = False
if self.inpainter.startswith("http"):
self.inpainter = self.check_download_model(self.engine_checkpoints_dir, self.inpainter)
kwargs = {
"cache_dir": self.engine_cache_dir,
- "engine_size": self.inpainter_size,
- "chunking_size": self.chunking_size,
+ "engine_size": self.inpainter_tensorrt_size,
+ "tiling_stride": self.tiling_stride,
"torch_dtype": self.dtype,
"requires_safety_checker": self.safe,
"requires_aesthetic_score": False,
@@ -2826,10 +3462,14 @@ def inpainter_pipeline(self) -> EnfugueStableDiffusionPipeline:
kwargs["text_encoder_2"] = None
kwargs["tokenizer_2"] = None
- kwargs["build_half"] = "16" in str(self.dtype)
+ if "16" in str(self.dtype):
+ kwargs["variant"] = "fp16"
+ kwargs["build_half"] = True
+
logger.debug(
f"Initializing inpainter TensorRT pipeline from diffusers cache directory at {self.inpainter_diffusers_cache_dir}. Arguments are {redact(kwargs)}"
)
+
inpainter_pipeline = self.inpainter_pipeline_class.from_pretrained(
self.inpainter_diffusers_cache_dir,
local_files_only=self.offline,
@@ -2846,6 +3486,8 @@ def inpainter_pipeline(self) -> EnfugueStableDiffusionPipeline:
kwargs["tokenizer_2"] = None
if vae is not None:
kwargs["vae"] = vae
+ if "16" in str(self.dtype):
+ kwargs["variant"] = "fp16"
logger.debug(
f"Initializing inpainter pipeline from diffusers cache directory at {self.inpainter_diffusers_cache_dir}. Arguments are {redact(kwargs)}"
@@ -2872,18 +3514,19 @@ def inpainter_pipeline(self) -> EnfugueStableDiffusionPipeline:
if inpainter_pipeline.is_sdxl and "16" not in self.inpainter and (self.inpainter_vae_name is None or "16" not in self.inpainter_vae_name):
inpainter_pipeline.register_to_config(force_full_precision_vae=True)
if self.should_cache_inpainter:
- logger.debug("Saving inpainter pipeline to pretrained cache.")
+ self.task_callback("Saving inpainter pipeline to pretrained cache.")
inpainter_pipeline.save_pretrained(self.inpainter_diffusers_dir)
if not self.inpainter_tensorrt_is_ready:
for lora, weight in self.lora:
- logger.debug(f"Adding LoRA {lora} to inpainter pipeline with weight {weight}")
+ self.task_callback(f"Adding LoRA {os.path.basename(lora)} to inpainter pipeline with weight {weight}")
inpainter_pipeline.load_lora_weights(lora, multiplier=weight)
for lycoris, weight in self.lycoris:
- logger.debug(f"Adding lycoris {lycoris} to inpainter pipeline")
+ self.task_callback(f"Adding lycoris {os.path.basename(lycoris)} to inpainter pipeline")
inpainter_pipeline.load_lycoris_weights(lycoris, multiplier=weight)
for inversion in self.inversion:
- logger.debug(f"Adding textual inversion {inversion} to inpainter pipeline")
+ self.task_callback(f"Adding textual inversion {os.path.basename(inversion)} to inpainter pipeline")
inpainter_pipeline.load_textual_inversion(inversion)
+
# load scheduler
if self.scheduler is not None:
logger.debug(f"Setting inpainter scheduler to {self.scheduler.__name__}") # type: ignore[attr-defined]
@@ -2900,8 +3543,146 @@ def inpainter_pipeline(self) -> None:
logger.debug("Deleting inpainter pipeline.")
del self._inpainter_pipeline
self.clear_memory()
+
+ @property
+ def animator_pipeline(self) -> EnfugueAnimateStableDiffusionPipeline:
+ """
+ Instantiates the animator pipeline.
+ """
+ if not hasattr(self, "_animator_pipeline"):
+ if self.animator is None:
+ logger.info("No animator explicitly set, using base model for animator.")
+ self.animator = self.model
+ self.is_default_animator = True
+ else:
+ self.is_default_animator = False
+
+ if self.animator.startswith("http"):
+ self.animator = self.check_download_model(self.engine_checkpoints_dir, self.animator)
+
+ # Disable reloading if it was set
+ self.reload_motion_module = False
+
+ kwargs = {
+ "cache_dir": self.engine_cache_dir,
+ "engine_size": self.animator_tensorrt_size,
+ "tiling_stride": self.tiling_stride,
+ "frame_window_size": self.frame_window_size,
+ "frame_window_stride": self.frame_window_stride,
+ "torch_dtype": self.dtype,
+ "requires_safety_checker": self.safe,
+ "requires_aesthetic_score": False,
+ "controlnets": self.animator_controlnets,
+ "force_full_precision_vae": self.animator_is_sdxl and self.animator_vae_name not in ["xl16", VAE_XL16],
+ "ip_adapter": self.ip_adapter,
+ "task_callback": getattr(self, "_task_callback", None),
+ "motion_module": self.motion_module,
+ "position_encoding_truncate_length": self.position_encoding_truncate_length,
+ "position_encoding_scale_length": self.position_encoding_scale_length,
+ }
+
+ vae = self.animator_vae
+
+ if self.animator_use_tensorrt:
+ if self.animator_is_sdxl: # Not possible yet
+ raise ValueError(f"Sorry, TensorRT is not yet supported for SDXL.")
+
+ if "unet" in self.TENSORRT_STAGES:
+ if not self.animator_controlnets and not self.TENSORRT_ALWAYS_USE_CONTROLLED_UNET:
+ kwargs["unet_engine_dir"] = self.animator_tensorrt_unet_dir
+ else:
+ kwargs["controlled_unet_engine_dir"] = self.animator_tensorrt_controlled_unet_dir
+
+ if "vae" in self.TENSORRT_STAGES:
+ kwargs["vae_engine_dir"] = self.animator_tensorrt_vae_dir
+ elif vae is not None:
+ kwargs["vae"] = vae
+
+ if "clip" in self.TENSORRT_STAGES:
+ kwargs["clip_engine_dir"] = self.animator_tensorrt_clip_dir
+
+ self.check_create_animator_engine_cache()
+
+ if not self.safe:
+ kwargs["safety_checker"] = None
+ if not self.animator_is_sdxl:
+ kwargs["text_encoder_2"] = None
+ kwargs["tokenizer_2"] = None
+ if "16" in str(self.dtype):
+ kwargs["variant"] = "fp16"
+ kwargs["build_half"] = True
+
+ logger.debug(
+ f"Initializing animator TensorRT pipeline from diffusers cache directory at {self.animator_diffusers_cache_dir}. Arguments are {redact(kwargs)}"
+ )
+
+ animator_pipeline = self.animator_pipeline_class.from_pretrained(
+ self.animator_diffusers_cache_dir, **kwargs
+ )
+ elif self.animator_engine_cache_exists:
+ if not self.safe:
+ kwargs["safety_checker"] = None
+ if not self.animator_is_sdxl:
+ kwargs["text_encoder_2"] = None
+ kwargs["tokenizer_2"] = None
+ kwargs["text_encoder_2"] = None
+ kwargs["tokenizer_2"] = None
+ if vae is not None:
+ kwargs["vae"] = vae
+ if "16" in str(self.dtype):
+ kwargs["variant"] = "fp16"
+
+ logger.debug(
+ f"Initializing animator pipeline from diffusers cache directory at {self.animator_diffusers_cache_dir}. Arguments are {redact(kwargs)}"
+ )
+
+ animator_pipeline = self.animator_pipeline_class.from_pretrained(
+ self.animator_diffusers_cache_dir, **kwargs
+ )
+ else:
+ if self.animator_vae_name is not None:
+ kwargs["vae_path"] = self.find_vae_path(self.animator_vae_name)
+
+ logger.debug(
+ f"Initializing animator pipeline from checkpoint at {self.animator}. Arguments are {redact(kwargs)}"
+ )
+
+ animator_pipeline = self.animator_pipeline_class.from_ckpt(
+ self.animator, load_safety_checker=self.safe, **kwargs
+ )
+ if animator_pipeline.is_sdxl and self.animator_vae_name not in ["xl16", VAE_XL16]:
+ animator_pipeline.register_to_config(force_full_precision_vae=True)
+ if self.should_cache_animator:
+ self.task_callback("Saving animator pipeline to pretrained cache.")
+ animator_pipeline.save_pretrained(self.animator_diffusers_dir)
+ if not self.animator_tensorrt_is_ready:
+ for lora, weight in self.lora:
+ self.task_callback(f"Adding LoRA {os.path.basename(lora)} to animator pipeline with weight {weight}")
+ animator_pipeline.load_lora_weights(lora, multiplier=weight)
+ for lycoris, weight in self.lycoris:
+ self.task_callback(f"Adding lycoris {os.path.basename(lycoris)} to animator pipeline")
+ animator_pipeline.load_lycoris_weights(lycoris, multiplier=weight)
+ for inversion in self.inversion:
+ self.task_callback(f"Adding textual inversion {os.path.basename(inversion)} to animator pipeline")
+ animator_pipeline.load_textual_inversion(inversion)
+ # load scheduler
+ if self.scheduler is not None:
+ logger.debug(f"Setting animator scheduler to {self.scheduler.__name__}") # type: ignore [attr-defined]
+ animator_pipeline.scheduler = self.scheduler.from_config({**animator_pipeline.scheduler_config, **self.scheduler_config}) # type: ignore[attr-defined]
+ self._animator_pipeline = animator_pipeline.to(self.device)
+ return self._animator_pipeline
+
+ @animator_pipeline.deleter
+ def animator_pipeline(self) -> None:
+ """
+ Unloads the animator pipeline if present.
+ """
+ if hasattr(self, "_animator_pipeline"):
+ logger.debug("Deleting animator pipeline.")
+ del self._animator_pipeline
+ self.clear_memory()
- def unload_pipeline(self, reason: str = "none") -> None:
+ def unload_pipeline(self, reason: str = "None") -> None:
"""
Calls the pipeline deleter.
"""
@@ -2931,7 +3712,7 @@ def offload_pipeline(self, intention: Optional[Literal["inpainting", "refining"]
self._pipeline = self._pipeline.to("cpu") # type: ignore[attr-defined]
self.clear_memory()
- def unload_refiner(self, reason: str = "none") -> None:
+ def unload_refiner(self, reason: str = "None") -> None:
"""
Calls the refiner deleter.
"""
@@ -2961,7 +3742,7 @@ def offload_refiner(self, intention: Optional[Literal["inpainting", "inference"]
self._refiner_pipeline = self._refiner_pipeline.to("cpu") # type: ignore[attr-defined]
self.clear_memory()
- def unload_inpainter(self, reason: str = "none") -> None:
+ def unload_inpainter(self, reason: str = "None") -> None:
"""
Calls the inpainter deleter.
"""
@@ -2993,6 +3774,41 @@ def offload_inpainter(self, intention: Optional[Literal["inference", "refining"]
self._inpainter_pipeline = self._inpainter_pipeline.to("cpu") # type: ignore[attr-defined]
self.clear_memory()
+ def unload_animator(self, reason: str = "None") -> None:
+ """
+ Calls the animator deleter.
+ """
+ if hasattr(self, "_animator_pipeline"):
+ logger.debug(f'Unloading animator pipeline for reason "{reason}"')
+ del self.animator_pipeline
+
+ def offload_animator(self, intention: Optional[Literal["inference", "inpainting", "refining"]] = None) -> None:
+ """
+ Offloads the pipeline to CPU if present.
+ """
+ if hasattr(self, "_animator_pipeline"):
+ import torch
+
+ if self.pipeline_switch_mode == "unload":
+ logger.debug("Offloading is disabled, unloading animator pipeline.")
+ self.unload_animator("switching modes" if not intention else f"switching to {intention}")
+ elif self.pipeline_switch_mode is None:
+ logger.debug("Offloading is disabled, keeping animator pipeline in memory.")
+ elif intention == "inference" and hasattr(self, "_pipeline"):
+ logger.debug("Swapping pipeline out of CPU and animator into CPU")
+ self.swap_pipelines(self._pipeline, self._animator_pipeline)
+ elif intention == "inpainting" and hasattr(self, "_inpainter_pipeline"):
+ logger.debug("Swapping inpainter out of CPU and animator into CPU")
+ self.swap_pipelines(self._inpainter_pipeline, self._animator_pipeline)
+ elif intention == "refining" and hasattr(self, "_refiner_pipeline"):
+ logger.debug("Swapping refiner out of CPU and animator into CPU")
+ self.swap_pipelines(self._refiner_pipeline, self._animator_pipeline)
+ else:
+ import torch
+ logger.debug("Offloading animator to CPU")
+ self._animator_pipeline = self._animator_pipeline.to("cpu", torch_dtype=torch.float32) # type: ignore[attr-defined]
+ self.clear_memory()
+
@property
def upscaler(self) -> Upscaler:
"""
@@ -3099,14 +3915,6 @@ def get_controlnet(self, controlnet: Optional[str] = None) -> Optional[ControlNe
logger.debug(f"Received KeyError on '{ex}' when instantiating controlnet from single file, trying to use XL ControlNet loader fix.")
return self.get_xl_controlnet(expected_controlnet_location)
else:
- expected_controlnet_location = os.path.join(self.engine_cache_dir, "models--" + controlnet.replace("/", "--"))
- if not os.path.exists(expected_controlnet_location):
- if self.offline:
- raise IOError(f"Offline mode enabled, cannot find requested ControlNet at {expected_controlnet_location}")
- logger.info(
- f"Controlnet {controlnet} does not exist in cache directory {self.engine_cache_dir}, it will be downloaded."
- )
- self.task_callback(f"Downloading {controlnet} model weights")
result = ControlNetModel.from_pretrained(
controlnet,
torch_dtype=torch.half,
@@ -3129,8 +3937,12 @@ def get_default_controlnet_path_by_name(
return CONTROLNET_CANNY_XL
elif name == "depth":
return CONTROLNET_DEPTH_XL
+ elif name == "pidi":
+ return CONTROLNET_PIDI_XL
elif name == "pose":
return CONTROLNET_POSE_XL
+ elif name == "qr":
+ return CONTROLNET_QR_XL
else:
raise ValueError(f"Sorry, ControlNet “{name}” is not yet supported by SDXL. Check back soon!")
else:
@@ -3254,7 +4066,6 @@ def controlnets(
if getattr(self, "_pipeline", None) is not None:
self._pipeline.controlnets = self.controlnets
-
@property
def inpainter_controlnets(self) -> Dict[str, ControlNetModel]:
"""
@@ -3323,6 +4134,74 @@ def inpainter_controlnets(
if getattr(self, "_inpainter_pipeline", None) is not None:
self._inpainter_pipeline.controlnets = self.inpainter_controlnets
+ @property
+ def animator_controlnets(self) -> Dict[str, ControlNetModel]:
+ """
+ Gets the configured controlnets for the animator
+ """
+ if not hasattr(self, "_animator_controlnets"):
+ self._animator_controlnets = {}
+
+ for controlnet_name in self.animator_controlnet_names:
+ self._animator_controlnets[controlnet_name] = self.get_controlnet(
+ self.get_controlnet_path_by_name(controlnet_name, self.is_sdxl)
+ )
+ return self._animator_controlnets # type: ignore[return-value]
+
+ @animator_controlnets.deleter
+ def animator_controlnets(self) -> None:
+ """
+ Removes current animator controlnets and clears memory
+ """
+ if hasattr(self, "_animator_controlnets"):
+ del self._animator_controlnets
+ self.clear_memory()
+
+ @animator_controlnets.setter
+ def animator_controlnets(
+ self,
+ *new_animator_controlnets: Optional[Union[CONTROLNET_LITERAL, List[CONTROLNET_LITERAL], Set[CONTROLNET_LITERAL]]],
+ ) -> None:
+ """
+ Sets a new list of animator controlnets (optional)
+ """
+ controlnet_names: Set[CONTROLNET_LITERAL] = set()
+
+ for arg in new_animator_controlnets:
+ if arg is None:
+ break
+ if isinstance(arg, str):
+ controlnet_names.add(arg)
+ else:
+ controlnet_names = controlnet_names.union(arg) # type: ignore[arg-type]
+
+ existing_controlnet_names = self.animator_controlnet_names
+
+ if controlnet_names == existing_controlnet_names:
+ return # No changes
+
+ logger.debug(f"Setting animator pipeline ControlNet(s) to {controlnet_names} from {existing_controlnet_names}")
+ self._animator_controlnet_names = controlnet_names
+
+ if (not controlnet_names and existing_controlnet_names):
+ self.unload_animator("Disabling ControlNet")
+ del self.animator_controlnets
+ elif (controlnet_names and not existing_controlnet_names):
+ self.unload_animator("Enabling ControlNet")
+ del self.animator_controlnets
+ elif controlnet_names and existing_controlnet_names:
+ logger.debug("Altering existing animator controlnets")
+ if hasattr(self, "_animator_controlnets"):
+ for controlnet_name in controlnet_names.union(existing_controlnet_names):
+ if controlnet_name not in controlnet_names:
+ self._animator_controlnets.pop(controlnet_name, None)
+ elif controlnet_name not in self._animator_controlnets:
+ self._animator_controlnets[controlnet_name] = self.get_controlnet(
+ self.get_controlnet_path_by_name(controlnet_name, self.is_sdxl)
+ )
+ if getattr(self, "_animator_pipeline", None) is not None:
+ self._animator_pipeline.controlnets = self.animator_controlnets
+
@property
def refiner_controlnets(self) -> Dict[str, ControlNetModel]:
"""
@@ -3405,6 +4284,13 @@ def inpainter_controlnet_names(self) -> Set[CONTROLNET_LITERAL]:
"""
return getattr(self, "_inpainter_controlnet_names", set())
+ @property
+ def animator_controlnet_names(self) -> Set[CONTROLNET_LITERAL]:
+ """
+ Gets the name of the control net, if one was set.
+ """
+ return getattr(self, "_animator_controlnet_names", set())
+
@property
def refiner_controlnet_names(self) -> Set[CONTROLNET_LITERAL]:
"""
@@ -3425,7 +4311,8 @@ def __call__(
refiner_negative_prompt_2: Optional[str] = None,
scale_to_refiner_size: bool = True,
task_callback: Optional[Callable[[str], None]] = None,
- next_intention: Optional[Literal["inpainting", "inference", "refining", "upscaling"]] = None,
+ next_intention: Optional[Literal["inpainting", "animation", "inference", "refining", "upscaling"]] = None,
+ scheduler: Optional[SCHEDULER_LITERAL] = None,
**kwargs: Any,
) -> StableDiffusionPipelineOutput:
"""
@@ -3449,26 +4336,45 @@ def memoize_callback(images: List[PIL.Image.Image]) -> None:
previous_callback(images)
latent_callback = memoize_callback # type: ignore[assignment]
kwargs["latent_callback"] = memoize_callback
+
self.start_keepalive()
+
try:
+ animating = bool(kwargs.get("animation_frames", None))
inpainting = kwargs.get("mask", None) is not None
refining = (
kwargs.get("image", None) is not None and
kwargs.get("strength", 0) in [0, None] and
- kwargs.get("ip_adapter_images", None) is None and
+ kwargs.get("ip_adapter_scale", None) is None and
refiner_strength != 0 and
refiner_start != 1 and
self.refiner is not None
)
- intention = "inpainting" if inpainting else "refining" if refining else "inference"
+
+ if animating:
+ intention = "animation"
+ elif inpainting:
+ intention = "inpainting"
+ elif refining:
+ intention = "refining"
+ else:
+ intention = "inference"
+
task_callback(f"Preparing {intention.title()} Pipeline")
- if inpainting and (self.has_inpainter or self.create_inpainter):
+
+ if animating and self.has_animator:
+ size = self.animator_size
+ elif inpainting and (self.has_inpainter or self.create_inpainter):
size = self.inpainter_size
elif refining:
size = self.refiner_size
else:
size = self.size
+ if scheduler is not None:
+ # Allow overriding scheduler
+ self.scheduler = scheduler # type: ignore[assignment]
+
if refining:
# Set result here to passed image
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -3477,12 +4383,13 @@ def memoize_callback(images: List[PIL.Image.Image]) -> None:
images=[kwargs["image"]] * samples,
nsfw_content_detected=[False] * samples
)
+ self.offload_animator(intention) # type: ignore
self.offload_pipeline(intention) # type: ignore
self.offload_inpainter(intention) # type: ignore
else:
called_width = kwargs.get("width", size)
called_height = kwargs.get("height", size)
- chunk_size = kwargs.get("chunking_size", self.chunking_size)
+ tiling_stride = kwargs.get("tiling_stride", self.tiling_stride)
# Check sizes
if called_width < size:
@@ -3491,7 +4398,7 @@ def memoize_callback(images: List[PIL.Image.Image]) -> None:
elif called_height < size:
self.tensorrt_is_enabled = False
logger.info(f"height ({called_height}) less than configured height ({size}), disabling TensorRT")
- elif (called_width != size or called_height != size) and not chunk_size:
+ elif (called_width != size or called_height != size) and not tiling_stride:
logger.info(f"Dimensions do not match size of engine and chunking is disabled, disabling TensorRT")
self.tensorrt_is_enabled = False
else:
@@ -3502,16 +4409,47 @@ def memoize_callback(images: List[PIL.Image.Image]) -> None:
logger.info(f"IP adapter requested, TensorRT is not compatible, disabling.")
self.tensorrt_is_enabled = False
- if inpainting and (self.has_inpainter or self.create_inpainter):
+ if animating:
+ if not self.has_animator:
+ logger.debug(f"Animation requested but no animator set, setting animator to the same as the base model")
+ self.animator = self.model
+
+ self.offload_pipeline(intention) # type: ignore
+ self.offload_refiner(intention) # type: ignore
+ self.offload_inpainter(intention) # type: ignore
+
+ pipe = self.animator_pipeline
+
+ if self.reload_motion_module:
+ if task_callback is not None:
+ task_callback("Reloading motion module")
+ try:
+ pipe.load_motion_module_weights(
+ cache_dir=self.engine_cache_dir,
+ motion_module=self.motion_module,
+ task_callback=task_callback,
+ position_encoding_truncate_length=self.position_encoding_truncate_length,
+ position_encoding_scale_length=self.position_encoding_scale_length,
+ )
+ except Exception as ex:
+ logger.warning(f"Received Exception {ex} when loading motion module weights, will try to reload the entire pipeline.")
+ del pipe
+ self.reload_motion_module = False
+ self.unload_animator("Re-initializing Pipeline")
+ pipe = self.animator_pipeline # Will raise
+ elif inpainting and (self.has_inpainter or self.create_inpainter):
self.offload_pipeline(intention) # type: ignore
self.offload_refiner(intention) # type: ignore
- pipe = self.inpainter_pipeline
+ self.offload_animator(intention) # type: ignore
+ pipe = self.inpainter_pipeline # type: ignore
else:
if inpainting:
logger.info(f"No inpainter set and creation is disabled; using base pipeline for inpainting.")
self.offload_refiner(intention) # type: ignore
self.offload_inpainter(intention) # type: ignore
- pipe = self.pipeline
+ self.offload_animator(intention) # type: ignore
+
+ pipe = self.pipeline # type: ignore
# Check refining settings
if self.refiner is not None and refiner_strength != 0:
@@ -3521,9 +4459,11 @@ def memoize_callback(images: List[PIL.Image.Image]) -> None:
kwargs["output_type"] = "latent"
# Check IP adapter for downloads
- if kwargs.get("ip_adapter_scale", None) is not None:
+ if kwargs.get("ip_adapter_images", None) is not None:
self.ip_adapter.check_download(
is_sdxl=pipe.is_sdxl,
+ model=kwargs.get("ip_adapter_model", None),
+ task_callback=task_callback,
)
self.stop_keepalive()
@@ -3597,7 +4537,7 @@ def memoize_callback(images: List[PIL.Image.Image]) -> None:
kwargs["negative_prompt_2"] = refiner_negative_prompt_2
logger.debug(f"Refining results with arguments {redact(kwargs)}")
- pipe = self.refiner_pipeline # Instantiate if needed
+ pipe = self.refiner_pipeline # type: ignore
self.stop_keepalive() # This checks, we can call it all we want
task_callback(f"Refining")
@@ -3629,12 +4569,12 @@ def write_model_metadata(self, path: str) -> None:
Writes metadata for TensorRT to a json file
"""
if "controlnet" in path:
- dump_json(path, {"size": self.size, "controlnets": self.controlnet_names})
+ dump_json(path, {"size": self.tensorrt_size, "controlnets": self.controlnet_names})
else:
dump_json(
path,
{
- "size": self.size,
+ "size": self.tensorrt_size,
"lora": self.lora_names_weights,
"lycoris": self.lycoris_names_weights,
"inversion": self.inversion_names,
diff --git a/src/python/enfugue/diffusion/pipeline.py b/src/python/enfugue/diffusion/pipeline.py
index 641f1a58..e31921c8 100644
--- a/src/python/enfugue/diffusion/pipeline.py
+++ b/src/python/enfugue/diffusion/pipeline.py
@@ -26,6 +26,7 @@
import torch
import inspect
import datetime
+import torchvision
import numpy as np
import safetensors.torch
@@ -50,6 +51,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
+from diffusers.models.modeling_utils import ModelMixin
from diffusers.models import (
AutoencoderKL,
AutoencoderTiny,
@@ -81,6 +83,24 @@
from diffusers.utils.torch_utils import randn_tensor
from diffusers.image_processor import VaeImageProcessor
+from enfugue.diffusion.constants import *
+from enfugue.diffusion.util import (
+ MaskWeightBuilder,
+ Prompt,
+ Chunker,
+ EncodedPrompt,
+ EncodedPrompts,
+ Video,
+ load_state_dict
+)
+from enfugue.util import (
+ logger,
+ check_download,
+ check_download_to_dir,
+ TokenMerger,
+)
+
+from einops import rearrange
from pibble.util.files import load_json
from enfugue.util import logger, check_download, check_download_to_dir, TokenMerger
@@ -92,41 +112,41 @@
LatentScaler,
MaskWeightBuilder
)
+
if TYPE_CHECKING:
from enfugue.diffusers.support.ip import IPAdapter
- from enfugue.diffusion.constants import (
- MASK_TYPE_LITERAL,
- NOISE_METHOD_LITERAL,
- LATENT_BLEND_METHOD_LITERAL
- )
# This is ~64k×64k. Absurd, but I don't judge
PIL.Image.MAX_IMAGE_PIXELS = 2**32
+# Image arg accepted arguments
+ImageArgType = Union[str, PIL.Image.Image, List[PIL.Image.Image]]
+
# IP image accepted arguments
class ImagePromptArgDict(TypedDict):
- image: Union[str, PIL.Image.Image]
+ image: ImageArgType
scale: NotRequired[float]
ImagePromptType = Union[
- Union[str, PIL.Image.Image], # Image
- Tuple[Union[str, PIL.Image.Image], float], # Image, Scale
- ImagePromptArgDict
+ ImageArgType, # Image
+ Tuple[ImageArgType, float], # Image, Scale
+ ImagePromptArgDict,
]
+
ImagePromptArgType = Optional[Union[ImagePromptType, List[ImagePromptType]]]
# Control image accepted arguments
class ControlImageArgDict(TypedDict):
- image: Union[str, PIL.Image.Image]
+ image: ImageArgType
scale: NotRequired[float]
start: NotRequired[float]
end: NotRequired[float]
ControlImageType = Union[
- Union[str, PIL.Image.Image], # Image
- Tuple[Union[str, PIL.Image.Image], float], # Image, Scale
- Tuple[Union[str, PIL.Image.Image], float, float], # Image, Scale, End
- Tuple[Union[str, PIL.Image.Image], float, float, float], # Image, Scale, Start, End
+ ImageArgType, # Image
+ Tuple[ImageArgType, float], # Image, Scale
+ Tuple[ImageArgType, float, float], # Image, Scale, End Denoising
+ Tuple[ImageArgType, float, float, float], # Image, Scale, Start Denoising, End Denoising
ControlImageArgDict
]
@@ -135,14 +155,14 @@ class ControlImageArgDict(TypedDict):
class EnfugueStableDiffusionPipeline(StableDiffusionPipeline):
"""
- This pipeline merges all of the following:
+ This pipeline merges all of the following, for all versions of SD:
1. txt2img
2. img2img
3. inpainting/outpainting
- 4. controlnet
- 5. tensorrt
+ 4. controlnet/multicontrolnet
+ 5. ip adapter
+ 6. animatediff
"""
-
controlnets: Optional[Dict[str, ControlNetModel]]
unet: UNet2DConditionModel
scheduler: KarrasDiffusionSchedulers
@@ -153,8 +173,9 @@ class EnfugueStableDiffusionPipeline(StableDiffusionPipeline):
text_encoder: Optional[CLIPTextModel]
text_encoder_2: Optional[CLIPTextModelWithProjection]
vae_scale_factor: int
- safety_checker: StableDiffusionSafetyChecker
+ safety_checker: Optional[StableDiffusionSafetyChecker]
config: OmegaConf
+ safety_checking_disabled: bool = False
def __init__(
self,
@@ -168,16 +189,19 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: CLIPImageProcessor,
- controlnets: Optional[Dict[str, ControlNetModel]] = None,
- requires_safety_checker: bool = True,
- force_zeros_for_empty_prompt: bool = True,
- requires_aesthetic_score: bool = False,
- force_full_precision_vae: bool = False,
- ip_adapter: Optional[IPAdapter] = None,
- engine_size: int = 512,
- chunking_size: int = 64,
- chunking_mask_type: MASK_TYPE_LITERAL = "bilinear",
- chunking_mask_kwargs: Dict[str, Any] = {}
+ controlnets: Optional[Dict[str, ControlNetModel]]=None,
+ requires_safety_checker: bool=True,
+ force_zeros_for_empty_prompt: bool=True,
+ requires_aesthetic_score: bool=False,
+ force_full_precision_vae: bool=False,
+ ip_adapter: Optional[IPAdapter]=None,
+ engine_size: int=512,
+ tiling_size: Optional[int]=None,
+ tiling_stride: Optional[int]=64,
+ tiling_mask_type: MASK_TYPE_LITERAL="bilinear",
+ tiling_mask_kwargs: Dict[str, Any]={},
+ frame_window_size: Optional[int]=16,
+ frame_window_stride: Optional[int]=4
) -> None:
super(EnfugueStableDiffusionPipeline, self).__init__(
vae,
@@ -191,13 +215,17 @@ def __init__(
)
# Save scheduler config for hotswapping
+ self.scheduler_class = type(scheduler)
self.scheduler_config = {**dict(scheduler.config)} # type: ignore[attr-defined]
# Enfugue engine settings
self.engine_size = engine_size
- self.chunking_size = chunking_size
- self.chunking_mask_type = chunking_mask_type
- self.chunking_mask_kwargs = chunking_mask_kwargs
+ self.tiling_size = tiling_size
+ self.tiling_stride = tiling_stride
+ self.tiling_mask_type = tiling_mask_type
+ self.tiling_mask_kwargs = tiling_mask_kwargs
+ self.frame_window_size = frame_window_size
+ self.frame_window_stride = frame_window_stride
# Hide tqdm
self.set_progress_bar_config(disable=True) # type: ignore[attr-defined]
@@ -242,8 +270,19 @@ def debug_tensors(cls, **kwargs: Union[Dict, List, torch.Tensor]) -> None:
elif isinstance(value, dict):
for k in value:
cls.debug_tensors(**{f"{key}_{k}": value[k]})
- else:
- logger.debug(f"{key} = {value.shape} ({value.dtype})")
+ elif isinstance(value, torch.Tensor):
+ logger.debug(f"{key} = {value.shape} ({value.dtype}) on {value.device}")
+
+ @classmethod
+ def open_image(cls, path: str) -> List[PIL.Image.Image]:
+ """
+ Opens an image or video and standardizes to a list of images
+ """
+ ext = os.path.splitext(path)[1]
+ if ext in [".gif", ".webp", ".mp4", ".mkv", ".mp4", ".avi", ".mov", ".apng"]:
+ return list(Video.file_to_frames(path))
+ else:
+ return [PIL.Image.open(path)]
@classmethod
def create_unet(
@@ -252,8 +291,9 @@ def create_unet(
cache_dir: str,
is_sdxl: bool,
is_inpainter: bool,
+ task_callback: Optional[Callable[[str], None]]=None,
**kwargs: Any
- ) -> UNet2DConditionModel:
+ ) -> ModelMixin:
"""
Instantiates the UNet from config
"""
@@ -268,18 +308,22 @@ def from_ckpt(
cls,
checkpoint_path: str,
cache_dir: str,
- prediction_type: Optional[str] = None,
- image_size: int = 512,
- scheduler_type: Literal["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", "ddim"] = "ddim",
- vae_path: Optional[str] = None,
- vae_preview_path: Optional[str] = None,
- load_safety_checker: bool = True,
- torch_dtype: Optional[torch.dtype] = None,
- upcast_attention: Optional[bool] = None,
- extract_ema: Optional[bool] = None,
- offload_models: bool = False,
+ prediction_type: Optional[str]=None,
+ image_size: int=512,
+ scheduler_type: Literal["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", "ddim"]="ddim",
+ vae_path: Optional[str]=None,
+ vae_preview_path: Optional[str]=None,
+ load_safety_checker: bool=True,
+ torch_dtype: Optional[torch.dtype]=None,
+ upcast_attention: Optional[bool]=None,
+ extract_ema: Optional[bool]=None,
+ motion_module: Optional[str]=None,
+ unet_kwargs: Dict[str, Any]={},
+ offload_models: bool=False,
is_inpainter=False,
task_callback: Optional[Callable[[str], None]]=None,
+ position_encoding_truncate_length: Optional[int]=None,
+ position_encoding_scale_length: Optional[int]=None,
**kwargs: Any,
) -> EnfugueStableDiffusionPipeline:
"""
@@ -288,7 +332,10 @@ def from_ckpt(
That's why we override it for this method - most of this is copied from
https://github.com/huggingface/diffusers/blob/49949f321d9b034440b52e54937fd2df3027bf0a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
"""
- logger.debug(f"Reading checkpoint file {checkpoint_path}")
+ if task_callback is None:
+ task_callback = lambda msg: logger.debug(msg)
+
+ task_callback(f"Loading checkpoint file {os.path.basename(checkpoint_path)}")
checkpoint = load_state_dict(checkpoint_path)
# Sometimes models don't have the global_step item
@@ -455,15 +502,22 @@ def from_ckpt(
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
+ task_callback("Loading UNet")
+
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
unet = cls.create_unet(
unet_config,
cache_dir=cache_dir,
+ motion_module=motion_module,
is_sdxl=isinstance(model_type, str) and model_type.startswith("SDXL"),
is_inpainter=is_inpainter,
- )
+ task_callback=task_callback,
+ position_encoding_truncate_length=position_encoding_truncate_length,
+ position_encoding_scale_length=position_encoding_scale_length,
+ **unet_kwargs
+ )
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint,
@@ -471,16 +525,20 @@ def from_ckpt(
path=checkpoint_path,
extract_ema=extract_ema
)
+ unet_keys = len(list(converted_unet_checkpoint.keys()))
+ logger.debug(f"Loading {unet_keys} keys into UNet state dict (non-strict)")
- unet.load_state_dict(converted_unet_checkpoint)
+ unet.load_state_dict(converted_unet_checkpoint, strict=False)
if offload_models:
+ logger.debug("Offloading enabled; sending UNet to CPU")
unet.to("cpu")
empty_cache()
# Convert the VAE model.
if vae_path is None:
try:
+ task_callback("Loading Default VAE")
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
@@ -495,15 +553,13 @@ def from_ckpt(
vae_config["scaling_factor"] = vae_scale_factor
vae = AutoencoderKL(**vae_config)
+ vae_keys = len(list(converted_vae_checkpoint.keys()))
+ logger.debug(f"Loading {vae_keys} keys into Autoencoder state dict (strict). Autoencoder scale is {vae_scale_factor}")
vae.load_state_dict(converted_vae_checkpoint)
except KeyError as ex:
default_path = "stabilityai/sdxl-vae" if model_type in ["SDXL", "SDXL-Refiner"] else "stabilityai/sd-vae-ft-ema"
logger.error(f"Malformed VAE state dictionary detected; missing required key '{ex}'. Reverting to default model {default_path}")
- pretrained_save_path = os.path.join(cache_dir, "models--{0}".format(
- default_path.replace("/", "--")
- ))
- if not os.path.exists(pretrained_save_path) and task_callback is not None:
- task_callback(f"Downloading default VAE weights from repository {default_path}")
+ task_callback(f"Loading VAE {default_path}")
vae = AutoencoderKL.from_pretrained(default_path, cache_dir=cache_dir)
elif os.path.exists(vae_path):
if model_type in ["SDXL", "SDXL-Refiner"]:
@@ -513,47 +569,55 @@ def from_ckpt(
vae_config_path,
check_size=False
)
+ task_callback("Loading VAE")
vae = AutoencoderKL.from_config(
AutoencoderKL._dict_from_json_file(vae_config_path)
)
+ vae_state_dict = load_state_dict(vae_path)
+ vae_keys = len(list(vae_state_dict.keys()))
+ logger.debug(f"Loading {vae_keys} keys into Autoencoder state dict (non-strict)")
vae.load_state_dict(load_state_dict(vae_path), strict=False)
else:
+ logger.debug(f"Initializing Autoencoder from file {vae_path}")
+ task_callback("Loading VAE")
vae = AutoencoderKL.from_single_file(
vae_path,
cache_dir=cache_dir,
from_safetensors = "safetensors" in vae_path
)
else:
+ logger.debug(f"Initializing autoencoder from repository {vae_path}")
vae = AutoencoderKL.from_pretrained(vae_path, cache_dir=cache_dir)
+ if offload_models:
+ logger.debug("Offloading enabled; sending VAE to CPU")
+ vae.to("cpu")
+ empty_cache()
+
if vae_preview_path is None:
if model_type in ["SDXL", "SDXL-Refiner"]:
vae_preview_path = "madebyollin/taesdxl"
else:
vae_preview_path = "madebyollin/taesd"
- vae_preview_local_path = os.path.join(cache_dir, "models--{0}".format(vae_preview_path.replace("/", "--")))
- if not os.path.exists(vae_preview_local_path) and task_callback is not None:
- task_callback("Downloading preview VAE weights from repository {vae_preview_path}")
- vae_preview = AutoencoderTiny.from_pretrained(vae_preview_path, cache_dir=cache_dir)
-
- if offload_models:
- vae.to("cpu")
- empty_cache()
+ task_callback(f"Loading preview VAE {vae_preview_path}")
+ vae_preview = AutoencoderTiny.from_pretrained(
+ vae_preview_path,
+ cache_dir=cache_dir
+ )
if load_safety_checker:
safety_checker_path = "CompVis/stable-diffusion-safety-checker"
- safety_checker_local_path = os.path.join(cache_dir, "models--{0}".format(safety_checker_path.replace("/", "--")))
- if not os.path.exists(safety_checker_local_path) and task_callback is not None:
- task_callback(f"Downloading safety checker weights from repository {safety_checker_path}")
-
+ task_callback(f"Loading safety checker {safety_checker_path}")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_checker_path,
cache_dir=cache_dir
)
if offload_models:
+ logger.debug("Offloading enabled; sending safety checker to CPU")
safety_checker.to("cpu")
empty_cache()
+ task_callback(f"Initializing feature extractor from repository {safety_checker_path}")
feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_checker_path,
cache_dir=cache_dir
@@ -566,14 +630,12 @@ def from_ckpt(
if model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
if offload_models:
+ logger.debug("Offloading enabled; sending text encoder to CPU")
text_model.to("cpu")
empty_cache()
tokenizer_path = "openai/clip-vit-large-patch14"
- tokenizer_local_path = os.path.join(cache_dir, "models--{0}".format(tokenizer_path.replace("/", "--")))
- if not os.path.exists(tokenizer_local_path) and task_callback is not None:
- task_callback(f"Downloading tokenizer weights from repository {tokenizer_path}")
-
+ task_callback(f"Loading tokenizer {tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(
tokenizer_path,
cache_dir=cache_dir
@@ -586,7 +648,7 @@ def from_ckpt(
vae_preview=vae_preview,
text_encoder=text_model,
tokenizer=tokenizer,
- unet=unet,
+ unet=unet, # type: ignore[arg-type]
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
@@ -594,20 +656,16 @@ def from_ckpt(
)
elif model_type == "SDXL":
tokenizer_path = "openai/clip-vit-large-patch14"
- tokenizer_local_path = os.path.join(cache_dir, "models--{0}".format(tokenizer_path.replace("/", "--")))
- if not os.path.exists(tokenizer_local_path) and task_callback is not None:
- task_callback(f"Downloading tokenizer weights from repository {tokenizer_path}")
-
+ task_callback(f"Loading tokenizer 1 {tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(
tokenizer_path,
cache_dir=cache_dir
)
+
text_encoder = convert_ldm_clip_checkpoint(checkpoint)
tokenizer_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
- tokenizer_2_local_path = os.path.join(cache_dir, "models--{0}".format(tokenizer_2_path.replace("/", "--")))
- if not os.path.exists(tokenizer_local_path) and task_callback is not None:
- task_callback(f"Downloading tokenizer 2 weights from repository {tokenizer_2_path}")
+ task_callback(f"Loading tokenizer 2 {tokenizer_2_path}")
tokenizer_2 = CLIPTokenizer.from_pretrained(
tokenizer_2_path,
@@ -623,6 +681,7 @@ def from_ckpt(
)
if offload_models:
+ logger.debug("Offloading enabled; sending text encoder 1 and 2 to CPU")
text_encoder.to("cpu")
text_encoder_2.to("cpu")
empty_cache()
@@ -634,7 +693,7 @@ def from_ckpt(
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
- unet=unet,
+ unet=unet, # type: ignore[arg-type]
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
@@ -643,10 +702,7 @@ def from_ckpt(
)
elif model_type == "SDXL-Refiner":
tokenizer_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
- tokenizer_2_local_path = os.path.join(cache_dir, "models--{0}".format(tokenizer_2_path.replace("/", "--")))
- if not os.path.exists(tokenizer_2_local_path) and task_callback is not None:
- task_callback(f"Downloading tokenizer 2 weights from repository {tokenizer_2_path}")
-
+ task_callback(f"Loading tokenizer {tokenizer_2_path}")
tokenizer_2 = CLIPTokenizer.from_pretrained(
tokenizer_2_path,
cache_dir=cache_dir,
@@ -661,6 +717,7 @@ def from_ckpt(
)
if offload_models:
+ logger.debug("Offloading enabled; sending text encoder 2 to CPU")
text_encoder_2.to("cpu")
empty_cache()
@@ -671,7 +728,7 @@ def from_ckpt(
text_encoder_2=text_encoder_2,
tokenizer=None,
tokenizer_2=tokenizer_2,
- unet=unet,
+ unet=unet, # type: ignore[arg-type]
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
@@ -720,6 +777,19 @@ def module_size(self) -> int:
size += buffer.nelement() * buffer.element_size()
return size
+ @property
+ def is_inpainting_unet(self) -> bool:
+ """
+ Returns true if this is an inpainting UNet (9-channel)
+ """
+ return self.unet.config.in_channels == 9 # type: ignore[attr-defined]
+
+ def revert_scheduler(self) -> None:
+ """
+ Reverts the scheduler back to whatever the original was.
+ """
+ self.scheduler = self.scheduler_class.from_config(self.scheduler_config) # type: ignore[attr-defined]
+
def get_size_from_module(self, module: torch.nn.Module) -> int:
"""
Gets the size of a module in bytes
@@ -748,6 +818,8 @@ def align_unet(
self,
device: torch.device,
dtype: torch.dtype,
+ animation_frames: Optional[int] = None,
+ motion_scale: Optional[float] = None,
freeu_factors: Optional[Tuple[float, float, float, float]] = None,
offload_models: bool = False
) -> None:
@@ -765,8 +837,18 @@ def align_unet(
self.unet.disable_freeu()
else:
s1, s2, b1, b2 = freeu_factors
+ logger.debug(f"Enabling FreeU with factors {s1=} {s2=} {b1=} {b2=}")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
self._freeu_enabled = True
+ if animation_frames:
+ try:
+ if motion_scale:
+ logger.debug(f"Setting motion attention scale to {motion_scale}")
+ self.unet.set_motion_attention_scale(motion_scale)
+ else:
+ self.unet.reset_motion_attention_scale()
+ except AttributeError:
+ raise RuntimeError("Couldn't set motion attention scale - was this pipeline initialized with the right UNet?")
self.unet.to(device=device, dtype=dtype)
def run_safety_checker(
@@ -778,6 +860,8 @@ def run_safety_checker(
"""
Override parent run_safety_checker to make sure safety checker is aligned
"""
+ if self.safety_checking_disabled:
+ return (output, [False] * len(output)) # Disabled after being enabled (likely temporary)
if self.safety_checker is not None:
self.safety_checker.to(device)
return super(EnfugueStableDiffusionPipeline, self).run_safety_checker(output, device, dtype) # type: ignore[misc]
@@ -786,21 +870,20 @@ def load_ip_adapter(
self,
device: Union[str, torch.device],
scale: float = 1.0,
- use_fine_grained: bool = False,
- use_face_model: bool = False,
- keepalive_callback: Optional[Callable[[], None]] = None
+ model: Optional[IP_ADAPTER_LITERAL]=None,
+ keepalive_callback: Optional[Callable[[], None]]=None
) -> None:
"""
Loads the IP Adapter
"""
if getattr(self, "ip_adapter", None) is None:
raise RuntimeError("Pipeline does not have an IP adapter")
+
if self.ip_adapter_loaded:
altered = self.ip_adapter.set_scale( # type: ignore[union-attr]
unet=self.unet,
- new_scale=scale,
- use_fine_grained=use_fine_grained,
- use_face_model=use_face_model,
+ scale=scale,
+ model=model,
keepalive_callback=keepalive_callback,
is_sdxl=self.is_sdxl,
controlnets=self.controlnets
@@ -809,22 +892,20 @@ def load_ip_adapter(
logger.error("IP adapter appeared loaded, but setting scale did not modify it.")
self.ip_adapter.load( # type: ignore[union-attr]
unet=self.unet,
- is_sdxl=self.is_sdxl,
scale=scale,
- use_fined_grained=use_fine_grained,
- use_face_model=use_face_model,
+ model=model,
keepalive_callback=keepalive_callback,
+ is_sdxl=self.is_sdxl,
controlnets=self.controlnets
)
else:
logger.debug("Loading IP adapter")
self.ip_adapter.load( # type: ignore[union-attr]
self.unet,
- is_sdxl=self.is_sdxl,
scale=scale,
+ model=model,
keepalive_callback=keepalive_callback,
- use_fine_grained=use_fine_grained,
- use_face_model=use_face_model,
+ is_sdxl=self.is_sdxl,
controlnets=self.controlnets
)
self.ip_adapter_loaded = True
@@ -834,7 +915,7 @@ def unload_ip_adapter(self) -> None:
Unloads the IP adapter by resetting attention processors to previous values
"""
if getattr(self, "ip_adapter", None) is None:
- raise RuntimeError("Pipeline does not have an IP adapter")
+ return
if self.ip_adapter_loaded:
logger.debug("Unloading IP adapter")
self.ip_adapter.unload(self.unet, self.controlnets) # type: ignore[union-attr]
@@ -843,7 +924,7 @@ def unload_ip_adapter(self) -> None:
def get_image_embeds(
self,
image: PIL.Image.Image,
- num_images_per_prompt: int
+ num_results_per_prompt: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Uses the IP adapter to get prompt embeddings from the image
@@ -852,17 +933,17 @@ def get_image_embeds(
raise RuntimeError("Pipeline does not have an IP adapter")
image_prompt_embeds, image_uncond_prompt_embeds = self.ip_adapter.probe(image) # type: ignore[union-attr]
bs_embed, seq_len, _ = image_prompt_embeds.shape
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
- image_uncond_prompt_embeds = image_uncond_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- image_uncond_prompt_embeds = image_uncond_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_results_per_prompt, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_results_per_prompt, seq_len, -1)
+ image_uncond_prompt_embeds = image_uncond_prompt_embeds.repeat(1, num_results_per_prompt, 1)
+ image_uncond_prompt_embeds = image_uncond_prompt_embeds.view(bs_embed * num_results_per_prompt, seq_len, -1)
return image_prompt_embeds, image_uncond_prompt_embeds
def encode_prompt(
self,
prompt: Optional[str],
device: torch.device,
- num_images_per_prompt: int = 1,
+ num_results_per_prompt: int = 1,
do_classifier_free_guidance: bool = False,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
@@ -972,8 +1053,8 @@ def encode_prompt(
bs_embed, seq_len, _ = prompt_embeds.shape # type: ignore
# duplicate text embeddings for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) # type: ignore
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds = prompt_embeds.repeat(1, num_results_per_prompt, 1) # type: ignore
+ prompt_embeds = prompt_embeds.view(bs_embed * num_results_per_prompt, seq_len, -1)
if self.is_sdxl:
prompt_embeds_list.append(prompt_embeds)
@@ -1027,8 +1108,8 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) # type: ignore
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- negative_prompt_embeds = negative_prompt_embeds.view(num_images_per_prompt, seq_len, -1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_results_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(num_results_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -1041,11 +1122,11 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
if self.is_sdxl:
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
- bs_embed * num_images_per_prompt, -1
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_results_per_prompt).view(
+ bs_embed * num_results_per_prompt, -1
)
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
- bs_embed * num_images_per_prompt, -1
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_results_per_prompt).view(
+ bs_embed * num_results_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # type: ignore
return prompt_embeds # type: ignore
@@ -1054,10 +1135,10 @@ def encode_prompt(
def get_runtime_context(
self,
batch_size: int,
+ animation_frames: Optional[int],
device: Union[str, torch.device],
ip_adapter_scale: Optional[Union[List[float], float]] = None,
- ip_adapter_plus: bool = False,
- ip_adapter_face: bool = False,
+ ip_adapter_model: Optional[IP_ADAPTER_LITERAL] = None,
step_complete: Optional[Callable[[bool], None]] = None
) -> Iterator[None]:
"""
@@ -1069,8 +1150,7 @@ def get_runtime_context(
self.load_ip_adapter(
device=device,
scale=max(ip_adapter_scale) if isinstance(ip_adapter_scale, list) else ip_adapter_scale,
- use_fine_grained=ip_adapter_plus,
- use_face_model=ip_adapter_face,
+ model=ip_adapter_model,
keepalive_callback=None if step_complete is None else lambda: step_complete(False) # type: ignore[misc]
)
else:
@@ -1120,28 +1200,11 @@ def load_lora_weights(
Call the appropriate adapted fix based on pipeline class
"""
try:
- if self.is_sdxl:
- # Call SDXL fix
- return self.load_sdxl_lora_weights(
- pretrained_model_name_or_path_or_dict,
- multiplier=multiplier,
- dtype=dtype,
- **kwargs
- )
- elif (
- isinstance(pretrained_model_name_or_path_or_dict, str) and
- pretrained_model_name_or_path_or_dict.endswith(".safetensors")
- ):
- # Call safetensors fix
- return self.load_safetensors_lora_weights(
- pretrained_model_name_or_path_or_dict,
- multiplier=multiplier,
- dtype=dtype,
- **kwargs
- )
- # Return parent
- return super(EnfugueStableDiffusionPipeline, self).load_lora_weights( # type: ignore[misc]
- pretrained_model_name_or_path_or_dict, **kwargs
+ return self.load_flexible_lora_weights(
+ pretrained_model_name_or_path_or_dict,
+ multiplier=multiplier,
+ dtype=dtype,
+ **kwargs
)
except (AttributeError, KeyError) as ex:
if self.is_sdxl:
@@ -1164,29 +1227,61 @@ def load_sdxl_lora_weights(
unet_config=self.unet.config,
**kwargs,
)
- self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) # type: ignore[attr-defined]
+ self.load_lora_into_unet( # type: ignore[attr-defined]
+ state_dict,
+ network_alphas=network_alphas,
+ unet=self.unet,
+ _pipeline=self,
+ )
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
- if len(text_encoder_state_dict) > 0:
+ text_encoder_state_dict = dict([
+ (k, v)
+ for k, v in state_dict.items()
+ if "text_encoder." in k
+ ])
+ text_encoder_keys = len(text_encoder_state_dict)
+
+ if text_encoder_keys > 0:
+ logger.debug(f"Loading {text_encoder_keys} keys into primary text encoder with multiplier {multiplier}")
self.load_lora_into_text_encoder( # type: ignore[attr-defined]
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=multiplier,
+ _pipeline=self,
)
- text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
- if len(text_encoder_2_state_dict) > 0:
+ text_encoder_2_state_dict = dict([
+ (k, v)
+ for k, v in state_dict.items()
+ if "text_encoder_2." in k
+ ])
+ text_encoder_2_keys = len(text_encoder_2_state_dict)
+
+ if text_encoder_2_keys > 0:
+ logger.debug(f"Loading {text_encoder_2_keys} keys into secondary text encoder with multiplier {multiplier}")
self.load_lora_into_text_encoder( # type: ignore[attr-defined]
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=multiplier,
+ _pipeline=self,
)
- def load_safetensors_lora_weights(
+ def load_motion_lora_weights(
+ self,
+ state_dict: Dict[str, torch.Tensor],
+ multiplier: float = 1.0,
+ dtype: torch.dtype = torch.float32
+ ) -> None:
+ """
+ Don't do anything in base pipeline
+ """
+ logger.warning("Ignoring motion LoRA for non-animation pipeline")
+
+ def load_flexible_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
multiplier: float = 1.0,
@@ -1199,14 +1294,27 @@ def load_safetensors_lora_weights(
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
- # load LoRA weight from .safetensors
- state_dict = safetensors.torch.load_file(pretrained_model_name_or_path_or_dict, device="cpu") # type: ignore[arg-type]
+ state_dict = load_state_dict(pretrained_model_name_or_path_or_dict) # type: ignore[arg-type]
+ while "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"] # type: ignore[assignment]
+
+ if any(["motion_module" in key for key in state_dict.keys()]):
+ return self.load_motion_lora_weights(
+ state_dict, # type: ignore[arg-type]
+ multiplier=multiplier,
+ dtype=dtype
+ )
+ if self.is_sdxl:
+ return self.load_sdxl_lora_weights(
+ state_dict, # type: ignore[arg-type]
+ multiplier=multiplier,
+ dtype=dtype
+ )
updates: Mapping[str, Any] = defaultdict(dict)
for key, value in state_dict.items():
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
-
layer, elem = key.split(".", 1)
updates[layer][elem] = value
@@ -1410,16 +1518,27 @@ def create_latents(
dtype: torch.dtype,
device: Union[str, torch.device],
generator: Optional[torch.Generator] = None,
+ animation_frames: Optional[int] = None,
) -> torch.Tensor:
"""
Creates random latents of a particular shape and type.
"""
- shape = (
- batch_size,
- num_channels_latents,
- height // self.vae_scale_factor,
- width // self.vae_scale_factor,
- )
+ if not animation_frames:
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ else:
+ shape = ( # type: ignore[assignment]
+ batch_size,
+ num_channels_latents,
+ animation_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
logger.debug(f"Creating random latents of shape {shape} and type {dtype}")
random_latents = randn_tensor(
shape,
@@ -1443,6 +1562,8 @@ def encode_image_unchunked(
if self.config.force_full_precision_vae: # type: ignore[attr-defined]
self.vae.to(dtype=torch.float32)
image = image.float()
+ else:
+ image = image.to(dtype=self.vae.dtype)
latents = self.vae.encode(image).latent_dist.sample(generator) * self.vae.config.scaling_factor # type: ignore[attr-defined]
if self.config.force_full_precision_vae: # type: ignore[attr-defined]
self.vae.to(dtype=dtype)
@@ -1453,6 +1574,7 @@ def encode_image(
image: torch.Tensor,
device: Union[str, torch.device],
dtype: torch.dtype,
+ chunker: Chunker,
weight_builder: MaskWeightBuilder,
generator: Optional[torch.Generator] = None,
progress_callback: Optional[Callable[[bool], None]] = None,
@@ -1461,8 +1583,12 @@ def encode_image(
Encodes an image in chunks using the VAE.
"""
_, _, height, width = image.shape
- chunks = self.get_chunks(height, width)
- total_steps = len(chunks)
+
+ # Disable tiling during encoding
+ tile = chunker.tile
+ chunker.tile = False
+
+ total_steps = chunker.num_chunks
# Align device
self.vae.to(device)
@@ -1471,8 +1597,11 @@ def encode_image(
result = self.encode_image_unchunked(image, dtype, generator)
if progress_callback is not None:
progress_callback(True)
+ # Re-enable tiling if asked for
+ chunker.tile = tile
return result
+ chunks = chunker.chunks
logger.debug(f"Encoding image in {total_steps} steps.")
latent_height = height // self.vae_scale_factor
@@ -1491,36 +1620,39 @@ def encode_image(
else:
self.vae.to(dtype=image.dtype)
- for i, (top, bottom, left, right) in enumerate(chunks):
- top_px = top * self.vae_scale_factor
- bottom_px = bottom * self.vae_scale_factor
- left_px = left * self.vae_scale_factor
- right_px = right * self.vae_scale_factor
+ with weight_builder:
+ for i, ((top, bottom), (left, right)) in enumerate(chunker.chunks):
+ top_px = top * self.vae_scale_factor
+ bottom_px = bottom * self.vae_scale_factor
+ left_px = left * self.vae_scale_factor
+ right_px = right * self.vae_scale_factor
- image_view = image[:, :, top_px:bottom_px, left_px:right_px]
+ image_view = image[:, :, top_px:bottom_px, left_px:right_px]
- encoded_image = self.vae.encode(image_view).latent_dist.sample(generator).to(device)
+ encoded_image = self.vae.encode(image_view).latent_dist.sample(generator).to(device)
- # Build weights
- multiplier = weight_builder(
- mask_type=self.chunking_mask_type,
- batch=1,
- dim=num_channels,
- width=engine_latent_size,
- height=engine_latent_size,
- unfeather_left=left==0,
- unfeather_top=top==0,
- unfeather_right=right==latent_width,
- unfeather_bottom=bottom==latent_height,
- **self.chunking_mask_kwargs
- )
+ # Build weights
+ multiplier = weight_builder(
+ mask_type=self.tiling_mask_type,
+ batch=1,
+ dim=num_channels,
+ width=right-left,
+ height=bottom-top,
+ unfeather_left=left==0,
+ unfeather_top=top==0,
+ unfeather_right=right==latent_width,
+ unfeather_bottom=bottom==latent_height,
+ **self.tiling_mask_kwargs
+ )
- value[:, :, top:bottom, left:right] += encoded_image * multiplier
- count[:, :, top:bottom, left:right] += multiplier
+ value[:, :, top:bottom, left:right] += encoded_image * multiplier
+ count[:, :, top:bottom, left:right] += multiplier
- if progress_callback is not None:
- progress_callback(True)
+ if progress_callback is not None:
+ progress_callback(True)
+ # Re-enable tiling if asked for
+ chunker.tile = tile
if self.config.force_full_precision_vae: # type: ignore[attr-defined]
self.vae.to(dtype=dtype)
weight_builder.dtype = dtype
@@ -1528,70 +1660,87 @@ def encode_image(
def prepare_image_latents(
self,
- image: Union[torch.Tensor, PIL.Image.Image],
+ image: torch.Tensor,
timestep: torch.Tensor,
batch_size: int,
dtype: torch.dtype,
device: Union[str, torch.device],
+ chunker: Chunker,
weight_builder: MaskWeightBuilder,
generator: Optional[torch.Generator] = None,
progress_callback: Optional[Callable[[bool], None]] = None,
add_noise: bool = True,
+ animation_frames: Optional[int] = None
) -> torch.Tensor:
"""
Prepares latents from an image, adding initial noise for img2img inference
"""
- image = image.to(device=device, dtype=dtype)
- if image.shape[1] == 4:
- init_latents = image
- else:
- init_latents = self.encode_image(
- image,
- device=device,
- generator=generator,
- dtype=dtype,
- weight_builder=weight_builder,
- progress_callback=progress_callback
- )
+ def encoded() -> Iterator[torch.Tensor]:
+ for i in image:
+ if i.shape[1] == 4:
+ yield i
+ else:
+ yield self.encode_image(
+ image=i.to(dtype=dtype, device=device),
+ device=device,
+ generator=generator,
+ dtype=dtype,
+ chunker=chunker,
+ weight_builder=weight_builder,
+ progress_callback=progress_callback
+ )
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # these should all be [1, 4, h, w], collapse along batch dim
+ latents = torch.cat(list(encoded()), dim=0).to(dtype) # type: ignore[assignment]
+
+ if animation_frames:
+ # Change from collapsing on batch dim to temporal dim
+ latents = rearrange(latents, 't c h w -> c t h w').unsqueeze(0)
+
+ if batch_size > latents.shape[0] and batch_size % latents.shape[0] == 0:
# duplicate images to match batch size
- additional_image_per_prompt = batch_size // init_latents.shape[0]
- init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
- elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ additional_image_per_prompt = batch_size // latents.shape[0]
+ latents = torch.cat([latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > latents.shape[0] and batch_size % latents.shape[0] != 0:
raise ValueError(
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
)
- else:
- init_latents = torch.cat([init_latents], dim=0)
+
+ if animation_frames and animation_frames > latents.shape[2]:
+ # duplicate last image to match animation length
+ latents = torch.cat([
+ latents,
+ latents.repeat(1, 1, animation_frames - latents.shape[2], 1, 1)
+ ], dim=2)
# add noise in accordance with timesteps
if add_noise:
- shape = init_latents.shape
+ shape = latents.shape
noise = randn_tensor(
shape,
generator=generator,
device=torch.device(device) if isinstance(device, str) else device,
dtype=dtype
)
- return self.scheduler.add_noise(init_latents, noise, timestep) # type: ignore[attr-defined]
+ return self.scheduler.add_noise(latents, noise, timestep) # type: ignore[attr-defined]
else:
- logger.debug("Not adding noise; starting from noised image.")
- return init_latents
+ return latents
def prepare_mask_latents(
self,
- mask: Union[PIL.Image.Image, torch.Tensor],
- image: Union[PIL.Image.Image, torch.Tensor],
+ mask: torch.Tensor,
+ image: torch.Tensor,
batch_size: int,
height: int,
width: int,
dtype: torch.dtype,
device: Union[str, torch.device],
+ chunker: Chunker,
weight_builder: MaskWeightBuilder,
generator: Optional[torch.Generator] = None,
do_classifier_free_guidance: bool = False,
progress_callback: Optional[Callable[[bool], None]] = None,
+ animation_frames: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepares both mask and image latents for inpainting
@@ -1599,28 +1748,46 @@ def prepare_mask_latents(
tensor_height = height // self.vae_scale_factor
tensor_width = width // self.vae_scale_factor
tensor_size = (tensor_height, tensor_width)
- mask = torch.nn.functional.interpolate(mask, size=tensor_size)
- mask = mask.to(device=device, dtype=dtype)
- image = image.to(device=device, dtype=dtype)
- latents = self.encode_image(
- image,
- device=device,
- generator=generator,
- dtype=dtype,
- weight_builder=weight_builder,
- progress_callback=progress_callback,
- ).to(device=device)
+ mask_latents = torch.Tensor().to(device)
+ latents = torch.Tensor().to(device)
+
+ if mask.shape[0] != image.shape[0]:
+ # Should have been fixed by now, raise value error
+ raise ValueError("Mask and image should be the same length.")
+
+ for m, i in zip(mask, image):
+ m = torch.nn.functional.interpolate(m, size=tensor_size)
+ m = m.to(device=device, dtype=dtype)
+ mask_latents = torch.cat([mask_latents, m.unsqueeze(0)])
+
+ latents = torch.cat([
+ latents,
+ self.encode_image(
+ i,
+ device=device,
+ generator=generator,
+ dtype=dtype,
+ chunker=chunker,
+ weight_builder=weight_builder,
+ progress_callback=progress_callback
+ ).unsqueeze(0).to(device=device, dtype=dtype)
+ ])
+
+ if animation_frames:
+ latents = rearrange(latents, "t b c h w -> b c t h w")
+ mask_latents = rearrange(mask_latents, "t b c h w -> b c t h w")
# duplicate mask and latents for each generation per prompt, using mps friendly method
- if mask.shape[0] < batch_size:
- if not batch_size % mask.shape[0] == 0:
+ if mask_latents.shape[0] < batch_size:
+ if not batch_size % mask_latents.shape[0] == 0:
raise ValueError(
- "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
- f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
- " of masks that you pass is divisible by the total requested batch size."
+ "The passed mask_latents and the required batch size don't match. mask_latentss are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask_latents.shape[0]} mask_latentss were passed. Make sure the number"
+ " of mask_latentss that you pass is divisible by the total requested batch size."
)
- mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ repeat_axes = [1 for i in range(len(mask_latents.shape)-1)]
+ mask_latents = mask_latents.repeat(batch_size // mask_latents.shape[0], *repeat_axes)
if latents.shape[0] < batch_size:
if not batch_size % latents.shape[0] == 0:
raise ValueError(
@@ -1628,14 +1795,23 @@ def prepare_mask_latents(
f" to a total batch size of {batch_size}, but {latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
- latents = latents.repeat(batch_size // latents.shape[0], 1, 1, 1)
+ repeat_axes = [1 for i in range(len(latents.shape)-1)]
+ latents = latents.repeat(batch_size // latents.shape[0], *repeat_axes)
+
+ # Duplicate mask and latents to match animation length, using mps friendly method:
+ if animation_frames and mask_latents.shape[2] < animation_frames:
+ mask_latents = mask_latents.repeat(1, 1, animation_frames - mask_latents.shape[2], 1, 1)
+ if animation_frames and latents.shape[2] < animation_frames:
+ latents = latents.repeat(1, 1, animation_frames - latents.shape[2], 1, 1)
- mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ mask_latents = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# aligning device to prevent device errors when concating it with the latent model input
+ mask_latents = mask_latents.to(device=device, dtype=dtype)
latents = latents.to(device=device, dtype=dtype)
- return mask, latents
+
+ return mask_latents, latents
def get_timesteps(
self,
@@ -1738,6 +1914,7 @@ def predict_noise_residual(
kwargs = {}
if added_cond_kwargs is not None:
kwargs["added_cond_kwargs"] = added_cond_kwargs
+
return self.unet(
latents,
timestep,
@@ -1755,10 +1932,11 @@ def prepare_control_image(
width: int,
height: int,
batch_size: int,
- num_images_per_prompt: int,
+ num_results_per_prompt: int,
device: Union[str, torch.device],
dtype: torch.dtype,
do_classifier_free_guidance=False,
+ animation_frames: Optional[int] = None
):
"""
Prepares an image for controlnet conditioning.
@@ -1770,12 +1948,12 @@ def prepare_control_image(
if isinstance(image[0], PIL.Image.Image):
images = []
- for image_ in image:
- image_ = image_.convert("RGB")
- image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
- image_ = np.array(image_)
- image_ = image_[None, :]
- images.append(image_)
+ for i in image:
+ i = i.convert("RGB")
+ i = i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ i = np.array(i)
+ i = i[None, :]
+ images.append(i)
image = images
@@ -1786,13 +1964,17 @@ def prepare_control_image(
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
+ if animation_frames:
+ # Expand batch to frames
+ image = rearrange(image, 't c h w -> c t h w').unsqueeze(0)
+
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
- repeat_by = num_images_per_prompt
+ repeat_by = num_results_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
if do_classifier_free_guidance:
@@ -1821,74 +2003,6 @@ def prepare_controlnet_inpaint_control_image(
return image.to(device=device, dtype=dtype)
- def get_minimal_chunks(self, height: int, width: int) -> List[Tuple[int, int, int, int]]:
- """
- Gets the minimum chunks that cover the shape in multiples of the engine width
- """
- latent_height = height // self.vae_scale_factor
- latent_width = width // self.vae_scale_factor
-
- latent_window_size = self.engine_size // self.vae_scale_factor
- horizontal_blocks = math.ceil(latent_width / latent_window_size)
- vertical_blocks = math.ceil(latent_height / latent_window_size)
- total_blocks = vertical_blocks * horizontal_blocks
- chunks = []
-
- for i in range(total_blocks):
- top = (i // horizontal_blocks) * latent_window_size
- bottom = top + latent_window_size
-
- left = (i % horizontal_blocks) * latent_window_size
- right = left + latent_window_size
-
- if bottom > latent_height:
- offset = bottom - latent_height
- bottom -= offset
- top -= offset
- if right > latent_width:
- offset = right - latent_width
- right -= offset
- left -= offset
- chunks.append((top, bottom, left, right))
- return chunks
-
- def get_chunks(self, height: int, width: int) -> List[Tuple[int, int, int, int]]:
- """
- Gets the chunked latent indices for multidiffusion
- """
- latent_height = height // self.vae_scale_factor
- latent_width = width // self.vae_scale_factor
-
- if not self.chunking_size:
- return [(0, latent_height, 0, latent_width)]
-
- latent_chunking_size = self.chunking_size // self.vae_scale_factor
- latent_window_size = self.engine_size // self.vae_scale_factor
-
- vertical_blocks = math.ceil((latent_height - latent_window_size) / latent_chunking_size + 1)
- horizontal_blocks = math.ceil((latent_width - latent_window_size) / latent_chunking_size + 1)
- total_blocks = vertical_blocks * horizontal_blocks
- chunks = []
-
- for i in range(int(total_blocks)):
- top = (i // horizontal_blocks) * latent_chunking_size
- bottom = top + latent_window_size
- left = (i % horizontal_blocks) * latent_chunking_size
- right = left + latent_window_size
-
- if bottom > latent_height:
- offset = bottom - latent_height
- bottom -= offset
- top -= offset
- if right > latent_width:
- offset = right - latent_width
- right -= offset
- left -= offset
-
- chunks.append((top, bottom, left, right))
-
- return chunks
-
def get_controlnet_conditioning_blocks(
self,
device: Union[str, torch.device],
@@ -1903,15 +2017,30 @@ def get_controlnet_conditioning_blocks(
"""
if not controlnet_conds or not self.controlnets:
return None, None
+
+ is_animation = len(latents.shape) == 5
+ if is_animation:
+ batch, channels, frames, height, width = latents.shape
+ # Compress frames to batch
+ latent_input = rearrange(latents, "b c f h w -> (b f) c h w")
+ hidden_state_input = encoder_hidden_states.repeat_interleave(frames, dim=0)
+ else:
+ batch, channels, height, width = latents.shape
+ frames = None
+ latent_input = latents
+ hidden_state_input = encoder_hidden_states
+
down_blocks, mid_block = None, None
for name in controlnet_conds:
if self.controlnets.get(name, None) is None:
raise RuntimeError(f"Conditioning image requested ControlNet {name}, but it's not loaded.")
for controlnet_cond, conditioning_scale in controlnet_conds[name]:
+ if is_animation:
+ controlnet_cond = rearrange(controlnet_cond, "b c f h w -> (b f) c h w")
down_samples, mid_sample = self.controlnets[name](
- latents,
+ latent_input,
timestep,
- encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states=hidden_state_input,
controlnet_cond=controlnet_cond,
conditioning_scale=conditioning_scale,
added_cond_kwargs=added_cond_kwargs,
@@ -1925,6 +2054,13 @@ def get_controlnet_conditioning_blocks(
for previous_block, current_block in zip(down_blocks, down_samples)
]
mid_block += mid_sample
+ if is_animation and down_blocks is not None and mid_block is not None:
+ # Expand batch back to frames
+ down_blocks = [
+ rearrange(block, "(b f) c h w -> b c f h w", b=batch, f=frames)
+ for block in down_blocks
+ ]
+ mid_block = rearrange(mid_block, "(b f) c h w -> b c f h w", b=batch, f=frames)
return down_blocks, mid_block
def denoise_unchunked(
@@ -1935,18 +2071,17 @@ def denoise_unchunked(
num_inference_steps: int,
timesteps: torch.Tensor,
latents: torch.Tensor,
- prompt_embeds: torch.Tensor,
+ encoded_prompts: EncodedPrompts,
weight_builder: MaskWeightBuilder,
guidance_scale: float,
do_classifier_free_guidance: bool = False,
- is_inpainting_unet: bool = False,
mask: Optional[torch.Tensor] = None,
mask_image: Optional[torch.Tensor] = None,
image: Optional[torch.Tensor] = None,
control_images: PreparedControlImageArgType = None,
progress_callback: Optional[Callable[[bool], None]] = None,
latent_callback: Optional[Callable[[Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]], None]] = None,
- latent_callback_steps: Optional[int] = 1,
+ latent_callback_steps: Optional[int] = None,
latent_callback_type: Literal["latent", "pt", "np", "pil"] = "latent",
extra_step_kwargs: Optional[Dict[str, Any]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -1960,9 +2095,15 @@ def denoise_unchunked(
num_steps = len(timesteps)
num_warmup_steps = num_steps - num_inference_steps * self.scheduler.order # type: ignore[attr-defined]
-
+
+ if len(latents.shape) == 5:
+ samples, num_channels, num_frames, latent_height, latent_width = latents.shape
+ else:
+ samples, num_channels, latent_height, latent_width = latents.shape
+ num_frames = None
+
noise = None
- if mask is not None and mask_image is not None and not is_inpainting_unet:
+ if mask is not None and mask_image is not None and not self.is_inpainting_unet:
noise = latents.detach().clone() / self.scheduler.init_noise_sigma # type: ignore[attr-defined]
noise = noise.to(device=device)
@@ -1977,6 +2118,25 @@ def denoise_unchunked(
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # type: ignore[attr-defined]
+ # Get embeds
+ embeds = encoded_prompts.get_embeds()
+ if embeds is None:
+ logger.warning("No text embeds, using zeros")
+ if self.text_encoder:
+ embeds = torch.zeros(samples, 77, self.text_encoder.config.hidden_size).to(device)
+ elif self.text_encoder_2:
+ embeds = torch.zeros(samples, 77, self.text_encoder_2.config.hidden_size).to(device)
+ else:
+ raise IOError("No embeds and no text encoder.")
+ embeds = embeds.to(device=device)
+
+ # Get added embeds
+ add_text_embeds = encoded_prompts.get_add_text_embeds()
+ if add_text_embeds is not None:
+ if not added_cond_kwargs:
+ raise ValueError(f"Added condition arguments is empty, but received add text embeds. There should be time IDs prior to this point.")
+ added_cond_kwargs["text_embeds"] = add_text_embeds.to(device=device, dtype=embeds.dtype)
+
# Get controlnet input(s) if configured
if control_images is not None:
# Find which control image(s) to use
@@ -1996,6 +2156,7 @@ def denoise_unchunked(
controlnet_conds[controlnet_name] = []
controlnet_conds[controlnet_name].append((control_image, conditioning_scale))
+
if not controlnet_conds:
down_block, mid_block = None, None
else:
@@ -2003,7 +2164,7 @@ def denoise_unchunked(
device=device,
latents=latent_model_input,
timestep=t,
- encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states=embeds,
controlnet_conds=controlnet_conds,
added_cond_kwargs=added_cond_kwargs,
)
@@ -2011,7 +2172,7 @@ def denoise_unchunked(
down_block, mid_block = None, None
# add other dimensions to unet input if set
- if mask is not None and mask_image is not None and is_inpainting_unet:
+ if mask is not None and mask_image is not None and self.is_inpainting_unet:
latent_model_input = torch.cat(
[latent_model_input, mask, mask_image],
dim=1,
@@ -2019,13 +2180,13 @@ def denoise_unchunked(
# predict the noise residual
noise_pred = self.predict_noise_residual(
- latent_model_input,
- t,
- prompt_embeds,
- cross_attention_kwargs,
- added_cond_kwargs,
- down_block,
- mid_block,
+ latents=latent_model_input,
+ timestep=t,
+ embeddings=embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ down_block_additional_residuals=down_block,
+ mid_block_additional_residual=mid_block,
)
# perform guidance
@@ -2044,7 +2205,7 @@ def denoise_unchunked(
# If using mask and not using fine-tuned inpainting, then we calculate
# the same denoising on the image without unet and cross with the
# calculated unet input * mask
- if mask is not None and image is not None and not is_inpainting_unet:
+ if mask is not None and image is not None and not self.is_inpainting_unet:
init_latents = image[:1]
init_mask = mask[:1]
@@ -2079,10 +2240,17 @@ def denoise_unchunked(
device=device,
)
latent_callback_value = self.denormalize_latents(latent_callback_value)
- if latent_callback_type != "pt":
- latent_callback_value = self.image_processor.pt_to_numpy(latent_callback_value)
- if latent_callback_type == "pil":
- latent_callback_value = self.image_processor.numpy_to_pil(latent_callback_value)
+ if num_frames is not None:
+ output = [] # type: ignore[assignment]
+ for frame in self.decode_animation_frames(latent_callback_value):
+ output.extend(self.image_processor.numpy_to_pil(frame)) # type: ignore[attr-defined]
+ latent_callback_value = output # type: ignore[assignment]
+ else:
+ if latent_callback_type != "pt":
+ latent_callback_value = self.image_processor.pt_to_numpy(latent_callback_value)
+ if latent_callback_type == "pil":
+ latent_callback_value = self.image_processor.numpy_to_pil(latent_callback_value)
+
latent_callback(latent_callback_value)
return latents
@@ -2107,20 +2275,20 @@ def denoise(
width: int,
device: Union[str, torch.device],
num_inference_steps: int,
+ chunker: Chunker,
timesteps: torch.Tensor,
latents: torch.Tensor,
- prompt_embeds: torch.Tensor,
+ encoded_prompts: EncodedPrompts,
weight_builder: MaskWeightBuilder,
guidance_scale: float,
do_classifier_free_guidance: bool = False,
- is_inpainting_unet: bool = False,
mask: Optional[torch.Tensor] = None,
mask_image: Optional[torch.Tensor] = None,
image: Optional[torch.Tensor] = None,
control_images: PreparedControlImageArgType = None,
progress_callback: Optional[Callable[[bool], None]] = None,
latent_callback: Optional[Callable[[Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]], None]] = None,
- latent_callback_steps: Optional[int] = 1,
+ latent_callback_steps: Optional[int] = None,
latent_callback_type: Literal["latent", "pt", "np", "pil"] = "latent",
extra_step_kwargs: Optional[Dict[str, Any]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -2132,10 +2300,16 @@ def denoise(
if extra_step_kwargs is None:
extra_step_kwargs = {}
- chunks = self.get_chunks(height, width)
- num_chunks = len(chunks)
+ if len(latents.shape) == 5:
+ samples, num_channels, num_frames, latent_height, latent_width = latents.shape
+ else:
+ samples, num_channels, latent_height, latent_width = latents.shape
+ num_frames = None
+
+ num_chunks = chunker.num_chunks
+ num_temporal_chunks = chunker.num_frame_chunks
- if num_chunks <= 1:
+ if num_chunks <= 1 and num_temporal_chunks <= 1:
return self.denoise_unchunked(
height=height,
width=width,
@@ -2143,10 +2317,9 @@ def denoise(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
latents=latents,
- prompt_embeds=prompt_embeds,
+ encoded_prompts=encoded_prompts,
weight_builder=weight_builder,
guidance_scale=guidance_scale,
- is_inpainting_unet=is_inpainting_unet,
do_classifier_free_guidance=do_classifier_free_guidance,
mask=mask,
mask_image=mask_image,
@@ -2161,7 +2334,10 @@ def denoise(
added_cond_kwargs=added_cond_kwargs,
)
- chunk_scheduler_status = [self.get_scheduler_state()] * num_chunks
+ chunk_scheduler_status = []
+ for i in range(num_chunks * num_temporal_chunks):
+ chunk_scheduler_status.append(self.get_scheduler_state())
+
num_steps = len(timesteps)
num_warmup_steps = num_steps - num_inference_steps * self.scheduler.order # type: ignore[attr-defined]
@@ -2172,15 +2348,13 @@ def denoise(
count = torch.zeros_like(latents)
value = torch.zeros_like(latents)
- samples, num_channels, _, _ = latents.shape
-
- total_num_steps = num_steps * num_chunks
+ total_num_steps = num_steps * num_chunks * num_temporal_chunks
logger.debug(
- f"Denoising image in {total_num_steps} steps on {device} ({num_inference_steps} inference steps * {num_chunks} chunks)"
+ f"Denoising image in {total_num_steps} steps on {device} ({num_inference_steps} inference steps * {num_chunks} chunks * {num_temporal_chunks} temporal chunks)"
)
noise = None
- if mask is not None and mask_image is not None and not is_inpainting_unet:
+ if mask is not None and mask_image is not None and not self.is_inpainting_unet:
noise = latents.detach().clone() / self.scheduler.init_noise_sigma # type: ignore[attr-defined]
noise = noise.to(device=device)
@@ -2194,17 +2368,113 @@ def denoise(
value.zero_()
# iterate over chunks
- for j, (top, bottom, left, right) in enumerate(chunks):
+ for j, ((top, bottom), (left, right), (start, end)) in enumerate(chunker):
+ # Memoize wrap for later
+ wrap_x = right <= left
+ wrap_y = bottom <= top
+ wrap_t = start is not None and end is not None and end <= start
+
+ mask_width = (latent_width - left) + right if wrap_x else right - left
+ mask_height = (latent_height - top) + bottom if wrap_y else bottom - top
+ if num_frames is None or start is None or end is None:
+ mask_frames = None
+ else:
+ mask_frames = (num_frames - start) + end if wrap_t else end - start
+
+ # Define some helpers for chunked denoising
+ def slice_for_view(tensor: torch.Tensor, scale_factor: int = 1) -> torch.Tensor:
+ """
+ Copies and slices input tensors
+ """
+ left_idx = left * scale_factor
+ right_idx = right * scale_factor
+ top_idx = top * scale_factor
+ bottom_idx = bottom * scale_factor
+ height_idx = latent_height * scale_factor
+ width_idx = latent_width * scale_factor
+
+ tensor_for_view = torch.clone(tensor)
+
+ if wrap_x:
+ if num_frames is None:
+ tensor_for_view = torch.cat([tensor_for_view[:, :, :, left_idx:width_idx], tensor_for_view[:, :, :, :right_idx]], dim=3)
+ else:
+ tensor_for_view = torch.cat([tensor_for_view[:, :, :, :, left_idx:width_idx], tensor_for_view[:, :, :, :, :right_idx]], dim=4)
+ elif num_frames is None:
+ tensor_for_view = tensor_for_view[:, :, :, left_idx:right_idx]
+ else:
+ tensor_for_view = tensor_for_view[:, :, :, :, left_idx:right_idx]
+
+ if wrap_y:
+ if num_frames is None:
+ tensor_for_view = torch.cat([tensor_for_view[:, :, top_idx:height_idx, :], tensor_for_view[:, :, :bottom_idx, :]], dim=2)
+ else:
+ tensor_for_view = torch.cat([tensor_for_view[:, :, :, top_idx:height_idx, :], tensor_for_view[:, :, :, :bottom_idx, :]], dim=3)
+ elif num_frames is None:
+ tensor_for_view = tensor_for_view[:, :, top_idx:bottom_idx, :]
+ else:
+ tensor_for_view = tensor_for_view[:, :, :, top_idx:bottom_idx, :]
+
+ if wrap_t:
+ tensor_for_view = torch.cat([tensor_for_view[:, :, start:num_frames, :, :], tensor_for_view[:, :, :end, :, :]], dim=2)
+ elif num_frames is not None:
+ tensor_for_view = tensor_for_view[:, :, start:end, :, :]
+
+ return tensor_for_view
+
+ def fill_value(tensor: torch.Tensor, multiplier: torch.Tensor) -> None:
+ """
+ Fills the value and count tensors
+ """
+ nonlocal value, count
+ start_x = left
+ end_x = latent_width if wrap_x else right
+ initial_x = end_x - start_x
+ start_y = top
+ end_y = latent_height if wrap_y else bottom
+ initial_y = end_y - start_y
+ start_t = start
+ end_t = num_frames if wrap_t else end
+ initial_t = None if end_t is None or start_t is None else end_t - start_t
+
+ if num_frames is None:
+ value[:, :, start_y:end_y, start_x:end_x] += tensor[:, :, :initial_y, :initial_x]
+ count[:, :, start_y:end_y, start_x:end_x] += multiplier[:, :, :initial_y, :initial_x]
+ if wrap_x:
+ value[:, :, start_y:end_y, :right] += tensor[:, :, :initial_y, initial_x:]
+ count[:, :, start_y:end_y, :right] += multiplier[:, :, :initial_y, initial_x:]
+ if wrap_y:
+ value[:, :, :bottom, :right] += tensor[:, :, initial_y:, initial_x:]
+ count[:, :, :bottom, :right] += multiplier[:, :, initial_y:, initial_x:]
+ if wrap_y:
+ value[:, :, :bottom, start_x:end_x] += tensor[:, :, initial_y:, :initial_x]
+ count[:, :, :bottom, start_x:end_x] += multiplier[:, :, initial_y:, :initial_x]
+ else:
+ value[:, :, start_t:end_t, start_y:end_y, start_x:end_x] += tensor[:, :, :initial_t, :initial_y, :initial_x]
+ count[:, :, start_t:end_t, start_y:end_y, start_x:end_x] += multiplier[:, :, :initial_t, :initial_y, :initial_x]
+ if wrap_x:
+ value[:, :, start_t:end_t, start_y:end_y, :right] += tensor[:, :, :initial_t, :initial_y, initial_x:]
+ count[:, :, start_t:end_t, start_y:end_y, :right] += multiplier[:, :, :initial_t, :initial_y, initial_x:]
+ if wrap_y:
+ value[:, :, start_t:end_t, :bottom, :right] += tensor[:, :, :initial_t, initial_y:, initial_x:]
+ count[:, :, start_t:end_t, :bottom, :right] += multiplier[:, :, :initial_t, initial_y:, initial_x:]
+ if wrap_t:
+ value[:, :, :end, :bottom, :right] += tensor[:, :, initial_t:, initial_y:, initial_x:]
+ count[:, :, :end, :bottom, :right] += multiplier[:, :, initial_t:, initial_y:, initial_x:]
+ if wrap_y:
+ value[:, :, start_t:end_t, :bottom, start_x:end_x] += tensor[:, :, :initial_t, initial_y:, :initial_x]
+ count[:, :, start_t:end_t, :bottom, start_x:end_x] += multiplier[:, :, :initial_t, initial_y:, :initial_x]
+ if wrap_t:
+ value[:, :, :end, :bottom, start_x:end_x] += tensor[:, :, initial_t:, initial_y:, :initial_x]
+ count[:, :, :end, :bottom, start_x:end_x] += multiplier[:, :, initial_t:, initial_y:, :initial_x]
+ if wrap_t:
+ value[:, :, :end, start_y:end_y, start_x:end_x] += tensor[:, :, initial_t:, :initial_y, :initial_x]
+ count[:, :, :end, start_y:end_y, start_x:end_x] += multiplier[:, :, initial_t:, :initial_y, :initial_x]
+
# Wrap IndexError to give a nice error about MultiDiff w/ some schedulers
try:
- # Get pixel indices
- top_px = top * self.vae_scale_factor
- bottom_px = bottom * self.vae_scale_factor
- left_px = left * self.vae_scale_factor
- right_px = right * self.vae_scale_factor
-
# Slice latents
- latents_for_view = latents[:, :, top:bottom, left:right]
+ latents_for_view = slice_for_view(latents)
# expand the latents if we are doing classifier free guidance
latent_model_input = (
@@ -2217,6 +2487,32 @@ def denoise(
# Scale model input
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # type: ignore[attr-defined]
+ # Get embeds
+ if wrap_t and start is not None and num_frames is not None and end is not None:
+ frame_indexes = list(range(start,num_frames)) + list(range(end))
+ elif num_frames is not None and start is not None and end is not None:
+ frame_indexes = list(range(start,end))
+ else:
+ frame_indexes = None
+
+ embeds = encoded_prompts.get_embeds(frame_indexes)
+
+ if embeds is None:
+ logger.warning(f"Warning: no prompts found for frame window {frame_indexes}")
+ if self.text_encoder:
+ embeds = torch.zeros(samples, 77, self.text_encoder.config.hidden_size).to(device)
+ elif self.text_encoder_2:
+ embeds = torch.zeros(samples, 77, self.text_encoder_2.config.hidden_size).to(device)
+ else:
+ raise IOError("No embeds and no text encoder.")
+
+ # Get added embeds
+ add_text_embeds = encoded_prompts.get_add_text_embeds(frame_indexes)
+ if add_text_embeds is not None:
+ if not added_cond_kwargs:
+ raise ValueError(f"Added condition arguments is empty, but received add text embeds. There should be time IDs prior to this point.")
+ added_cond_kwargs["text_embeds"] = add_text_embeds.to(device=device, dtype=embeds.dtype)
+
# Get controlnet input(s) if configured
if control_images is not None:
# Find which control image(s) to use
@@ -2234,9 +2530,8 @@ def denoise(
):
if controlnet_name not in controlnet_conds:
controlnet_conds[controlnet_name] = []
-
controlnet_conds[controlnet_name].append((
- control_image[:, :, top_px:bottom_px, left_px:right_px],
+ slice_for_view(control_image, self.vae_scale_factor),
conditioning_scale
))
@@ -2247,7 +2542,7 @@ def denoise(
device=device,
latents=latent_model_input,
timestep=t,
- encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states=embeds,
controlnet_conds=controlnet_conds,
added_cond_kwargs=added_cond_kwargs,
)
@@ -2255,12 +2550,12 @@ def denoise(
down_block, mid_block = None, None
# add other dimensions to unet input if set
- if mask is not None and mask_image is not None and is_inpainting_unet:
+ if mask is not None and mask_image is not None and self.is_inpainting_unet:
latent_model_input = torch.cat(
[
latent_model_input,
- mask[:, :, top:bottom, left:right],
- mask_image[:, :, top:bottom, left:right],
+ slice_for_view(mask),
+ slice_for_view(mask_image),
],
dim=1,
)
@@ -2269,7 +2564,7 @@ def denoise(
noise_pred = self.predict_noise_residual(
latents=latent_model_input,
timestep=t,
- embeddings=prompt_embeds,
+ embeddings=embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block,
@@ -2282,12 +2577,13 @@ def denoise(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
- denoised_latents = self.scheduler.step( # type: ignore[attr-defined]
+ denoised_latents = self.scheduler.step( #type: ignore[attr-defined]
noise_pred,
t,
latents_for_view,
**extra_step_kwargs,
).prev_sample
+
except IndexError:
raise RuntimeError(f"Received IndexError during denoising. It is likely that the scheduler you are using ({type(self.scheduler).__name__}) does not work with Multi-Diffusion, and you should avoid using this when chunking is enabled.")
@@ -2297,15 +2593,15 @@ def denoise(
# If using mask and not using fine-tuned inpainting, then we calculate
# the same denoising on the image without unet and cross with the
# calculated unet input * mask
- if mask is not None and image is not None and noise is not None and not is_inpainting_unet:
- init_latents = (image[:, :, top:bottom, left:right])[:1]
- init_mask = (mask[:, :, top:bottom, left:right])[:1]
+ if mask is not None and image is not None and noise is not None and not self.is_inpainting_unet:
+ init_latents = (slice_for_view(image))[:1]
+ init_mask = (slice_for_view(mask))[:1]
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents = self.scheduler.add_noise( # type: ignore[attr-defined]
init_latents,
- noise[:, :, top:bottom, left:right],
+ slice_for_view(noise),
torch.tensor([noise_timestep])
)
@@ -2313,29 +2609,30 @@ def denoise(
# Build weights
multiplier = weight_builder(
- mask_type=self.chunking_mask_type,
+ mask_type=self.tiling_mask_type,
batch=samples,
dim=num_channels,
- width=engine_latent_size,
- height=engine_latent_size,
+ frames=mask_frames,
+ width=mask_width,
+ height=mask_height,
unfeather_left=left==0,
unfeather_top=top==0,
unfeather_right=right==latent_width,
unfeather_bottom=bottom==latent_height,
- **self.chunking_mask_kwargs
+ unfeather_start=False if num_frames is None else (start==0 and not chunker.loop),
+ unfeather_end=False if num_frames is None else (end==num_frames and not chunker.loop),
+ **self.tiling_mask_kwargs
)
- value[:, :, top:bottom, left:right] += denoised_latents * multiplier
- count[:, :, top:bottom, left:right] += multiplier
+ fill_value(denoised_latents * multiplier, multiplier)
- # Call the progress callback
if progress_callback is not None:
progress_callback(True)
# multidiffusion
latents = torch.where(count > 0, value / count, value)
- # call the latent_callback, if provided
+ # Call the latent callback, if provided
steps_since_last_callback += 1
if (
latent_callback is not None
@@ -2343,7 +2640,6 @@ def denoise(
and steps_since_last_callback >= latent_callback_steps
and (i == num_steps - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0)) # type: ignore[attr-defined]
):
- steps_since_last_callback = 0
latent_callback_value = latents
if latent_callback_type != "latent":
@@ -2353,10 +2649,17 @@ def denoise(
device=device,
)
latent_callback_value = self.denormalize_latents(latent_callback_value)
- if latent_callback_type != "pt":
- latent_callback_value = self.image_processor.pt_to_numpy(latent_callback_value)
- if latent_callback_type == "pil":
- latent_callback_value = self.image_processor.numpy_to_pil(latent_callback_value)
+ if num_frames is not None:
+ output = [] # type: ignore[assignment]
+ for frame in self.decode_animation_frames(latent_callback_value):
+ output.extend(self.image_processor.numpy_to_pil(frame)) # type: ignore[attr-defined]
+ latent_callback_value = output # type: ignore[assignment]
+ else:
+ if latent_callback_type != "pt":
+ latent_callback_value = self.image_processor.pt_to_numpy(latent_callback_value)
+ if latent_callback_type == "pil":
+ latent_callback_value = self.image_processor.numpy_to_pil(latent_callback_value)
+
latent_callback(latent_callback_value)
return latents
@@ -2372,8 +2675,14 @@ def decode_latent_preview(
Batches anything > 1024px (128 latent)
"""
from math import ceil
- batch = latents.shape[0]
- height, width = latents.shape[-2:]
+
+ shape = latents.shape
+ if len(shape) == 5:
+ batch, channels, frames, height, width = shape
+ else:
+ batch, channels, height, width = shape
+ frames = None
+
height_px = height * self.vae_scale_factor
width_px = width * self.vae_scale_factor
@@ -2381,54 +2690,69 @@ def decode_latent_preview(
overlap = 16
max_size_px = max_size * self.vae_scale_factor
- if height > max_size or width > max_size:
- # Do some chunking to avoid sharp lines, but don't follow global chunk
- width_chunks = ceil(width / (max_size - overlap))
- height_chunks = ceil(height / (max_size - overlap))
- decoded_preview = torch.zeros(
- (batch, 3, height*self.vae_scale_factor, width*self.vae_scale_factor),
- dtype=latents.dtype,
- device=device
- )
- multiplier = torch.zeros_like(decoded_preview)
- for i in range(height_chunks):
- start_h = max(0, i * (max_size - overlap))
- end_h = start_h + max_size
- if end_h > height:
- diff = end_h - height
- end_h -= diff
- start_h = max(0, start_h-diff)
- start_h_px = start_h * self.vae_scale_factor
- end_h_px = end_h * self.vae_scale_factor
- for j in range(width_chunks):
- start_w = max(0, j * (max_size - overlap))
- end_w = start_w + max_size
- if end_w > width:
- diff = end_w - width
- end_w -= diff
- start_w = max(0, start_w-diff)
- start_w_px = start_w * self.vae_scale_factor
- end_w_px = end_w * self.vae_scale_factor
- mask = weight_builder(
- mask_type="bilinear",
- batch=batch,
- dim=3,
- width=min(width_px, max_size_px),
- height=min(height_px, max_size_px),
- unfeather_left=start_w==0,
- unfeather_top=start_h==0,
- unfeather_right=end_w==width,
- unfeather_bottom=end_h==height,
- )
- decoded_view = self.vae_preview.decode(
- latents[:, :, start_h:end_h, start_w:end_w],
- return_dict=False
- )[0].to(device)
- decoded_preview[:, :, start_h_px:end_h_px, start_w_px:end_w_px] += decoded_view * mask
- multiplier[:, :, start_h_px:end_h_px, start_w_px:end_w_px] += mask
- return decoded_preview / multiplier
+ # Define function to decode a single frame
+ def decode_preview(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Decodes a single frame
+ """
+ if height > max_size or width > max_size:
+ # Do some chunking to avoid sharp lines, but don't follow global chunk
+ width_chunks = ceil(width / (max_size - overlap))
+ height_chunks = ceil(height / (max_size - overlap))
+ decoded_preview = torch.zeros(
+ (batch, 3, height*self.vae_scale_factor, width*self.vae_scale_factor),
+ dtype=tensor.dtype,
+ device=device
+ )
+ multiplier = torch.zeros_like(decoded_preview)
+ for i in range(height_chunks):
+ start_h = max(0, i * (max_size - overlap))
+ end_h = start_h + max_size
+ if end_h > height:
+ diff = end_h - height
+ end_h -= diff
+ start_h = max(0, start_h-diff)
+ start_h_px = start_h * self.vae_scale_factor
+ end_h_px = end_h * self.vae_scale_factor
+ for j in range(width_chunks):
+ start_w = max(0, j * (max_size - overlap))
+ end_w = start_w + max_size
+ if end_w > width:
+ diff = end_w - width
+ end_w -= diff
+ start_w = max(0, start_w-diff)
+ start_w_px = start_w * self.vae_scale_factor
+ end_w_px = end_w * self.vae_scale_factor
+ mask = weight_builder(
+ mask_type="bilinear",
+ batch=batch,
+ dim=3,
+ width=min(width_px, max_size_px),
+ height=min(height_px, max_size_px),
+ unfeather_left=start_w==0,
+ unfeather_top=start_h==0,
+ unfeather_right=end_w==width,
+ unfeather_bottom=end_h==height,
+ )
+ decoded_view = self.vae_preview.decode(
+ tensor[:, :, start_h:end_h, start_w:end_w],
+ return_dict=False
+ )[0].to(device)
+ decoded_preview[:, :, start_h_px:end_h_px, start_w_px:end_w_px] += decoded_view * mask
+ multiplier[:, :, start_h_px:end_h_px, start_w_px:end_w_px] += mask
+ return decoded_preview / multiplier
+ else:
+ return self.vae_preview.decode(tensor, return_dict=False)[0].to(device)
+
+ # If there are frames, decode them one at a time
+ if frames is not None:
+ decoded_frames = [
+ decode_preview(latents[:, :, i, :, :]).unsqueeze(2)
+ for i in range(frames)
+ ]
+ return torch.cat(decoded_frames, dim=2)
else:
- return self.vae_preview.decode(latents, return_dict=False)[0].to(device)
+ return decode_preview(latents)
def decode_latent_view(self, latents: torch.Tensor) -> torch.Tensor:
"""
@@ -2450,21 +2774,27 @@ def decode_latents(
self,
latents: torch.Tensor,
device: Union[str, torch.device],
+ chunker: Chunker,
weight_builder: MaskWeightBuilder,
- progress_callback: Optional[Callable[[bool], None]] = None,
+ progress_callback: Optional[Callable[[bool], None]]=None,
+ scale_latents: bool=True
) -> torch.Tensor:
"""
Decodes the latents in chunks as necessary.
"""
- samples, num_channels, height, width = latents.shape
+ if len(latents.shape) == 5:
+ samples, num_channels, num_frames, height, width = latents.shape
+ else:
+ samples, num_channels, height, width = latents.shape
+ num_frames = None
+
height *= self.vae_scale_factor
width *= self.vae_scale_factor
- latents = 1 / self.vae.config.scaling_factor * latents # type: ignore[attr-defined]
-
- chunks = self.get_chunks(height, width)
- total_steps = len(chunks)
+ if scale_latents:
+ latents = 1 / self.vae.config.scaling_factor * latents # type: ignore[attr-defined]
+ total_steps = chunker.num_chunks
revert_dtype = None
if self.config.force_full_precision_vae: # type: ignore[attr-defined]
@@ -2486,41 +2816,100 @@ def decode_latents(
latent_height = height // self.vae_scale_factor
engine_latent_size = self.engine_size // self.vae_scale_factor
- count = torch.zeros((samples, 3, height, width)).to(device=device, dtype=latents.dtype)
- value = torch.zeros_like(count)
+ if num_frames is None:
+ count = torch.zeros((samples, 3, height, width)).to(device=device, dtype=latents.dtype)
+ else:
+ count = torch.zeros((samples, 3, num_frames, height, width)).to(device=device, dtype=latents.dtype)
+ value = torch.zeros_like(count)
logger.debug(f"Decoding latents in {total_steps} steps")
- # iterate over chunks
- for i, (top, bottom, left, right) in enumerate(chunks):
- # Slice latents
- latents_for_view = latents[:, :, top:bottom, left:right]
+ for j, ((top, bottom), (left, right)) in enumerate(chunker.chunks):
+ # Memoize wrap for later
+ wrap_x = right <= left
+ wrap_y = bottom <= top
+
+ mask_width = ((latent_width - left) + right if wrap_x else right - left) * self.vae_scale_factor
+ mask_height = ((latent_height - top) + bottom if wrap_y else bottom - top) * self.vae_scale_factor
+
+ # Define some helpers for chunked denoising
+ def slice_for_view(tensor: torch.Tensor, scale_factor: int = 1) -> torch.Tensor:
+ """
+ Copies and slices input tensors
+ """
+ left_idx = left * scale_factor
+ right_idx = right * scale_factor
+ top_idx = top * scale_factor
+ bottom_idx = bottom * scale_factor
+ height_idx = latent_height * scale_factor
+ width_idx = latent_width * scale_factor
+ tensor_for_view = torch.clone(tensor)
+
+ if wrap_x:
+ tensor_for_view = torch.cat([tensor_for_view[:, :, :, left_idx:width_idx], tensor_for_view[:, :, :, :right_idx]], dim=3)
+ else:
+ tensor_for_view = tensor_for_view[:, :, :, left_idx:right_idx]
+
+ if wrap_y:
+ tensor_for_view = torch.cat([tensor_for_view[:, :, top_idx:height_idx, :], tensor_for_view[:, :, :bottom_idx, :]], dim=2)
+ else:
+ tensor_for_view = tensor_for_view[:, :, top_idx:bottom_idx, :]
+
+ return tensor_for_view
+
+ def fill_value(tensor: torch.Tensor, multiplier: torch.Tensor) -> None:
+ """
+ Fills the value and count tensors
+ """
+ nonlocal value, count
+ start_x = left
+ end_x = latent_width if wrap_x else right
+ start_x *= self.vae_scale_factor
+ end_x *= self.vae_scale_factor
+ initial_x = end_x - start_x
+ right_px = right * self.vae_scale_factor
+
+ start_y = top
+ end_y = latent_height if wrap_y else bottom
+ start_y *= self.vae_scale_factor
+ end_y *= self.vae_scale_factor
+ initial_y = end_y - start_y
+ bottom_px = bottom * self.vae_scale_factor
+
+ value[:, :, start_y:end_y, start_x:end_x] += tensor[:, :, :initial_y, :initial_x]
+ count[:, :, start_y:end_y, start_x:end_x] += multiplier[:, :, :initial_y, :initial_x]
+ if wrap_x:
+ value[:, :, start_y:end_y, :right_px] += tensor[:, :, :initial_y, initial_x:]
+ count[:, :, start_y:end_y, :right_px] += multiplier[:, :, :initial_y, initial_x:]
+ if wrap_y:
+ value[:, :, :bottom_px, :right_px] += tensor[:, :, initial_y:, initial_x:]
+ count[:, :, :bottom_px, :right_px] += multiplier[:, :, initial_y:, initial_x:]
+ if wrap_y:
+ value[:, :, :bottom_px, start_x:end_x] += tensor[:, :, initial_y:, :initial_x]
+ count[:, :, :bottom_px, start_x:end_x] += multiplier[:, :, initial_y:, :initial_x]
- # Get pixel indices
- top_px = top * self.vae_scale_factor
- bottom_px = bottom * self.vae_scale_factor
- left_px = left * self.vae_scale_factor
- right_px = right * self.vae_scale_factor
+ # Slice latents
+ latents_for_view = slice_for_view(latents)
# Decode latents
decoded_latents = self.decode_latent_view(latents_for_view).to(device=device)
# Build weights
multiplier = weight_builder(
- mask_type=self.chunking_mask_type,
+ mask_type=self.tiling_mask_type,
batch=samples,
dim=3,
- width=self.engine_size,
- height=self.engine_size,
+ frames=None,
+ width=mask_width,
+ height=mask_height,
unfeather_left=left==0,
unfeather_top=top==0,
unfeather_right=right==latent_width,
unfeather_bottom=bottom==latent_height,
- **self.chunking_mask_kwargs
+ **self.tiling_mask_kwargs
)
- value[:, :, top_px:bottom_px, left_px:right_px] += decoded_latents * multiplier
- count[:, :, top_px:bottom_px, left_px:right_px] += multiplier
+ fill_value(decoded_latents * multiplier, multiplier)
if progress_callback is not None:
progress_callback(True)
@@ -2530,8 +2919,29 @@ def decode_latents(
if revert_dtype is not None:
latents = latents.to(dtype=revert_dtype)
self.vae.to(dtype=revert_dtype)
+
return latents
+ def decode_animation_frames(
+ self,
+ videos: torch.Tensor,
+ n_rows: int = 8,
+ rescale: bool = False
+ ) -> List[np.ndarray]:
+ """
+ Decode an animation
+ """
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (255 - (x * 255)).cpu().numpy().astype(np.uint8)
+ outputs.append(x)
+ return outputs
+
def prepare_extra_step_kwargs(
self,
generator: Optional[torch.Generator],
@@ -2573,7 +2983,7 @@ def step_complete(increment_step: bool = True) -> None:
seconds_in_window = (datetime.datetime.now() - window_start).total_seconds()
its = (overall_step - window_start_step) / seconds_in_window
unit = "s/it" if its < 1 else "it/s"
- its_display = 1 / its if its < 1 else its
+ its_display = 0 if its == 0 else 1 / its if its < 1 else its
logger.debug(
f"{{0:0{digits}d}}/{{1:0{digits}d}}: {{2:0.2f}} {{3:s}}".format(
overall_step, overall_steps, its_display, unit
@@ -2589,56 +2999,214 @@ def step_complete(increment_step: bool = True) -> None:
return step_complete
+ def standardize_image(
+ self,
+ image: Optional[Union[ImageArgType, torch.Tensor]]=None,
+ animation_frames: Optional[int]=None,
+ ) -> Optional[Union[torch.Tensor, List[PIL.Image.Image]]]:
+ """
+ Standardizes image args to list
+ """
+ if image is None or isinstance(image, torch.Tensor):
+ return image
+ if not isinstance(image, list):
+ image = [image]
+
+ images = []
+ for img in image:
+ if isinstance(img, str):
+ img = self.open_image(img)
+ if isinstance(img, list):
+ images.extend(img)
+ else:
+ images.append(img)
+
+ if animation_frames:
+ image_len = len(images)
+ if image_len < animation_frames:
+ images += [
+ images[image_len-1]
+ for i in range(animation_frames - image_len)
+ ]
+ else:
+ images = images[:1]
+
+ return images
+
+ def standardize_ip_adapter_images(
+ self,
+ images: ImagePromptArgType=None,
+ animation_frames: Optional[int]=None,
+ ) -> Optional[List[Tuple[List[PIL.Image.Image], float]]]:
+ """
+ Standardizes IP adapter args to list
+ """
+ if not images:
+ return None
+
+ if not isinstance(images, list):
+ images = [images]
+
+ ip_adapter_tuples = []
+
+ for image in images:
+ if isinstance(image, tuple):
+ img, scale = image
+ elif isinstance(image, dict):
+ img = image["image"]
+ scale = float(image["scale"])
+ elif isinstance(image, str):
+ img = self.open_image(img)
+ scale = 1.0
+ else:
+ img = image
+ scale = 1.0
+
+ if not isinstance(img, list):
+ img = [img]
+
+ if animation_frames:
+ image_len = len(img)
+ if image_len < animation_frames:
+ img += [
+ img[image_len-1]
+ for i in range(animation_frames - image_len)
+ ]
+ else:
+ img = img[:animation_frames]
+ else:
+ img = img[:1]
+
+ ip_adapter_tuples.append((img, scale))
+
+ return ip_adapter_tuples
+
+ def standardize_control_images(
+ self,
+ control_images: ControlImageArgType=None,
+ animation_frames: Optional[int]=None,
+ ) -> Optional[
+ Dict[
+ str,
+ List[Tuple[List[PIL.Image.Image], float, Optional[float], Optional[float]]]
+ ]
+ ]:
+ """
+ Standardizes control images to dict of list of tuple
+ """
+ if control_images is None:
+ return None
+
+ standardized: Dict[str, List[Tuple[List[PIL.Image.Image], float, Optional[float], Optional[float]]]] = {}
+
+ for name in control_images:
+ if name not in self.controlnets: # type: ignore[operator]
+ raise RuntimeError(f"Control image mapped to ControlNet {name}, but it is not loaded.")
+
+ standardized[name] = []
+
+ image_list = control_images[name]
+ if not isinstance(image_list, list):
+ image_list = [image_list]
+
+ for controlnet_image in image_list:
+ if isinstance(controlnet_image, tuple):
+ if len(controlnet_image) == 4:
+ controlnet_image, conditioning_scale, conditioning_start, conditioning_end = controlnet_image
+ elif len(controlnet_image) == 3:
+ controlnet_image, conditioning_scale, conditioning_start = controlnet_image
+ conditioning_end = 1.0
+ elif len(controlnet_image) == 2:
+ controlnet_image, conditioning_scale = controlnet_image
+ conditioning_start, conditioning_end = None, None
+
+ elif isinstance(controlnet_image, dict):
+ conditioning_scale = controlnet_image.get("scale", 1.0)
+ conditioning_start = controlnet_image.get("start", None)
+ conditioning_end = controlnet_image.get("end", None)
+ controlnet_image = controlnet_image["image"]
+ else:
+ conditioning_scale = 1.0
+ conditioning_start, conditioning_end = None, None
+
+ if isinstance(controlnet_image, str):
+ controlnet_image = self.open_image(controlnet_image)
+
+ if not isinstance(controlnet_image, list):
+ controlnet_image = [controlnet_image]
+
+ if animation_frames:
+ image_len = len(controlnet_image)
+ if image_len < animation_frames:
+ controlnet_image += [
+ controlnet_image[image_len-1]
+ for i in range(animation_frames - image_len)
+ ]
+
+ standardized[name].append((
+ controlnet_image,
+ conditioning_scale,
+ conditioning_start,
+ conditioning_end
+ ))
+ return standardized
+
@torch.no_grad()
def __call__(
self,
- device: Optional[Union[str, torch.device]] = None,
- offload_models: bool = False,
- prompt: Optional[str] = None,
- prompt_2: Optional[str] = None,
- negative_prompt: Optional[str] = None,
- negative_prompt_2: Optional[str] = None,
- clip_skip: Optional[int] = None,
- freeu_factors: Optional[Tuple[float, float, float, float]] = None,
- image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
- mask: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None,
- control_images: ControlImageArgType = None,
- ip_adapter_images: ImagePromptArgType = None,
- ip_adapter_plus: bool = False,
- ip_adapter_face: bool = False,
- height: Optional[int] = None,
- width: Optional[int] = None,
- chunking_size: Optional[int] = None,
- denoising_start: Optional[float] = None,
- denoising_end: Optional[float] = None,
- strength: Optional[float] = 0.8,
- num_inference_steps: int = 40,
- guidance_scale: float = 7.5,
- num_images_per_prompt: int = 1,
- eta: float = 0.0,
- generator: Optional[torch.Generator] = None,
- noise_generator: Optional[torch.Generator] = None,
- latents: Optional[torch.Tensor] = None,
- prompt_embeds: Optional[torch.Tensor] = None,
- negative_prompt_embeds: Optional[torch.Tensor] = None,
- output_type: Literal["latent", "pt", "np", "pil"] = "pil",
- return_dict: bool = True,
- scale_image: bool = True,
- progress_callback: Optional[Callable[[int, int, float], None]] = None,
- latent_callback: Optional[Callable[[Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]], None]] = None,
- latent_callback_steps: Optional[int] = None,
- latent_callback_type: Literal["latent", "pt", "np", "pil"] = "latent",
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- original_size: Optional[Tuple[int, int]] = None,
- crops_coords_top_left: Tuple[int, int] = (0, 0),
- target_size: Optional[Tuple[int, int]] = None,
- aesthetic_score: float = 6.0,
- negative_aesthetic_score: float = 2.5,
- chunking_mask_type: Optional[MASK_TYPE_LITERAL] = None,
- chunking_mask_kwargs: Optional[Dict[str, Any]] = None,
- noise_offset: Optional[float] = None,
- noise_method: NOISE_METHOD_LITERAL = "perlin",
- noise_blend_method: LATENT_BLEND_METHOD_LITERAL = "inject",
+ device: Optional[Union[str, torch.device]]=None,
+ offload_models: bool=False,
+ prompt: Optional[str]=None,
+ prompt_2: Optional[str]=None,
+ negative_prompt: Optional[str]=None,
+ negative_prompt_2: Optional[str]=None,
+ prompts: Optional[List[Prompt]]=None,
+ image: Optional[Union[ImageArgType, torch.Tensor]]=None,
+ mask: Optional[Union[ImageArgType, torch.Tensor]]=None,
+ clip_skip: Optional[int]=None,
+ freeu_factors: Optional[Tuple[float, float, float, float]]=None,
+ control_images: ControlImageArgType=None,
+ ip_adapter_images: ImagePromptArgType=None,
+ ip_adapter_model: Optional[IP_ADAPTER_LITERAL]=None,
+ height: Optional[int]=None,
+ width: Optional[int]=None,
+ tiling_size: Optional[int]=None,
+ tiling_stride: Optional[int]=None,
+ frame_window_size: Optional[int]=None,
+ frame_window_stride: Optional[int]=None,
+ denoising_start: Optional[float]=None,
+ denoising_end: Optional[float]=None,
+ strength: Optional[float]=0.8,
+ num_inference_steps: int=40,
+ guidance_scale: float=7.5,
+ num_results_per_prompt: int=1,
+ animation_frames: Optional[int]=None,
+ motion_scale: Optional[float]=None,
+ loop: bool=False,
+ tile: Union[bool, Tuple[bool, bool]]=False,
+ eta: float=0.0,
+ generator: Optional[torch.Generator]=None,
+ noise_generator: Optional[torch.Generator]=None,
+ latents: Optional[torch.Tensor]=None,
+ prompt_embeds: Optional[torch.Tensor]=None,
+ negative_prompt_embeds: Optional[torch.Tensor]=None,
+ output_type: Literal["latent", "pt", "np", "pil"]="pil",
+ return_dict: bool=True,
+ progress_callback: Optional[Callable[[int, int, float], None]]=None,
+ latent_callback: Optional[Callable[[Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]], None]]=None,
+ latent_callback_steps: Optional[int]=None,
+ latent_callback_type: Literal["latent", "pt", "np", "pil"]="latent",
+ cross_attention_kwargs: Optional[Dict[str, Any]]=None,
+ original_size: Optional[Tuple[int, int]]=None,
+ crops_coords_top_left: Tuple[int, int]=(0, 0),
+ target_size: Optional[Tuple[int, int]]=None,
+ aesthetic_score: float=6.0,
+ negative_aesthetic_score: float=2.5,
+ tiling_mask_type: Optional[MASK_TYPE_LITERAL]=None,
+ tiling_mask_kwargs: Optional[Dict[str, Any]]=None,
+ noise_offset: Optional[float]=None,
+ noise_method: NOISE_METHOD_LITERAL="perlin",
+ noise_blend_method: LATENT_BLEND_METHOD_LITERAL="inject",
) -> Union[
StableDiffusionPipelineOutput,
Tuple[Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]], Optional[List[bool]]],
@@ -2646,26 +3214,43 @@ def __call__(
"""
Invokes the pipeline.
"""
+ # 0. Standardize arguments
+ image = self.standardize_image(
+ image,
+ animation_frames=animation_frames
+ )
+ mask = self.standardize_image(
+ mask,
+ animation_frames=animation_frames
+ )
+ control_images = self.standardize_control_images( # type: ignore[assignment]
+ control_images,
+ animation_frames=animation_frames
+ )
+
+ if ip_adapter_images is not None:
+ ip_adapter_images = self.standardize_ip_adapter_images(
+ ip_adapter_images,
+ animation_frames=animation_frames
+ )
+ ip_adapter_scale = max([scale for _, scale in ip_adapter_images]) # type: ignore[union-attr]
+ else:
+ ip_adapter_scale = None
+
# 1. Default height and width to image or unet config
if not height:
- if image is not None:
- if isinstance(image, str):
- image = PIL.Image.open(image)
- if isinstance(image, PIL.Image.Image):
- _, height = image.size
- else:
- height = image.shape[-2] * self.vae_scale_factor
+ if isinstance(image, list):
+ _, height = image[0].size
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[-2] * self.vae_scale_factor
else:
height = self.unet.config.sample_size * self.vae_scale_factor # type: ignore[attr-defined]
if not width:
- if image is not None:
- if isinstance(image, str):
- image = PIL.Image.open(image)
- if isinstance(image, PIL.Image.Image):
- width, _ = image.size
- else:
- width = image.shape[-1] * self.vae_scale_factor
+ if isinstance(image, list):
+ width, _ = image[0].size
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[-1] * self.vae_scale_factor
else:
width = self.unet.config.sample_size * self.vae_scale_factor # type: ignore[attr-defined]
@@ -2676,38 +3261,25 @@ def __call__(
original_size = original_size or (height, width)
target_size = target_size or (height, width)
- # Allow overridding chunking variables
- if chunking_size is not None:
- self.chunking_size = chunking_size
- if chunking_mask_type is not None:
- self.chunking_mask_type = chunking_mask_type
- if chunking_mask_kwargs is not None:
- self.chunking_mask_kwargs = chunking_mask_kwargs
-
- # Check latent callback steps, disable if 0 or maximum offloading set
+ # Allow overridding tiling variables
+ if tiling_size:
+ self.tiling_size = tiling_size
+ if tiling_stride is not None:
+ self.tiling_stride = tiling_stride
+ if tiling_mask_type is not None:
+ self.tiling_mask_type = tiling_mask_type
+ if tiling_mask_kwargs is not None:
+ self.tiling_mask_kwargs = tiling_mask_kwargs
+ if frame_window_size is not None:
+ self.frame_window_size = frame_window_size
+ if frame_window_stride is not None:
+ self.frame_window_stride = frame_window_stride
+
+ # Check 0/None
if latent_callback_steps == 0:
latent_callback_steps = None
-
- # Standardize IP adapter tuples
- ip_adapter_tuples: Optional[List[Tuple[PIL.Image.Image, float]]] = None
- if ip_adapter_images:
- ip_adapter_tuples = []
- for ip_adapter_argument in (ip_adapter_images if isinstance(ip_adapter_images, list) else [ip_adapter_images]):
- if isinstance(ip_adapter_argument, dict):
- ip_adapter_tuples.append((
- ip_adapter_argument["image"],
- ip_adapter_argument.get("scale", 1.0)
- ))
- elif isinstance(ip_adapter_argument, list):
- ip_adapter_tuples.append(tuple(ip_adapter_argument)) # type: ignore[arg-type]
- elif isinstance(ip_adapter_argument, tuple):
- ip_adapter_tuples.append(ip_adapter_argument) # type: ignore[arg-type]
- else:
- ip_adapter_tuples.append((ip_adapter_argument, 1.0))
- ip_adapter_scale = max([scale for img, scale in ip_adapter_tuples])
- else:
- ip_adapter_tuples = None
- ip_adapter_scale = None
+ if animation_frames == 0:
+ animation_frames = None
# Convenient bool for later
decode_intermediates = latent_callback_steps is not None and latent_callback is not None
@@ -2716,18 +3288,15 @@ def __call__(
prepared_latents: Optional[torch.Tensor] = None
output_nsfw: Optional[List[bool]] = None
- # Determine dimensionality
- is_inpainting_unet = self.unet.config.in_channels == 9 # type: ignore[attr-defined]
-
# Define call parameters
- if prompt is not None:
- batch_size = 1
- elif prompt_embeds:
+ if prompt_embeds:
batch_size = prompt_embeds.shape[0]
else:
- raise ValueError("Prompt or prompt embeds are required.")
+ batch_size = 1
+ if prompt is None and prompts is None:
+ prompt = "high-quality, best quality, aesthetically pleasing" # Good luck!
- if is_inpainting_unet:
+ if self.is_inpainting_unet:
if image is None:
logger.warning("No image present, but using inpainting model. Adding blank image.")
image = PIL.Image.new("RGB", (width, height))
@@ -2743,10 +3312,24 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0
# Calculate chunks
- num_chunks = max(1, len(self.get_chunks(height, width)))
+ chunker = Chunker(
+ height=height,
+ width=width,
+ size=self.tiling_size if self.tiling_size else 1024 if self.is_sdxl else 512,
+ stride=self.tiling_stride,
+ frames=animation_frames,
+ frame_size=self.frame_window_size,
+ frame_stride=self.frame_window_stride,
+ loop=loop,
+ tile=tile,
+ )
+
+ num_chunks = chunker.num_chunks
+ num_temporal_chunks = chunker.num_frame_chunks
+
self.scheduler.set_timesteps(num_inference_steps, device=device) # type: ignore[attr-defined]
- if image is not None and mask is None and (strength is not None or denoising_start is not None):
+ if image is not None and (strength is not None or denoising_start is not None):
# Scale timesteps by strength
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
@@ -2756,7 +3339,7 @@ def __call__(
else:
timesteps = self.scheduler.timesteps # type: ignore[attr-defined]
- batch_size *= num_images_per_prompt
+ batch_size *= num_results_per_prompt
num_scheduled_inference_steps = len(timesteps)
# Calculate end of steps if we aren't going all the way
@@ -2775,27 +3358,48 @@ def __call__(
encoding_steps = 0
decoding_steps = 1
+ # Open images if they're files
+ if isinstance(image, str):
+ image = self.open_image(image)
+ if isinstance(mask, str):
+ mask = self.open_image(mask)
+
if image is not None and type(image) is not torch.Tensor:
- encoding_steps += 1
+ if isinstance(image, list):
+ encoding_steps += len(image)
+ else:
+ encoding_steps += 1
if mask is not None and type(mask) is not torch.Tensor:
- encoding_steps += 1
- if not is_inpainting_unet:
+ if isinstance(mask, list):
+ encoding_steps += len(mask)
+ else:
encoding_steps += 1
+ if not self.is_inpainting_unet:
+ if isinstance(image, list):
+ encoding_steps +=len(image)
+ else:
+ encoding_steps += 1
+ if ip_adapter_images is not None:
+ image_prompt_probes = sum([
+ len(images) for images, scale in ip_adapter_images
+ ])
+ else:
+ image_prompt_probes = 0
- chunk_plural = "s" if num_chunks != 1 else ""
- step_plural = "s" if num_scheduled_inference_steps != 1 else ""
- encoding_plural = "s" if encoding_steps != 1 else ""
- decoding_plural = "s" if decoding_steps != 1 else ""
- overall_num_steps = num_chunks * (num_scheduled_inference_steps + encoding_steps + decoding_steps)
+ num_frames = 1 if not animation_frames else animation_frames
+ overall_num_steps = num_chunks * (encoding_steps + (decoding_steps * num_frames) + (num_scheduled_inference_steps * num_temporal_chunks)) + image_prompt_probes
logger.debug(
" ".join([
- f"Calculated overall steps to be {overall_num_steps}",
- f"[{num_chunks} chunk{chunk_plural} * ({num_scheduled_inference_steps} inference step{step_plural}",
- f"+ {encoding_steps} encoding step{encoding_plural} + {decoding_steps} decoding step{decoding_plural})]"
+ f"Calculated overall steps to be {overall_num_steps} -",
+ f"{image_prompt_probes} image prompt embedding probe(s) +",
+ f"[{num_chunks} chunk(s) * ({encoding_steps} encoding step(s) + ({decoding_steps} decoding step(s) * {num_frames} frame(s)) +",
+ f"({num_temporal_chunks} temporal chunk(s) * {num_scheduled_inference_steps} inference step(s))]"
])
)
+
+ # Create a callback which gets passed to stepped functions
step_complete = self.get_step_complete_callback(overall_num_steps, progress_callback)
-
+
if self.config.force_full_precision_vae: # type: ignore[attr-defined]
logger.debug(f"Configuration indicates VAE must be used in full precision")
# make sure the VAE is in float32 mode, as it overflows in float16
@@ -2806,137 +3410,205 @@ def __call__(
with self.get_runtime_context(
batch_size=batch_size,
+ animation_frames=animation_frames,
device=device,
ip_adapter_scale=ip_adapter_scale,
- ip_adapter_plus=ip_adapter_plus,
- ip_adapter_face=ip_adapter_face,
+ ip_adapter_model=ip_adapter_model,
step_complete=step_complete
):
- if self.is_sdxl:
- # XL uses more inputs for prompts than 1.5
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = self.encode_prompt(
- prompt,
- device,
- num_images_per_prompt,
- do_classifier_free_guidance,
- negative_prompt,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- prompt_2=prompt_2,
- negative_prompt_2=negative_prompt_2,
- clip_skip=clip_skip
+ # First standardize to list of prompts
+ if prompts is None:
+ prompts = [
+ Prompt(
+ positive=prompt,
+ positive_2=prompt_2,
+ negative=negative_prompt,
+ negative_2=negative_prompt_2,
+ start=None,
+ end=None,
+ weight=None
+ )
+ ]
+
+ encoded_prompt_list = []
+ # Iterate over given prompts and encode
+ for given_prompt in prompts:
+ if self.is_sdxl:
+ # XL uses more inputs for prompts than 1.5
+ (
+ these_prompt_embeds,
+ these_negative_prompt_embeds,
+ these_pooled_prompt_embeds,
+ these_negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ given_prompt.positive,
+ device,
+ num_results_per_prompt,
+ do_classifier_free_guidance,
+ given_prompt.negative,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_2=given_prompt.positive_2,
+ negative_prompt_2=given_prompt.negative_2
+ )
+ else:
+ these_prompt_embeds = self.encode_prompt(
+ given_prompt.positive,
+ device,
+ num_results_per_prompt,
+ do_classifier_free_guidance,
+ given_prompt.negative,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_2=given_prompt.positive_2,
+ negative_prompt_2=given_prompt.negative_2
+ ) # type: ignore
+ these_pooled_prompt_embeds = None
+ these_negative_prompt_embeds = None
+ these_negative_pooled_prompt_embeds = None
+
+ encoded_prompt_list.append(
+ EncodedPrompt(
+ prompt=given_prompt,
+ embeds=these_prompt_embeds, # type: ignore[arg-type]
+ negative_embeds=these_negative_prompt_embeds,
+ pooled_embeds=these_pooled_prompt_embeds,
+ negative_pooled_embeds=these_negative_pooled_prompt_embeds
+ )
)
- else:
- prompt_embeds = self.encode_prompt(
- prompt,
- device,
- num_images_per_prompt,
- do_classifier_free_guidance,
- negative_prompt,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- prompt_2=prompt_2,
- negative_prompt_2=negative_prompt_2,
- clip_skip=clip_skip
- ) # type: ignore
- pooled_prompt_embeds = None
- negative_prompt_embeds = None
- negative_pooled_prompt_embeds = None
-
- # Open images if they're files
- if isinstance(image, str):
- image = PIL.Image.open(image)
-
- if isinstance(mask, str):
- mask = PIL.Image.open(mask)
-
- # Scale images if requested
- if scale_image and isinstance(image, PIL.Image.Image):
- image_width, image_height = image.size
- if image_width != width or image_height != height:
- logger.debug(f"Resizing input image from {image_width}x{image_height} to {width}x{height}")
- image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
-
- if scale_image and mask:
- mask_width, mask_height = mask.size
- if mask_width != width or mask_height != height:
- logger.debug(f"Resizing input mask from {mask_width}x{mask_height} to {width}x{height}")
- mask = mask.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
-
- # Remove any alpha mask on image, convert mask to grayscale
- if isinstance(image, PIL.Image.Image):
- image = image.convert("RGB")
- if isinstance(mask, PIL.Image.Image):
- mask = mask.convert("L")
- if isinstance(image, PIL.Image.Image) and isinstance(mask, PIL.Image.Image):
- if is_inpainting_unet:
- prepared_mask, prepared_image = self.prepare_mask_and_image(mask, image, False) # type: ignore
- init_image = None
+
+ encoded_prompts = EncodedPrompts(
+ prompts=encoded_prompt_list,
+ is_sdxl=self.is_sdxl,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ image_prompt_embeds=None, # Will be set later
+ image_uncond_prompt_embeds=None # Will be set later
+ )
+
+ # Remove any alpha mask on image, convert mask to grayscale, align tensors
+ if image is not None:
+ if isinstance(image, torch.Tensor):
+ image = image.to(device=device, dtype=encoded_prompts.dtype)
else:
- prepared_mask, prepared_image, init_image = self.prepare_mask_and_image(mask, image, True) # type: ignore
- elif image is not None and mask is None:
- prepared_image = self.image_processor.preprocess(image)
- prepared_mask, prepared_image_latents, init_image = None, None, None
- else:
- prepared_mask, prepared_image, prepared_image_latents, init_image = None, None, None, None
+ image = [img.convert("RGB") for img in image]
+ if mask is not None:
+ if isinstance(mask, torch.Tensor):
+ mask = mask.to(device=device, dtype=encoded_prompts.dtype)
+ else:
+ mask = [img.convert("L") for img in mask]
+
+ # Repeat images as necessary to get the same size
+ image_length = max([
+ 0 if image is None else len(image),
+ 0 if mask is None else len(mask),
+ ])
- if width < self.engine_size or height < self.engine_size:
- # Disable chunking
- logger.debug(f"{width}x{height} is smaller than is chunkable, disabling.")
- self.chunking_size = 0
+ if image is not None and not isinstance(image, torch.Tensor):
+ l = len(image)
+ for i in range(image_length - l):
+ image.append(image[-1])
+ if mask is not None and not isinstance(mask, torch.Tensor):
+ l = len(mask)
+ for i in range(image_length - l):
+ mask.append(mask[-1])
+
+ # Process image and mask or image
+ prepared_image: Optional[torch.Tensor] = None
+ prepared_mask: Optional[torch.Tensor] = None
+ init_image: Optional[torch.Tensor] = None
+
+ if image is not None and mask is not None:
+ prepared_image = torch.Tensor()
+ prepared_mask = torch.Tensor()
+ init_image = torch.Tensor()
+
+ for m, i in zip(mask, image):
+ p_m, p_i, i_i = self.prepare_mask_and_image(m, i, True) # type: ignore
+ prepared_mask = torch.cat([prepared_mask, p_m.unsqueeze(0)])
+ prepared_image = torch.cat([prepared_image, p_i.unsqueeze(0)])
+ init_image = torch.cat([init_image, i_i.unsqueeze(0)])
- # No longer none
- prompt_embeds = cast(torch.Tensor, prompt_embeds)
+ elif image is not None and mask is None:
+ if isinstance(image, torch.Tensor):
+ prepared_image = image.unsqueeze(0)
+ else:
+ prepared_image = torch.Tensor()
+ for i in image:
+ prepared_image = torch.cat([
+ prepared_image,
+ self.image_processor.preprocess(i).unsqueeze(0)
+ ])
# Build the weight builder
- weight_builder = MaskWeightBuilder(device=device, dtype=prompt_embeds.dtype)
+ weight_builder = MaskWeightBuilder(
+ device=device,
+ dtype=encoded_prompts.dtype
+ )
+
with weight_builder:
if prepared_image is not None and prepared_mask is not None:
# Inpainting
num_channels_latents = self.vae.config.latent_channels # type: ignore[attr-defined]
if latents:
- prepared_latents = latents.to(device) * self.scheduler.init_noise_sigma # type: ignore[attr-defined]
+ prepared_latents = latents.to(device) * self.schedule.init_noise_sigma # type: ignore[attr-defined]
else:
- prepared_latents = self.create_latents(
- batch_size,
- num_channels_latents,
- height,
- width,
- prompt_embeds.dtype,
- device,
- generator,
- )
+ if strength is not None and strength < 1.0:
+ prepared_latents = self.prepare_image_latents(
+ image=init_image.to(device=device), # type: ignore[union-attr]
+ timestep=timesteps[:1].repeat(batch_size),
+ batch_size=batch_size,
+ dtype=encoded_prompts.dtype,
+ device=device,
+ chunker=chunker,
+ weight_builder=weight_builder,
+ generator=generator,
+ progress_callback=step_complete,
+ add_noise=denoising_start is None,
+ animation_frames=animation_frames
+ )
+ else:
+ prepared_latents = self.create_latents(
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ encoded_prompts.dtype,
+ device,
+ generator,
+ animation_frames=animation_frames
+ )
prepared_mask, prepared_image_latents = self.prepare_mask_latents(
- mask=prepared_mask,
- image=prepared_image,
+ mask=prepared_mask.to(device=device),
+ image=prepared_image.to(device=device),
batch_size=batch_size,
height=height,
width=width,
- dtype=prompt_embeds.dtype,
+ dtype=encoded_prompts.dtype,
device=device,
- weight_builder=weight_builder,
+ chunker=chunker,
generator=generator,
+ weight_builder=weight_builder,
do_classifier_free_guidance=do_classifier_free_guidance,
progress_callback=step_complete,
+ animation_frames=animation_frames
)
if init_image is not None:
- init_image = init_image.to(device=device, dtype=prompt_embeds.dtype)
- init_image = self.encode_image(
- init_image,
+ init_image = self.prepare_image_latents(
+ image=init_image.to(device=device),
+ timestep=timesteps[:1].repeat(batch_size),
+ batch_size=batch_size,
+ dtype=encoded_prompts.dtype,
device=device,
- dtype=prompt_embeds.dtype,
+ chunker=chunker,
+ weight_builder=weight_builder,
generator=generator,
- weight_builder=weight_builder
+ progress_callback=step_complete,
+ add_noise=False,
+ animation_frames=animation_frames
)
-
# prepared_latents = noise or init latents + noise
# prepared_mask = only mask
# prepared_image_latents = masked image
@@ -2944,34 +3616,39 @@ def __call__(
elif prepared_image is not None and strength is not None:
# img2img
prepared_latents = self.prepare_image_latents(
- image=prepared_image,
+ image=prepared_image.to(device=device),
timestep=timesteps[:1].repeat(batch_size),
batch_size=batch_size,
- dtype=prompt_embeds.dtype,
+ dtype=encoded_prompts.dtype,
device=device,
+ chunker=chunker,
weight_builder=weight_builder,
generator=generator,
progress_callback=step_complete,
add_noise=denoising_start is None,
+ animation_frames=animation_frames
)
+ prepared_image_latents = None # Don't need to store these separately
# prepared_latents = img + noise
elif latents:
prepared_latents = latents.to(device) * self.scheduler.init_noise_sigma # type: ignore[attr-defined]
# prepared_latents = passed latents + noise
else:
# txt2img
+ prepared_image_latents = None
prepared_latents = self.create_latents(
batch_size=batch_size,
num_channels_latents=self.unet.config.in_channels, # type: ignore[attr-defined]
height=height,
width=width,
- dtype=prompt_embeds.dtype,
+ dtype=encoded_prompts.dtype,
device=device,
generator=generator,
+ animation_frames=animation_frames
)
# prepared_latents = noise
- # Look for controlnet and conditioning image
+ # Look for controlnet and conditioning image, prepare
prepared_control_images: PreparedControlImageArgType = {}
if control_images is not None:
if not self.controlnets:
@@ -2979,50 +3656,20 @@ def __call__(
prepared_control_images = None
else:
for name in control_images:
- if name not in self.controlnets:
- raise RuntimeError(f"Control image mapped to ControlNet {name}, but it is not loaded.")
-
- image_list = control_images[name]
- if not isinstance(image_list, list):
- image_list = [image_list]
-
- for controlnet_image in image_list:
- if isinstance(controlnet_image, tuple):
- if len(controlnet_image) == 4:
- controlnet_image, conditioning_scale, conditioning_start, conditioning_end = controlnet_image
- elif len(controlnet_image) == 3:
- controlnet_image, conditioning_scale, conditioning_start = controlnet_image
- conditioning_end = None
- elif len(controlnet_image) == 2:
- controlnet_image, conditioning_scale = controlnet_image
- conditioning_start, conditioning_end = None, None
-
- elif isinstance(controlnet_image, dict):
- conditioning_scale = controlnet_image.get("scale", 1.0)
- conditioning_start = controlnet_image.get("start", None)
- conditioning_end = controlnet_image.get("end", None)
- controlnet_image = controlnet_image["image"]
- else:
- conditioning_scale = 1.0
- conditioning_start, conditioning_end = None, None
-
- if isinstance(controlnet_image, str):
- controlnet_image = PIL.Image.open(controlnet_image)
-
+ prepared_control_images[name] = [] # type: ignore[index]
+ for controlnet_image, conditioning_scale, conditioning_start, conditioning_end in control_images[name]: # type: ignore
prepared_controlnet_image = self.prepare_control_image(
image=controlnet_image,
height=height,
width=width,
batch_size=batch_size,
- num_images_per_prompt=num_images_per_prompt,
+ num_results_per_prompt=num_results_per_prompt,
device=device,
- dtype=prompt_embeds.dtype,
+ dtype=encoded_prompts.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
+ animation_frames=animation_frames
)
- if name not in prepared_control_images: # type: ignore[operator]
- prepared_control_images[name] = [] # type: ignore[index]
-
prepared_control_images[name].append( # type: ignore[index]
(prepared_controlnet_image, conditioning_scale, conditioning_start, conditioning_end)
)
@@ -3030,69 +3677,77 @@ def __call__(
# Should no longer be None
prepared_latents = cast(torch.Tensor, prepared_latents)
+ # Check if we need to cut multi-images
+ if not animation_frames:
+ if prepared_mask is not None:
+ prepared_mask = prepared_mask[:, 0]
+ if prepared_image_latents is not None:
+ prepared_image_latents = prepared_image_latents[:, 0]
+
# Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # Get prompt embeds here if using ip adapter
- if ip_adapter_tuples is not None:
- image_prompt_embeds = torch.Tensor().to(
+ # Get prompt embeds here if using IP adapter
+ if ip_adapter_images is not None:
+ logger.debug(f"Performing {image_prompt_probes} image prompt probe(s)")
+ ip_adapter_image_embeds = torch.Tensor().to(
device=device,
- dtype=prompt_embeds.dtype
+ dtype=encoded_prompts.dtype
)
- image_uncond_prompt_embeds = torch.Tensor().to(
+ ip_adapter_image_uncond_embeds = torch.Tensor().to(
device=device,
- dtype=prompt_embeds.dtype
+ dtype=encoded_prompts.dtype
)
- for img, scale in ip_adapter_tuples:
- these_prompt_embeds, these_uncond_prompt_embeds = self.get_image_embeds(
- img,
- num_images_per_prompt
+ for images, scale in ip_adapter_images:
+ image_prompt_embeds = torch.Tensor().to(
+ device=device,
+ dtype=encoded_prompts.dtype
)
-
- image_prompt_embeds = torch.cat([
- image_prompt_embeds,
- (these_prompt_embeds * scale / ip_adapter_scale)
- ], dim=1)
-
- image_uncond_prompt_embeds = torch.cat([
- image_uncond_prompt_embeds,
- (these_uncond_prompt_embeds * scale / ip_adapter_scale)
- ], dim=1)
-
- if self.is_sdxl:
- negative_prompt_embeds = cast(torch.Tensor, negative_prompt_embeds)
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, image_uncond_prompt_embeds], dim=1)
- else:
- if do_classifier_free_guidance:
- _negative_prompt_embeds, _prompt_embeds = prompt_embeds.chunk(2)
- else:
- _negative_prompt_embeds, _prompt_embeds = negative_prompt_embeds, prompt_embeds # type: ignore
- prompt_embeds = torch.cat([_prompt_embeds, image_prompt_embeds], dim=1)
- negative_prompt_embeds = torch.cat([_negative_prompt_embeds, image_uncond_prompt_embeds], dim=1)
- if do_classifier_free_guidance:
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
-
- if ip_adapter_images:
- del ip_adapter_images
-
- # Prepared added time IDs and embeddings (SDXL)
+ image_uncond_prompt_embeds = torch.Tensor().to(
+ device=device,
+ dtype=encoded_prompts.dtype
+ )
+ for img in images:
+ image_embeds, uncond_embeds = self.get_image_embeds(
+ img,
+ num_results_per_prompt
+ )
+ step_complete(True)
+ image_prompt_embeds = torch.cat([
+ image_prompt_embeds,
+ image_embeds.unsqueeze(0)
+ ], dim=0)
+ image_uncond_prompt_embeds = torch.cat([
+ image_uncond_prompt_embeds,
+ uncond_embeds.unsqueeze(0)
+ ], dim=0)
+
+ image_prompt_embeds *= scale / ip_adapter_scale # type: ignore[operator]
+ image_uncond_prompt_embeds *= scale / ip_adapter_scale # type: ignore[operator]
+
+ ip_adapter_image_embeds = torch.cat([
+ ip_adapter_image_embeds,
+ image_prompt_embeds.unsqueeze(0)
+ ], dim=0)
+ ip_adapter_image_uncond_embeds = torch.cat([
+ ip_adapter_image_uncond_embeds,
+ image_uncond_prompt_embeds.unsqueeze(0)
+ ], dim=0)
+
+ # Assign to helper data class
+ encoded_prompts.image_prompt_embeds = ip_adapter_image_embeds
+ encoded_prompts.image_uncond_prompt_embeds = ip_adapter_image_uncond_embeds
+
+ # Prepared added time IDs (SDXL)
+ added_cond_kwargs: Optional[Dict[str, Any]] = None
if self.is_sdxl:
- negative_prompt_embeds = cast(torch.Tensor, negative_prompt_embeds)
- pooled_prompt_embeds = cast(torch.Tensor, pooled_prompt_embeds)
- negative_pooled_prompt_embeds = cast(torch.Tensor, negative_pooled_prompt_embeds)
- add_text_embeds = pooled_prompt_embeds
-
- if do_classifier_free_guidance:
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
-
+ added_cond_kwargs = {}
if self.config.requires_aesthetic_score: # type: ignore[attr-defined]
add_time_ids, add_neg_time_ids = self.get_add_time_ids(
original_size=original_size,
crops_coords_top_left=crops_coords_top_left,
target_size=target_size,
- dtype=prompt_embeds.dtype,
+ dtype=encoded_prompts.dtype,
aesthetic_score=aesthetic_score,
negative_aesthetic_score=negative_aesthetic_score,
)
@@ -3103,21 +3758,22 @@ def __call__(
original_size=original_size,
crops_coords_top_left=crops_coords_top_left,
target_size=target_size,
- dtype=prompt_embeds.dtype,
+ dtype=encoded_prompts.dtype,
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
-
- prompt_embeds = prompt_embeds.to(device)
- add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
- else:
- added_cond_kwargs = None
+ added_cond_kwargs["time_ids"] = add_time_ids
+
+ # Make sure controlnet on device
+ if self.controlnets is not None:
+ for name in self.controlnets:
+ self.controlnets[name].to(device=device)
# Unload VAE, and maybe preview VAE
self.vae.to("cpu")
if decode_intermediates:
- self.vae_preview.to(device)
+ self.vae_preview.to(device, dtype=encoded_prompts.dtype)
+
empty_cache()
# Inject noise
@@ -3129,6 +3785,7 @@ def __call__(
noise_latents = make_noise(
batch_size=prepared_latents.shape[0],
channels=prepared_latents.shape[1],
+ animation_frames=animation_frames,
height=height // self.vae_scale_factor,
width=width // self.vae_scale_factor,
generator=noise_generator,
@@ -3143,16 +3800,13 @@ def __call__(
method=noise_blend_method
)
- # Make sure controlnet on device
- if self.controlnets is not None:
- for name in self.controlnets:
- self.controlnets[name].to(device=device)
-
# Make sure unet is on device
self.align_unet(
device=device,
- dtype=prompt_embeds.dtype,
+ dtype=encoded_prompts.dtype,
freeu_factors=freeu_factors,
+ animation_frames=animation_frames,
+ motion_scale=motion_scale,
offload_models=offload_models
) # May be overridden by RT
@@ -3162,12 +3816,13 @@ def __call__(
width=width,
device=device,
num_inference_steps=num_inference_steps,
+ chunker=chunker,
+ weight_builder=weight_builder,
timesteps=timesteps,
latents=prepared_latents,
- prompt_embeds=prompt_embeds,
+ encoded_prompts=encoded_prompts,
guidance_scale=guidance_scale,
do_classifier_free_guidance=do_classifier_free_guidance,
- is_inpainting_unet=is_inpainting_unet,
mask=prepared_mask,
mask_image=prepared_image_latents,
image=init_image,
@@ -3179,7 +3834,6 @@ def __call__(
extra_step_kwargs=extra_step_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
- weight_builder=weight_builder
)
# Clear no longer needed tensors
@@ -3195,16 +3849,16 @@ def __call__(
# Unload UNet to free memory
if offload_models:
self.unet.to("cpu")
+ empty_cache()
- # Empty caches for more memory
- empty_cache()
+ # Load VAE if decoding
if output_type != "latent":
self.vae.to(
dtype=torch.float32 if self.config.force_full_precision_vae else prepared_latents.dtype, # type: ignore[attr-defined]
device=device
)
if self.is_sdxl:
- use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [
+ use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ # type: ignore[union-attr]
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
@@ -3215,34 +3869,38 @@ def __call__(
if not use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(prepared_latents.dtype)
self.vae.decoder.conv_in.to(prepared_latents.dtype)
- self.vae.decoder.mid_block.to(prepared_latents.dtype)
+ self.vae.decoder.mid_block.to(prepared_latents.dtype) # type: ignore
else:
prepared_latents = prepared_latents.float()
+ if output_type == "latent":
+ output = prepared_latents
+ else:
prepared_latents = self.decode_latents(
prepared_latents,
device=device,
+ chunker=chunker,
progress_callback=step_complete,
weight_builder=weight_builder
)
+ if not animation_frames:
+ output = self.denormalize_latents(prepared_latents)
+ if output_type != "pt":
+ output = self.image_processor.pt_to_numpy(output)
+ output_nsfw = self.run_safety_checker(output, device, encoded_prompts.dtype)[1] # type: ignore[arg-type]
+ if output_type == "pil":
+ output = self.image_processor.numpy_to_pil(output)
+ else:
+ output = [] # type: ignore[assignment]
+ for frame in self.decode_animation_frames(prepared_latents):
+ output.extend(self.image_processor.numpy_to_pil(frame)) # type: ignore[attr-defined]
- if self.config.force_full_precision_vae: # type: ignore[attr-defined]
- self.vae.to(dtype=prepared_latents.dtype)
-
- if output_type == "latent":
- output = prepared_latents
- else:
if offload_models:
# Offload VAE again
self.vae.to("cpu")
self.vae_preview.to("cpu")
- empty_cache()
- output = self.denormalize_latents(prepared_latents)
- if output_type != "pt":
- output = self.image_processor.pt_to_numpy(output)
- output_nsfw = self.run_safety_checker(output, device, prompt_embeds.dtype)[1]# type: ignore[arg-type]
- if output_type == "pil":
- output = self.image_processor.numpy_to_pil(output)
+ elif self.config.force_full_precision_vae: #type: ignore[attr-defined]
+ self.vae.to(dtype=prepared_latents.dtype)
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
diff --git a/src/python/enfugue/diffusion/plan.py b/src/python/enfugue/diffusion/plan.py
deleted file mode 100644
index 95da1965..00000000
--- a/src/python/enfugue/diffusion/plan.py
+++ /dev/null
@@ -1,2271 +0,0 @@
-from __future__ import annotations
-
-import io
-import os
-import sys
-import PIL
-import PIL.Image
-import PIL.ImageDraw
-import PIL.ImageOps
-import math
-
-from random import randint
-
-from PIL.PngImagePlugin import PngInfo
-
-from typing import (
- Optional,
- Dict,
- Any,
- Union,
- Tuple,
- List,
- Callable,
- Iterator,
- TYPE_CHECKING,
-)
-from typing_extensions import (
- TypedDict,
- NotRequired
-)
-
-from pibble.util.strings import get_uuid, Serializer
-
-from enfugue.util import (
- logger,
- feather_mask,
- fit_image,
- images_are_equal,
- TokenMerger,
- IMAGE_FIT_LITERAL,
- IMAGE_ANCHOR_LITERAL,
-)
-
-if TYPE_CHECKING:
- from enfugue.diffusers.manager import DiffusionPipelineManager
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
- from enfugue.diffusion.constants import (
- SCHEDULER_LITERAL,
- CONTROLNET_LITERAL,
- UPSCALE_LITERAL,
- MASK_TYPE_LITERAL,
- NOISE_METHOD_LITERAL,
- LATENT_BLEND_METHOD_LITERAL,
- )
-
-DEFAULT_SIZE = 512
-DEFAULT_IMAGE_CALLBACK_STEPS = 10
-DEFAULT_CONDITIONING_SCALE = 1.0
-DEFAULT_IMG2IMG_STRENGTH = 0.8
-DEFAULT_INFERENCE_STEPS = 40
-DEFAULT_GUIDANCE_SCALE = 7.5
-DEFAULT_UPSCALE_PROMPT = "highly detailed, ultra-detailed, intricate detail, high definition, HD, 4k, 8k UHD"
-DEFAULT_UPSCALE_INFERENCE_STEPS = 100
-DEFAULT_UPSCALE_GUIDANCE_SCALE = 12
-DEFAULT_UPSCALE_CHUNKING_SIZE = 128
-
-DEFAULT_REFINER_START = 0.85
-DEFAULT_REFINER_STRENGTH = 0.3
-DEFAULT_REFINER_GUIDANCE_SCALE = 5.0
-DEFAULT_AESTHETIC_SCORE = 6.0
-DEFAULT_NEGATIVE_AESTHETIC_SCORE = 2.5
-
-MODEL_PROMPT_WEIGHT = 0.2
-GLOBAL_PROMPT_STEP_WEIGHT = 0.4
-GLOBAL_PROMPT_UPSCALE_WEIGHT = 0.4
-UPSCALE_PROMPT_STEP_WEIGHT = 0.1
-MAX_IMAGE_SCALE = 3.0
-
-__all__ = ["NodeDict", "DiffusionStep", "DiffusionPlan"]
-
-class UpscaleStepDict(TypedDict):
- method: UPSCALE_LITERAL
- amount: Union[int, float]
- strength: NotRequired[float]
- num_inference_steps: NotRequired[int]
- scheduler: NotRequired[SCHEDULER_LITERAL]
- guidance_scale: NotRequired[float]
- controlnets: NotRequired[List[Union[CONTROLNET_LITERAL, Tuple[CONTROLNET_LITERAL, float]]]]
- prompt: NotRequired[str]
- prompt_2: NotRequired[str]
- negative_prompt: NotRequired[str]
- negative_prompt_2: NotRequired[str]
- chunking_size: NotRequired[int]
- chunking_mask_type: NotRequired[MASK_TYPE_LITERAL]
- chunking_mask_kwargs: NotRequired[Dict[str, Any]]
- noise_offset: NotRequired[float]
- noise_method: NotRequired[NOISE_METHOD_LITERAL]
- noise_blend_method: NotRequired[LATENT_BLEND_METHOD_LITERAL]
-
-class ControlImageDict(TypedDict):
- controlnet: CONTROLNET_LITERAL
- image: PIL.Image.Image
- fit: NotRequired[IMAGE_FIT_LITERAL]
- anchor: NotRequired[IMAGE_ANCHOR_LITERAL]
- scale: NotRequired[float]
- start: NotRequired[Optional[float]]
- end: NotRequired[Optional[float]]
- process: NotRequired[bool]
- invert: NotRequired[bool]
- refiner: NotRequired[bool]
-
-class IPAdapterImageDict(TypedDict):
- image: PIL.Image.Image
- fit: NotRequired[IMAGE_FIT_LITERAL]
- anchor: NotRequired[IMAGE_ANCHOR_LITERAL]
- scale: NotRequired[float]
-
-class NodeDict(TypedDict):
- w: int
- h: int
- x: int
- y: int
- control_images: NotRequired[List[ControlImageDict]]
- ip_adapter_images: NotRequired[List[IPAdapterImageDict]]
- ip_adapter_plus: NotRequired[bool]
- image: NotRequired[PIL.Image.Image]
- mask: NotRequired[PIL.Image.Image]
- fit: NotRequired[IMAGE_FIT_LITERAL]
- anchor: NotRequired[IMAGE_ANCHOR_LITERAL]
- prompt: NotRequired[str]
- prompt_2: NotRequired[str]
- negative_prompt: NotRequired[str]
- negative_prompt_2: NotRequired[str]
- strength: NotRequired[float]
- remove_background: NotRequired[bool]
- invert_mask: NotRequired[bool]
- crop_inpaint: NotRequired[bool]
- inpaint_feather: NotRequired[int]
-
-class DiffusionStep:
- """
- A step represents most of the inputs to describe what the image is and how to control inference
- """
-
- result: StableDiffusionPipelineOutput
-
- def __init__(
- self,
- name: str = "Step", # Can be set later
- width: Optional[int] = None,
- height: Optional[int] = None,
- prompt: Optional[str] = None,
- prompt_2: Optional[str] = None,
- negative_prompt: Optional[str] = None,
- negative_prompt_2: Optional[str] = None,
- image: Optional[Union[DiffusionStep, PIL.Image.Image, str]] = None,
- mask: Optional[Union[DiffusionStep, PIL.Image.Image, str]] = None,
- control_images: Optional[List[ControlImageDict]] = None,
- ip_adapter_images: Optional[List[IPAdapterImageDict]] = None,
- ip_adapter_plus: bool = False,
- ip_adapter_face: bool = False,
- strength: Optional[float] = None,
- num_inference_steps: Optional[int] = DEFAULT_INFERENCE_STEPS,
- guidance_scale: Optional[float] = DEFAULT_GUIDANCE_SCALE,
- refiner_start: Optional[float] = None,
- refiner_strength: Optional[float] = None,
- refiner_guidance_scale: Optional[float] = DEFAULT_REFINER_GUIDANCE_SCALE,
- refiner_aesthetic_score: Optional[float] = DEFAULT_AESTHETIC_SCORE,
- refiner_negative_aesthetic_score: Optional[float] = DEFAULT_NEGATIVE_AESTHETIC_SCORE,
- refiner_prompt: Optional[str] = None,
- refiner_prompt_2: Optional[str] = None,
- refiner_negative_prompt: Optional[str] = None,
- refiner_negative_prompt_2: Optional[str] = None,
- crop_inpaint: Optional[bool] = True,
- inpaint_feather: Optional[int] = None,
- remove_background: bool = False,
- fill_background: bool = False,
- scale_to_model_size: bool = False,
- ) -> None:
- self.name = name
- self.width = width
- self.height = height
- self.prompt = prompt
- self.prompt_2 = prompt_2
- self.negative_prompt = negative_prompt
- self.negative_prompt_2 = negative_prompt_2
- self.image = image
- self.mask = mask
- self.ip_adapter_images = ip_adapter_images
- self.ip_adapter_plus = ip_adapter_plus
- self.ip_adapter_face = ip_adapter_face
- self.control_images = control_images
- self.strength = strength
- self.refiner_start = refiner_start
- self.refiner_strength = refiner_strength
- self.refiner_prompt = refiner_prompt
- self.refiner_prompt_2 = refiner_prompt_2
- self.refiner_negative_prompt = refiner_negative_prompt
- self.refiner_negative_prompt_2 = refiner_negative_prompt_2
- self.remove_background = remove_background
- self.fill_background = fill_background
- self.scale_to_model_size = scale_to_model_size
- self.num_inference_steps = num_inference_steps if num_inference_steps is not None else DEFAULT_INFERENCE_STEPS
- self.guidance_scale = guidance_scale if guidance_scale is not None else DEFAULT_GUIDANCE_SCALE
- self.refiner_guidance_scale = (
- refiner_guidance_scale if refiner_guidance_scale is not None else DEFAULT_REFINER_GUIDANCE_SCALE
- )
- self.refiner_aesthetic_score = (
- refiner_aesthetic_score if refiner_aesthetic_score is not None else DEFAULT_AESTHETIC_SCORE
- )
- self.refiner_negative_aesthetic_score = (
- refiner_negative_aesthetic_score
- if refiner_negative_aesthetic_score is not None
- else DEFAULT_NEGATIVE_AESTHETIC_SCORE
- )
- self.crop_inpaint = crop_inpaint if crop_inpaint is not None else True
- self.inpaint_feather = inpaint_feather if inpaint_feather is not None else 32
-
- def get_serialization_dict(self, image_directory: Optional[str]=None) -> Dict[str, Any]:
- """
- Gets the dictionary that will be returned to serialize
- """
- serialized: Dict[str, Any] = {
- "name": self.name,
- "width": self.width,
- "height": self.height,
- "prompt": self.prompt,
- "prompt_2": self.prompt_2,
- "negative_prompt": self.negative_prompt,
- "negative_prompt_2": self.negative_prompt_2,
- "strength": self.strength,
- "num_inference_steps": self.num_inference_steps,
- "guidance_scale": self.guidance_scale,
- "remove_background": self.remove_background,
- "fill_background": self.fill_background,
- "refiner_start": self.refiner_start,
- "refiner_strength": self.refiner_strength,
- "refiner_guidance_scale": self.refiner_guidance_scale,
- "refiner_aesthetic_score": self.refiner_aesthetic_score,
- "refiner_negative_aesthetic_score": self.refiner_negative_aesthetic_score,
- "refiner_prompt": self.refiner_prompt,
- "refiner_prompt_2": self.refiner_prompt_2,
- "refiner_negative_prompt": self.refiner_negative_prompt,
- "refiner_negative_prompt_2": self.refiner_negative_prompt_2,
- "scale_to_model_size": self.scale_to_model_size,
- "crop_inpaint": self.crop_inpaint,
- "inpaint_feather": self.inpaint_feather,
- "ip_adapter_plus": self.ip_adapter_plus,
- "ip_adapter_face": self.ip_adapter_face,
- }
-
- serialize_children: List[DiffusionStep] = []
- for key in ["image", "mask"]:
- child = getattr(self, key)
- if isinstance(child, DiffusionStep):
- if child in serialize_children:
- serialized[key] = serialize_children.index(child)
- else:
- serialize_children.append(child)
- serialized[key] = len(serialize_children) - 1
- elif child is not None and image_directory is not None:
- path = os.path.join(image_directory, f"{get_uuid()}.png")
- child.save(path)
- serialized[key] = path
- else:
- serialized[key] = child
-
- if self.ip_adapter_images:
- adapter_images = []
- for adapter_image in self.ip_adapter_images:
- image_dict = {
- "fit": adapter_image.get("fit", None),
- "anchor": adapter_image.get("anchor", None),
- "scale": adapter_image.get("scale", None)
- }
- if isinstance(adapter_image["image"], DiffusionStep):
- if adapter_image["image"] in serialize_children:
- image_dict["image"] = serialize_children[adapter_image["image"]] # type: ignore
- else:
- serialize_children.append(adapter_image["image"])
- image_dict["image"] = len(serialize_children) - 1
- elif adapter_image["image"] is not None and image_directory is not None:
- path = os.path.join(image_directory, f"{get_uuid()}.png")
- adapter_image["image"].save(path)
- image_dict["image"] = path # type: ignore
- else:
- image_dict["image"] = adapter_image["image"]
- adapter_images.append(image_dict)
- serialized["ip_adapter_images"] = adapter_images
-
- if self.control_images:
- control_images = []
- for control_image in self.control_images:
- image_dict = {
- "controlnet": control_image["controlnet"], # type: ignore
- "scale": control_image.get("scale", 1.0),
- "fit": control_image.get("fit", None),
- "anchor": control_image.get("anchor", None),
- "start": control_image.get("start", None),
- "end": control_image.get("end", None)
- }
- if isinstance(control_image["image"], DiffusionStep):
- if control_image["image"] in serialize_children:
- image_dict["image"] = serialize_children.index(control_image["image"])
- else:
- serialize_children.append(control_image["image"])
- image_dict["image"] = len(serialize_children) - 1
- elif control_image["image"] is not None and image_directory is not None:
- path = os.path.join(image_directory, f"{get_uuid()}.png")
- control_image["image"].save(path)
- image_dict["image"] = path # type:ignore[assignment]
- else:
- image_dict["image"] = control_image["image"]
- control_images.append(image_dict)
- serialized["control_images"] = control_images
- serialized["children"] = [child.get_serialization_dict(image_directory) for child in serialize_children]
- return serialized
-
- @property
- def kwargs(self) -> Dict[str, Any]:
- """
- Returns the keyword arguments that will passed to the pipeline invocation.
- """
- return {
- "width": self.width,
- "height": self.height,
- "prompt": self.prompt,
- "prompt_2": self.prompt_2,
- "negative_prompt": self.negative_prompt,
- "negative_prompt_2": self.negative_prompt_2,
- "image": self.image,
- "strength": self.strength,
- "num_inference_steps": self.num_inference_steps,
- "guidance_scale": self.guidance_scale,
- "refiner_start": self.refiner_start,
- "refiner_strength": self.refiner_strength,
- "refiner_guidance_scale": self.refiner_guidance_scale,
- "refiner_aesthetic_score": self.refiner_aesthetic_score,
- "refiner_negative_aesthetic_score": self.refiner_negative_aesthetic_score,
- "refiner_prompt": self.refiner_prompt,
- "refiner_prompt_2": self.refiner_prompt_2,
- "refiner_negative_prompt": self.refiner_negative_prompt,
- "refiner_negative_prompt_2": self.refiner_negative_prompt_2,
- "ip_adapter_plus": self.ip_adapter_plus,
- "ip_adapter_face": self.ip_adapter_face,
- }
-
- def get_inpaint_bounding_box(self, pipeline_size: int) -> List[Tuple[int, int]]:
- """
- Gets the bounding box of places inpainted
- """
- if isinstance(self.mask, str):
- mask = PIL.Image.open(self.mask)
- elif isinstance(self.mask, PIL.Image.Image):
- mask = self.mask
- else:
- raise ValueError("Cannot get bounding box for empty or dynamic mask.")
-
- width, height = mask.size
- x0, y0, x1, y1 = mask.getbbox()
-
- # Add feather
- x0 = max(0, x0 - self.inpaint_feather)
- x1 = min(width, x1 + self.inpaint_feather)
- y0 = max(0, y0 - self.inpaint_feather)
- y1 = min(height, y1 + self.inpaint_feather)
-
- # Create centered frame about the bounding box
- bbox_width = x1 - x0
- bbox_height = y1 - y0
-
- if bbox_width < pipeline_size:
- x0 = max(0, x0 - ((pipeline_size - bbox_width) // 2))
- x1 = min(width, x0 + pipeline_size)
- x0 = max(0, x1 - pipeline_size)
- if bbox_height < pipeline_size:
- y0 = max(0, y0 - ((pipeline_size - bbox_height) // 2))
- y1 = min(height, y0 + pipeline_size)
- y0 = max(0, y1 - pipeline_size)
-
- return [(x0, y0), (x1, y1)]
-
- def paste_inpaint_image(
- self, background: PIL.Image.Image, foreground: PIL.Image.Image, position: Tuple[int, int]
- ) -> PIL.Image.Image:
- """
- Pastes the inpaint image on the background with an appropriately feathered mask.
- """
- image = background.copy()
-
- width, height = image.size
- foreground_width, foreground_height = foreground.size
- left, top = position
- right, bottom = left + foreground_width, top + foreground_height
-
- feather_left = left > 0
- feather_top = top > 0
- feather_right = right < width
- feather_bottom = bottom < height
-
- mask = PIL.Image.new("L", (foreground_width, foreground_height), 255)
-
- for i in range(self.inpaint_feather):
- multiplier = (i + 1) / (self.inpaint_feather + 1)
- pixels = []
- if feather_left:
- pixels.extend([(i, j) for j in range(foreground_height)])
- if feather_top:
- pixels.extend([(j, i) for j in range(foreground_width)])
- if feather_right:
- pixels.extend([(foreground_width - i - 1, j) for j in range(foreground_height)])
- if feather_bottom:
- pixels.extend([(j, foreground_height - i - 1) for j in range(foreground_width)])
- for x, y in pixels:
- mask.putpixel((x, y), int(mask.getpixel((x, y)) * multiplier))
-
- image.paste(foreground, position, mask=mask)
- return image
-
- def execute(
- self,
- pipeline: DiffusionPipelineManager,
- use_cached: bool = True,
- **kwargs: Any,
- ) -> StableDiffusionPipelineOutput:
- """
- Executes this pipeline step.
- """
- if hasattr(self, "result") and use_cached:
- return self.result
-
- samples = kwargs.pop("samples", 1)
-
- if isinstance(self.image, DiffusionStep):
- image = self.image.execute(pipeline, samples=1, **kwargs)["images"][0]
- elif isinstance(self.image, str):
- image = PIL.Image.open(self.image)
- else:
- image = self.image
-
- if isinstance(self.mask, DiffusionStep):
- mask = self.mask.execute(pipeline, samples=1, **kwargs)["images"][0]
- elif isinstance(self.mask, str):
- mask = PIL.Image.open(self.mask)
- else:
- mask = self.mask
-
- if self.ip_adapter_images is not None:
- ip_adapter_images: List[Tuple[PIL.Image.Image, float]] = []
- for adapter_image_dict in self.ip_adapter_images:
- adapter_image = adapter_image_dict["image"]
-
- if isinstance(adapter_image, DiffusionStep):
- adapter_image = adapter_image.execute(pipeline, samples=1, **kwargs)["images"][0]
- elif isinstance(adapter_image, str):
- adapter_image = PIL.Image.open(adapter_image)
-
- adapter_scale = adapter_image_dict.get("scale", 1.0)
- ip_adapter_images.append((
- adapter_image,
- adapter_scale
- ))
- else:
- ip_adapter_images = None # type: ignore[assignment]
-
- if self.control_images is not None:
- control_images: Dict[str, List[Tuple[PIL.Image.Image, float, Optional[float], Optional[float]]]] = {}
- for control_image_dict in self.control_images:
- control_image = control_image_dict["image"]
- controlnet = control_image_dict["controlnet"]
-
- if isinstance(control_image, DiffusionStep):
- control_image = control_image.execute(pipeline, samples=1, **kwargs)["images"][0]
- elif isinstance(control_image, str):
- control_image = PIL.Image.open(control_image)
-
- conditioning_scale = control_image_dict.get("scale", 1.0)
- conditioning_start = control_image_dict.get("start", None)
- conditioning_end = control_image_dict.get("end", None)
-
- if control_image_dict.get("process", True):
- control_image = pipeline.control_image_processor(controlnet, control_image)
- elif control_image_dict.get("invert", False):
- control_image = PIL.ImageOps.invert(control_image)
-
- if controlnet not in control_images:
- control_images[controlnet] = [] # type: ignore[assignment]
-
- control_images[controlnet].append((
- control_image,
- conditioning_scale,
- conditioning_start,
- conditioning_end,
- ))
- else:
- control_images = None # type: ignore[assignment]
-
- if (
- not self.prompt and
- not mask and
- not control_images and
- not ip_adapter_images
- ):
- if image:
- if self.remove_background:
- with pipeline.background_remover.remover() as remove_background:
- image = remove_background(image)
-
- samples = kwargs.get("num_images_per_prompt", 1)
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
-
- self.result = StableDiffusionPipelineOutput(
- images=[image] * samples, nsfw_content_detected=[False] * samples
- )
- return self.result
- raise ValueError("No prompt or image in this step; cannot invoke or pass through.")
-
- invocation_kwargs = {**kwargs, **self.kwargs}
-
- image_scale = 1
- pipeline_size = pipeline.inpainter_size if mask is not None else pipeline.size
- image_width, image_height, image_background, image_position = None, None, None, None
-
- if image is not None:
- if self.remove_background and self.fill_background:
- # Execute remove background here
- with pipeline.background_remover.remover() as remove_background:
- image = remove_background(image)
- white = PIL.Image.new("RGB", image.size, (255, 255, 255))
- black = PIL.Image.new("RGB", image.size, (0, 0, 0))
- image_mask = white.copy()
- alpha = image.split()[-1]
- alpha_clamp = PIL.Image.eval(alpha, lambda a: 255 if a > 128 else 0)
- image_mask.paste(black, mask=alpha)
- if mask is not None:
- assert mask.size == image.size, "image and mask must be the same size"
- # Merge mask and alpha
- image_mask.paste(mask)
- inverse_alpha = PIL.Image.eval(alpha_clamp, lambda a: 255 - a)
- image_mask.paste(white, mask=inverse_alpha)
- mask = image_mask
- image_width, image_height = image.size
- invocation_kwargs["image"] = image
-
- if mask is not None:
- mask_width, mask_height = mask.size
- if (
- self.crop_inpaint
- and (mask_width > pipeline_size or mask_height > pipeline.size)
- and image is not None
- ):
- (x0, y0), (x1, y1) = self.get_inpaint_bounding_box(pipeline_size)
-
- bbox_width = x1 - x0
- bbox_height = y1 - y0
-
- pixel_ratio = (bbox_height * bbox_width) / (mask_width * mask_height)
- pixel_savings = (1.0 - pixel_ratio) * 100
- if pixel_ratio < 0.75:
- logger.debug(f"Calculated pixel area savings of {pixel_savings:.1f}% by cropping to ({x0}, {y0}), ({x1}, {y1}) ({bbox_width}px by {bbox_height}px)")
- # Disable refining
- invocation_kwargs["refiner_strength"] = 0
- invocation_kwargs["refiner_start"] = 1
- image_position = (x0, y0)
- image_background = image.copy()
- image = image.crop((x0, y0, x1, y1))
- mask = mask.crop((x0, y0, x1, y1))
- image_width, image_height = bbox_width, bbox_height
- invocation_kwargs["image"] = image # Override what was set above
- else:
- logger.debug(
- f"Calculated pixel area savings of {pixel_savings:.1f}% are insufficient, will not crop"
- )
- invocation_kwargs["mask"] = mask
- if image is not None:
- assert image.size == mask.size, "image and mask must be the same size"
- else:
- image_width, image_height = mask.size
-
- if control_images is not None:
- for controlnet_name in control_images:
- for i, (
- control_image,
- conditioning_scale,
- conditioning_start,
- conditioning_end
- ) in enumerate(control_images[controlnet_name]):
- if image_position is not None and image_width is not None and image_height is not None:
- # Also crop control image
- x0, y0 = image_position
- x1 = x0 + image_width
- y1 = y0 + image_height
- control_image = control_image.crop((x0, y0, x1, y1))
- control_images[controlnet_name][i] = (control_image, conditioning_scale, conditioning_start, conditioning_end)
- if image_width is None or image_height is None:
- image_width, image_height = control_image.size
- else:
- this_width, this_height = control_image.size
- assert image_width == this_width and image_height == this_height, "all images must be the same size"
- invocation_kwargs["control_images"] = control_images
- if mask is not None:
- pipeline.inpainter_controlnets = list(control_images.keys())
- else:
- pipeline.controlnets = list(control_images.keys())
-
- if ip_adapter_images is not None:
- invocation_kwargs["ip_adapter_images"] = ip_adapter_images
-
- if self.width is not None and self.height is not None and image_width is None and image_height is None:
- image_width, image_height = self.width, self.height
-
- if image_width is None or image_height is None:
- logger.warning("No known invocation size, defaulting to engine size")
- image_width, image_height = pipeline_size, pipeline_size
-
- if image_width is not None and image_width < pipeline_size:
- image_scale = pipeline_size / image_width
- if image_height is not None and image_height < pipeline_size:
- image_scale = max(image_scale, pipeline_size / image_height)
-
- if image_scale > MAX_IMAGE_SCALE or not self.scale_to_model_size:
- # Refuse it's too oblong. We'll just calculate at the appropriate size.
- image_scale = 1
-
- invocation_kwargs["width"] = 8 * math.ceil((image_width * image_scale) / 8)
- invocation_kwargs["height"] = 8 * math.ceil((image_height * image_scale) / 8)
-
- if image_scale > 1:
- # scale input images up
- for key in ["image", "mask"]:
- if invocation_kwargs.get(key, None) is not None:
- invocation_kwargs[key] = self.scale_image(invocation_kwargs[key], image_scale)
- for controlnet_name in invocation_kwargs.get("control_images", {}):
- for i, (control_image, conditioning_scale) in enumerate(invocation_kwargs["control_images"].get(controlnet_name, [])):
- invocation_kwargs["control_images"][controlnet_name][i] = (
- self.scale_image(control_image, image_scale),
- conditioning_scale
- )
-
- latent_callback = invocation_kwargs.get("latent_callback", None)
- if image_background is not None and image_position is not None and latent_callback is not None:
- # Hijack latent callback to paste onto background
- def pasted_latent_callback(images: List[PIL.Image.Image]) -> None:
- images = [
- self.paste_inpaint_image(image_background, image, image_position) # type: ignore
- for image in images
- ]
- latent_callback(images)
-
- invocation_kwargs["latent_callback"] = pasted_latent_callback
-
- result = pipeline(**invocation_kwargs)
-
- if image_background is not None and image_position is not None:
- for i, image in enumerate(result["images"]):
- result["images"][i] = self.paste_inpaint_image(image_background, image, image_position)
-
- if self.remove_background and not self.fill_background:
- with pipeline.background_remover.remover() as remove_background:
- for i, image in enumerate(result["images"]):
- result["images"][i] = remove_background(image)
-
- if image_scale > 1:
- for i, image in enumerate(result["images"]):
- result["images"][i] = self.scale_image(image, 1 / image_scale)
-
- self.result = result
- return result
-
- @staticmethod
- def scale_image(image: PIL.Image.Image, scale: Union[int, float]) -> PIL.Image.Image:
- """
- Scales an image proportionally.
- """
- width, height = image.size
- scaled_width = 8 * round((width * scale) / 8)
- scaled_height = 8 * round((height * scale) / 8)
- return image.resize((scaled_width, scaled_height))
-
- @staticmethod
- def deserialize_dict(step_dict: Dict[str, Any]) -> DiffusionStep:
- """
- Given a serialized dict, instantiate a diffusion step
- """
- kwargs: Dict[str, Any] = {}
- for key in [
- "name",
- "prompt",
- "prompt_2",
- "negative_prompt",
- "negative_prompt_2",
- "strength",
- "num_inference_steps",
- "guidance_scale",
- "refiner_start",
- "refiner_strength",
- "refiner_guidance_scale",
- "refiner_aesthetic_score",
- "refiner_negative_aesthetic_score",
- "refiner_prompt",
- "refiner_prompt_2",
- "refiner_negative_prompt",
- "refiner_negative_prompt_2",
- "width",
- "height",
- "fill_background",
- "remove_background",
- "scale_to_model_size",
- "crop_inpaint",
- "inpaint_feather",
- "ip_adapter_plus",
- "ip_adapter_face",
- ]:
- if key in step_dict:
- kwargs[key] = step_dict[key]
-
- deserialized_children = [DiffusionStep.deserialize_dict(child) for child in step_dict.get("children", [])]
- for key in ["image", "mask"]:
- if key not in step_dict:
- continue
- if isinstance(step_dict[key], int):
- kwargs[key] = deserialized_children[step_dict[key]]
- elif isinstance(step_dict[key], str) and os.path.exists(step_dict[key]):
- kwargs[key] = PIL.Image.open(step_dict[key])
- elif isinstance(step_dict[key], list):
- kwargs[key] = [
- PIL.Image.open(path)
- for path in step_dict[key]
- ]
- else:
- kwargs[key] = step_dict[key]
- if "control_images" in step_dict:
- control_images: List[Dict[str, Any]] = []
- for control_image_dict in step_dict["control_images"]:
- control_image = control_image_dict["image"]
- if isinstance(control_image, int):
- control_image = deserialized_children[control_image]
- elif isinstance(control_image, str):
- control_image = PIL.Image.open(control_image)
- control_images.append({
- "image": control_image,
- "controlnet": control_image_dict["controlnet"],
- "scale": control_image_dict.get("scale", 1.0),
- "start": control_image_dict.get("start", None),
- "end": control_image_dict.get("end", None),
- "process": control_image_dict.get("process", True),
- "invert": control_image_dict.get("invert", False)
- })
- kwargs["control_images"] = control_images
-
- if "ip_adapter_images" in step_dict:
- ip_adapter_images: List[Dict[str, Any]] = []
- for ip_adapter_image_dict in step_dict["ip_adapter_images"]:
- ip_adapter_image = ip_adapter_image_dict["image"]
- if isinstance(ip_adapter_image, int):
- ip_adapter_image = deserialized_children[ip_adapter_image]
- elif isinstance(ip_adapter_image, str):
- ip_adapter_image = PIL.Image.open(ip_adapter_image)
- ip_adapter_images.append({
- "image": ip_adapter_image,
- "scale": ip_adapter_image_dict.get("scale", 1.0),
- })
- kwargs["ip_adapter_images"] = ip_adapter_images
- return DiffusionStep(**kwargs)
-
-class DiffusionNode:
- """
- A diffusion node has a step that may be recursive, combined with bounds.
- """
- def __init__(self, bounds: List[Tuple[int, int]], step: DiffusionStep) -> None:
- self.bounds = bounds
- self.step = step
-
- def resize_image(self, image: PIL.Image.Image) -> PIL.Image.Image:
- """
- Resizes the image to fit the bounds.
- """
- x, y = self.bounds[0]
- w, h = self.bounds[1]
- return image.resize((w - x, h - y))
-
- def get_serialization_dict(self, image_directory: Optional[str] = None) -> Dict[str, Any]:
- """
- Gets the step's dict and adds bounds.
- """
- step_dict = self.step.get_serialization_dict(image_directory)
- step_dict["bounds"] = self.bounds
- return step_dict
-
- def execute(
- self,
- pipeline: DiffusionPipelineManager,
- **kwargs: Any,
- ) -> StableDiffusionPipelineOutput:
- """
- Passes through the execution to the step.
- """
- return self.step.execute(pipeline, **kwargs)
-
- @property
- def name(self) -> str:
- """
- Pass-through the step name
- """
- return self.step.name
-
- @staticmethod
- def deserialize_dict(step_dict: Dict[str, Any]) -> DiffusionNode:
- """
- Given a serialized dict, instantiate a diffusion Node
- """
- bounds = step_dict.pop("bounds", None)
- if bounds is None:
- raise TypeError("Bounds are required")
-
- return DiffusionNode(
- [(int(bounds[0][0]), int(bounds[0][1])), (int(bounds[1][0]), int(bounds[1][1]))],
- DiffusionStep.deserialize_dict(step_dict),
- )
-
-
-class DiffusionPlan:
- """
- A diffusion plan represents any number of steps, with each step receiving the output of the previous.
-
- Additionally, we handle upscaling as part of the plan. If we want to upscale later, the Plan can be initiated
- with an empty steps array and initial image.
- """
-
- def __init__(
- self,
- size: int, # Required
- prompt: Optional[str] = None, # Global
- prompt_2: Optional[str] = None, # Global
- negative_prompt: Optional[str] = None, # Global
- negative_prompt_2: Optional[str] = None, # Global
- clip_skip: Optional[int] = None,
- refiner_size: Optional[int] = None,
- inpainter_size: Optional[int] = None,
- model: Optional[str] = None,
- refiner: Optional[str] = None,
- inpainter: Optional[str] = None,
- lora: Optional[Union[str, List[str], Tuple[str, float], List[Union[str, Tuple[str, float]]]]] = None,
- lycoris: Optional[Union[str, List[str], Tuple[str, float], List[Union[str, Tuple[str, float]]]]] = None,
- inversion: Optional[Union[str, List[str]]] = None,
- scheduler: Optional[SCHEDULER_LITERAL] = None,
- vae: Optional[str] = None,
- refiner_vae: Optional[str] = None,
- inpainter_vae: Optional[str] = None,
- width: Optional[int] = None,
- height: Optional[int] = None,
- nodes: List[DiffusionNode] = [],
- image: Optional[Union[str, PIL.Image.Image]] = None,
- chunking_size: Optional[int] = None,
- chunking_mask_type: Optional[MASK_TYPE_LITERAL] = None,
- chunking_mask_kwargs: Optional[Dict[str, Any]] = None,
- samples: Optional[int] = 1,
- iterations: Optional[int] = 1,
- seed: Optional[int] = None,
- build_tensorrt: bool = False,
- outpaint: bool = True,
- upscale_steps: Optional[Union[UpscaleStepDict, List[UpscaleStepDict]]] = None,
- freeu_factors: Optional[Tuple[float, float, float, float]] = None,
- guidance_scale: Optional[float] = None,
- num_inference_steps: Optional[int] = None,
- noise_offset: Optional[float] = None,
- noise_method: NOISE_METHOD_LITERAL = "perlin",
- noise_blend_method: LATENT_BLEND_METHOD_LITERAL = "inject",
- ) -> None:
- self.size = size
- self.inpainter_size = inpainter_size
- self.refiner_size = refiner_size
- self.prompt = prompt
- self.prompt_2 = prompt_2
- self.negative_prompt = negative_prompt
- self.negative_prompt_2 = negative_prompt_2
- self.clip_skip = clip_skip
- self.model = model
- self.refiner = refiner
- self.inpainter = inpainter
- self.lora = lora
- self.lycoris = lycoris
- self.inversion = inversion
- self.scheduler = scheduler
- self.vae = vae
- self.refiner_vae = refiner_vae
- self.inpainter_vae = inpainter_vae
- self.width = width if width is not None else self.size
- self.height = height if height is not None else self.size
- self.image = image
- self.chunking_size = chunking_size if chunking_size is not None else self.size // 8 # Pass 0 to disable
- self.chunking_mask_type = chunking_mask_type
- self.chunking_mask_kwargs = chunking_mask_kwargs
- self.samples = samples if samples is not None else 1
- self.iterations = iterations if iterations is not None else 1
- self.seed = seed if seed is not None else randint(1, sys.maxsize)
-
- self.outpaint = outpaint
- self.build_tensorrt = build_tensorrt
- self.nodes = nodes
- self.upscale_steps = upscale_steps
- self.freeu_factors = freeu_factors
- self.guidance_scale = guidance_scale
- self.num_inference_steps = num_inference_steps
- self.noise_offset = noise_offset
- self.noise_method = noise_method
- self.noise_blend_method = noise_blend_method
-
- @property
- def kwargs(self) -> Dict[str, Any]:
- """
- Returns the keyword arguments that will be passing to the pipeline call.
- """
- return {
- "width": self.width,
- "height": self.height,
- "clip_skip": self.clip_skip,
- "freeu_factors": self.freeu_factors,
- "chunking_size": self.chunking_size,
- "chunking_mask_type": self.chunking_mask_type,
- "chunking_mask_kwargs": self.chunking_mask_kwargs,
- "num_images_per_prompt": self.samples,
- "noise_offset": self.noise_offset,
- "noise_method": self.noise_method,
- "noise_blend_method": self.noise_blend_method,
- }
-
- @property
- def upscale(self) -> Iterator[UpscaleStepDict]:
- """
- Iterates over upscale steps.
- """
- if self.upscale_steps is not None:
- if isinstance(self.upscale_steps, list):
- for step in self.upscale_steps:
- yield step
- else:
- yield self.upscale_steps
-
- def execute(
- self,
- pipeline: DiffusionPipelineManager,
- task_callback: Optional[Callable[[str], None]] = None,
- progress_callback: Optional[Callable[[int, int, float], None]] = None,
- image_callback: Optional[Callable[[List[PIL.Image.Image]], None]] = None,
- image_callback_steps: Optional[int] = None,
- ) -> StableDiffusionPipelineOutput:
- """
- This is the main interface for execution.
-
- The first step will be the one that executes with the selected number of samples,
- and then each subsequent step will be performed on the number of outputs from the
- first step.
- """
- # We import here so this file can be imported by processes without initializing torch
- from diffusers.utils.pil_utils import PIL_INTERPOLATION
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
- if task_callback is None:
- task_callback = lambda arg: None
-
- images, nsfw = self.execute_nodes(
- pipeline,
- task_callback,
- progress_callback,
- image_callback,
- image_callback_steps
- )
-
- for upscale_step in self.upscale:
- method = upscale_step["method"]
- amount = upscale_step["amount"]
- num_inference_steps = upscale_step.get("num_inference_steps", DEFAULT_UPSCALE_INFERENCE_STEPS)
- guidance_scale = upscale_step.get("guidance_scale", DEFAULT_UPSCALE_GUIDANCE_SCALE)
- prompt = upscale_step.get("prompt", DEFAULT_UPSCALE_PROMPT)
- prompt_2 = upscale_step.get("prompt_2", None)
- negative_prompt = upscale_step.get("negative_prompt", None)
- negative_prompt_2 = upscale_step.get("negative_prompt_2", None)
- strength = upscale_step.get("strength", None)
- controlnets = upscale_step.get("controlnets", None)
- chunking_size = upscale_step.get("chunking_size", DEFAULT_UPSCALE_CHUNKING_SIZE)
- scheduler = upscale_step.get("scheduler", self.scheduler)
- chunking_mask_type = upscale_step.get("chunking_mask_type", None)
- chunking_mask_kwargs = upscale_step.get("chunking_mask_kwargs", None)
- noise_offset = upscale_step.get("noise_offset", None)
- noise_method = upscale_step.get("noise_method", None)
- noise_blend_method = upscale_step.get("noise_blend_method", None)
- refiner = self.refiner is not None and upscale_step.get("refiner", True)
-
- for i, image in enumerate(images):
- if nsfw is not None and nsfw[i]:
- logger.debug(f"Image {i} had NSFW content, not upscaling.")
- continue
-
- logger.debug(f"Upscaling sample {i} by {amount} using {method}")
- task_callback(f"Upscaling sample {i+1}")
-
- if method in ["esrgan", "esrganime", "gfpgan"]:
- if refiner:
- pipeline.unload_pipeline("clearing memory for upscaler")
- pipeline.offload_refiner()
- else:
- pipeline.offload_pipeline()
- pipeline.unload_refiner("clearing memory for upscaler")
- image = pipeline.upscaler(method, image, tile=pipeline.size, outscale=amount)
- elif method in PIL_INTERPOLATION:
- width, height = image.size
- image = image.resize(
- (int(width * amount), int(height * amount)),
- resample=PIL_INTERPOLATION[method]
- )
- else:
- logger.error(f"Unknown upscaler {method}")
- return self.format_output(images, nsfw)
-
- images[i] = image
- if image_callback is not None:
- image_callback(images)
-
- if strength is not None and strength > 0:
- task_callback("Preparing upscale pipeline")
-
- if refiner:
- # Refiners have safety disabled from the jump
- logger.debug("Using refiner for upscaling.")
- re_enable_safety = False
- chunking_size = min(chunking_size, pipeline.refiner_size // 2)
- else:
- # Disable pipeline safety here, it gives many false positives when upscaling.
- # We'll re-enable it after.
- logger.debug("Using base pipeline for upscaling.")
- re_enable_safety = pipeline.safe
- chunking_size = min(chunking_size, pipeline.size // 2)
- pipeline.safe = False
-
- if scheduler is not None:
- pipeline.scheduler = scheduler
-
- for i, image in enumerate(images):
- if nsfw is not None and nsfw[i]:
- logger.debug(f"Image {i} had NSFW content, not upscaling.")
- continue
-
- width, height = image.size
- kwargs = {
- "width": width,
- "height": height,
- "image": image,
- "num_images_per_prompt": 1,
- "prompt": prompt,
- "prompt_2": prompt_2,
- "negative_prompt": negative_prompt,
- "negative_prompt_2": negative_prompt_2,
- "strength": strength,
- "num_inference_steps": num_inference_steps,
- "guidance_scale": guidance_scale,
- "chunking_size": chunking_size,
- "chunking_mask_type": chunking_mask_type,
- "chunking_mask_kwargs": chunking_mask_kwargs,
- "progress_callback": progress_callback,
- "latent_callback": image_callback,
- "latent_callback_type": "pil",
- "latent_callback_steps": image_callback_steps,
- "noise_offset": noise_offset,
- "noise_method": noise_method,
- "noise_blend_method": noise_blend_method,
- }
-
- if controlnets is not None:
- if not isinstance(controlnets, list):
- controlnets = [controlnets] # type: ignore[unreachable]
-
- controlnet_names = []
- controlnet_weights = []
-
- for controlnet in controlnets:
- if isinstance(controlnet, tuple):
- controlnet, weight = controlnet
- else:
- weight = 1.0
- if controlnet not in controlnet_names:
- controlnet_names.append(controlnet)
- controlnet_weights.append(weight)
-
- logger.debug(f"Enabling controlnet(s) {controlnet_names} for upscaling")
-
- if refiner:
- pipeline.refiner_controlnets = controlnet_names
- upscale_pipline = pipeline.refiner_pipeline
- is_sdxl = pipeline.refiner_is_sdxl
- else:
- pipeline.controlnets = controlnet_names
- upscale_pipeline = pipeline.pipeline
- is_sdxl = pipeline.is_sdxl
-
- kwargs["control_images"] = dict([
- (
- controlnet_name,
- [(
- pipeline.control_image_processor(controlnet_name, image),
- controlnet_weight
- )]
- )
- for controlnet_name, controlnet_weight in zip(controlnet_names, controlnet_weights)
- ])
- elif refiner:
- pipeline.refiner_controlnets = None
- upscale_pipeline = pipeline.refiner_pipeline
- else:
- pipeline.controlnets = None
- upscale_pipeline = pipeline.pipeline
-
- logger.debug(f"Upscaling sample {i} with arguments {kwargs}")
- pipeline.stop_keepalive() # Stop here to kill during upscale diffusion
- task_callback(f"Re-diffusing Upscaled Sample {i+1}")
- image = upscale_pipeline(
- generator=pipeline.generator,
- device=pipeline.device,
- offload_models=pipeline.pipeline_sequential_onload,
- **kwargs
- ).images[0]
- pipeline.start_keepalive() # Return keepalive between iterations
- images[i] = image
- if image_callback is not None:
- image_callback(images)
- if re_enable_safety:
- pipeline.safe = True
- if refiner:
- logger.debug("Offloading refiner for next inference.")
- pipeline.refiner_controlnets = None
- pipeline.offload_refiner()
- else:
- pipeline.controlnets = None # Make sure we reset controlnets
- pipeline.stop_keepalive() # Make sure this is stopped
- return self.format_output(images, nsfw)
-
- def get_image_metadata(self, image: PIL.Image.Image) -> Dict[str, Any]:
- """
- Gets metadata from an image
- """
- (width, height) = image.size
- return {
- "width": width,
- "height": height,
- "metadata": getattr(image, "text", {})
- }
-
- def redact_images_from_metadata(self, metadata: Dict[str, Any]) -> None:
- """
- Removes images from a metadata dictionary
- """
- for key in ["image", "mask"]:
- image = metadata.get(key, None)
- if image is not None:
- metadata[key] = self.get_image_metadata(metadata[key])
- if "control_images" in metadata:
- for i, control_dict in enumerate(metadata["control_images"]):
- control_dict["image"] = self.get_image_metadata(control_dict["image"])
- if "ip_adapter_images" in metadata:
- for i, ip_adapter_dict in enumerate(metadata["ip_adapter_images"]):
- ip_adapter_dict["image"] = self.get_image_metadata(ip_adapter_dict["image"])
- if "children" in metadata:
- for child in metadata["children"]:
- self.redact_images_from_metadata(child)
- if "nodes" in metadata:
- for child in metadata["nodes"]:
- self.redact_images_from_metadata(child)
-
- def format_output(self, images: List[PIL.Image.Image], nsfw: List[bool]) -> StableDiffusionPipelineOutput:
- """
- Adds Enfugue metadata to an image result
- """
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
-
- metadata_dict = self.get_serialization_dict()
- self.redact_images_from_metadata(metadata_dict)
- formatted_images = []
- for i, image in enumerate(images):
- byte_io = io.BytesIO()
- metadata = PngInfo()
- metadata.add_text("EnfugueGenerationData", Serializer.serialize(metadata_dict))
- image.save(byte_io, format="PNG", pnginfo=metadata)
- formatted_images.append(PIL.Image.open(byte_io))
-
- return StableDiffusionPipelineOutput(
- images=formatted_images,
- nsfw_content_detected=nsfw
- )
-
- def prepare_pipeline(self, pipeline: DiffusionPipelineManager) -> None:
- """
- Assigns pipeline-level variables.
- """
- pipeline.start_keepalive() # Make sure this is going
- pipeline.model = self.model
- pipeline.refiner = self.refiner
- pipeline.inpainter = self.inpainter
- pipeline.lora = self.lora
- pipeline.lycoris = self.lycoris
- pipeline.inversion = self.inversion
- pipeline.size = self.size
- pipeline.scheduler = self.scheduler
- pipeline.vae = self.vae
- pipeline.refiner_vae = self.refiner_vae
- pipeline.refiner_size = self.refiner_size
- pipeline.inpainter_vae = self.inpainter_vae
- pipeline.inpainter_size = self.inpainter_size
- if self.build_tensorrt:
- pipeline.build_tensorrt = True
-
- def execute_nodes(
- self,
- pipeline: DiffusionPipelineManager,
- task_callback: Callable[[str], None],
- progress_callback: Optional[Callable[[int, int, float], None]] = None,
- image_callback: Optional[Callable[[List[PIL.Image.Image]], None]] = None,
- image_callback_steps: Optional[int] = None,
- ) -> Tuple[List[PIL.Image.Image], List[bool]]:
- """
- This is called during execute(). It will go through the steps in order and return
- the resulting image(s).
- """
- if not self.nodes:
- if not self.image:
- raise ValueError("No image and no steps; cannot execute plan.")
- return [self.image], [False]
-
- # Define progress and latent callback kwargs, we'll add task callbacks ourself later
- callback_kwargs = {
- "progress_callback": progress_callback,
- "latent_callback_steps": image_callback_steps,
- "latent_callback_type": "pil",
- }
-
- # Set up the pipeline
- pipeline._task_callback = task_callback
- self.prepare_pipeline(pipeline)
-
- if self.seed is not None:
- # Set up the RNG
- pipeline.seed = self.seed
-
- images = [PIL.Image.new("RGBA", (self.width, self.height)) for i in range(self.samples * self.iterations)]
- image_draw = [PIL.ImageDraw.Draw(image) for image in images]
- nsfw_content_detected = [False] * self.samples * self.iterations
-
- # Keep a final mask of all nodes to outpaint in the end
- outpaint_mask = PIL.Image.new("RGB", (self.width, self.height), (255, 255, 255))
- outpaint_draw = PIL.ImageDraw.Draw(outpaint_mask)
-
- for i, node in enumerate(self.nodes):
- def node_task_callback(task: str) -> None:
- """
- Wrap the task callback so we indicate what node we're on
- """
- task_callback(f"{node.name}: {task}")
-
- invocation_kwargs = {**self.kwargs, **callback_kwargs}
- invocation_kwargs["task_callback"] = node_task_callback
- this_intention = "inpainting" if node.step.mask is not None else "inference"
- next_intention: Optional[str] = None
-
- if i < len(self.nodes) - 2:
- next_node = self.nodes[i+1]
- next_intention = "inpainting" if next_node.step.mask is not None else "inference"
- elif self.upscale_steps is not None and not (isinstance(self.upscale_steps, list) and len(self.upscale_steps) == 0):
- upscale_step = self.upscale_steps
- if isinstance(upscale_step, list):
- upscale_step = upscale_step[0]
-
- upscale_strength = upscale_step.get("strength", None)
- use_ai_upscaler = "gan" in upscale_step["method"]
- use_sd_upscaler = upscale_strength is not None and upscale_strength > 0
-
- if use_ai_upscaler:
- next_intention = "upscaling"
- elif use_sd_upscaler:
- if self.refiner is not None and upscale_step.get("refiner", True):
- next_intention = "refining"
- else:
- next_intention = "inference"
-
- for it in range(self.iterations):
- if image_callback is not None:
- def node_image_callback(callback_images: List[PIL.Image.Image]) -> None:
- """
- Wrap the original image callback so we're actually pasting the initial image on the main canvas
- """
- for j, callback_image in enumerate(callback_images):
- image_index = (it * self.samples) + j
- images[image_index].paste(node.resize_image(callback_image), node.bounds[0])
- image_callback(images) # type: ignore
-
- else:
- node_image_callback = None # type: ignore
-
- result = node.execute(
- pipeline,
- latent_callback=node_image_callback,
- next_intention=this_intention if it < self.iterations - 1 else next_intention,
- use_cached=False,
- **invocation_kwargs
- )
-
- for j, image in enumerate(result["images"]):
- image_index = (it * self.samples) + j
- image = node.resize_image(image)
- if image.mode == "RGBA":
- # Draw the alpha mask of the return image onto the outpaint mask
- alpha = image.split()[-1]
- black = PIL.Image.new("RGB", alpha.size, (0, 0, 0))
- outpaint_mask.paste(black, node.bounds[0], mask=alpha)
- image_draw[image_index].rectangle((*node.bounds[0], *node.bounds[1]), fill=(0, 0, 0, 0))
- images[image_index].paste(image, node.bounds[0], mask=alpha)
- else:
- # Draw a rectangle directly
- outpaint_draw.rectangle(node.bounds, fill="#000000")
- images[image_index].paste(node.resize_image(image), node.bounds[0])
-
- nsfw_content_detected[image_index] = nsfw_content_detected[image_index] or (
- "nsfw_content_detected" in result and result["nsfw_content_detected"][j]
- )
-
- # Call the callback
- if image_callback is not None:
- image_callback(images)
-
- # Determine if there's anything left to outpaint
- outpaint_mask = outpaint_mask.convert("L")
- image_r_min, image_r_max = outpaint_mask.getextrema()
- if image_r_max > 0 and self.prompt and self.outpaint:
- # Outpaint
- del invocation_kwargs["num_images_per_prompt"]
- outpaint_mask = feather_mask(outpaint_mask)
-
- outpaint_prompt_tokens = TokenMerger()
- outpaint_prompt_2_tokens = TokenMerger()
-
- outpaint_negative_prompt_tokens = TokenMerger()
- outpaint_negative_prompt_2_tokens = TokenMerger()
-
- for i, node in enumerate(self.nodes):
- if node.step.prompt is not None:
- outpaint_prompt_tokens.add(node.step.prompt)
- if node.step.prompt_2 is not None:
- outpaint_prompt_2_tokens.add(node.step.prompt_2)
- if node.step.negative_prompt is not None:
- outpaint_negative_prompt_tokens.add(node.step.negative_prompt)
- if node.step.negative_prompt_2 is not None:
- outpaint_negative_prompt_2_tokens.add(node.step.negative_prompt_2)
-
- if self.prompt is not None:
- outpaint_prompt_tokens.add(self.prompt, 2) # Weighted
- if self.prompt_2 is not None:
- outpaint_prompt_2_tokens.add(self.prompt_2, 2) # Weighted
- if self.negative_prompt is not None:
- outpaint_negative_prompt_tokens.add(self.negative_prompt, 2)
- if self.negative_prompt_2 is not None:
- outpaint_negative_prompt_2_tokens.add(self.negative_prompt_2, 2)
-
- def outpaint_task_callback(task: str) -> None:
- """
- Wrap the outpaint task callback to include the overall plan task itself
- """
- task_callback(f"Outpaint: {task}")
-
- invocation_kwargs["strength"] = 0.99
- invocation_kwargs["task_callback"] = outpaint_task_callback
- if self.guidance_scale is not None:
- invocation_kwargs["guidance_scale"] = self.guidance_scale
- if self.num_inference_steps is not None:
- invocation_kwargs["num_inference_steps"] = self.num_inference_steps
-
- for i, image in enumerate(images):
- pipeline.controlnet = None
- if image_callback is not None:
- def outpaint_image_callback(callback_images: List[PIL.Image.Image]) -> None:
- """
- Wrap the original image callback so we're actually pasting the initial image on the main canvas
- """
- images[i] = callback_images[0]
- image_callback(images) # type: ignore
- else:
- outpaint_image_callback = None # type: ignore
-
- result = pipeline(
- image=image,
- mask=outpaint_mask,
- prompt=str(outpaint_prompt_tokens),
- prompt_2=str(outpaint_prompt_2_tokens),
- negative_prompt=str(outpaint_negative_prompt_tokens),
- negative_prompt_2=str(outpaint_negative_prompt_2_tokens),
- latent_callback=outpaint_image_callback,
- num_images_per_prompt=1,
- **invocation_kwargs,
- )
-
- images[i] = result["images"][0]
- nsfw_content_detected[i] = nsfw_content_detected[i] or (
- "nsfw_content_detected" in result and result["nsfw_content_detected"][0]
- )
-
- return images, nsfw_content_detected
-
- def get_serialization_dict(self, image_directory: Optional[str] = None) -> Dict[str, Any]:
- """
- Serializes the whole plan for storage or passing between processes.
- """
- serialized_image = self.image
- if image_directory is not None and isinstance(self.image, PIL.Image.Image):
- serialized_image = os.path.join(image_directory, f"{get_uuid()}.png")
- self.image.save(serialized_image)
-
- return {
- "model": self.model,
- "refiner": self.refiner,
- "inpainter": self.inpainter,
- "lora": self.lora,
- "lycoris": self.lycoris,
- "inversion": self.inversion,
- "scheduler": self.scheduler,
- "vae": self.vae,
- "refiner_vae": self.refiner_vae,
- "inpainter_vae": self.inpainter_vae,
- "width": self.width,
- "height": self.height,
- "size": self.size,
- "inpainter_size": self.inpainter_size,
- "refiner_size": self.refiner_size,
- "seed": self.seed,
- "prompt": self.prompt,
- "prompt_2": self.prompt_2,
- "negative_prompt": self.negative_prompt,
- "negative_prompt_2": self.negative_prompt_2,
- "image": serialized_image,
- "nodes": [node.get_serialization_dict(image_directory) for node in self.nodes],
- "samples": self.samples,
- "iterations": self.iterations,
- "upscale_steps": self.upscale_steps,
- "chunking_size": self.chunking_size,
- "chunking_mask_type": self.chunking_mask_type,
- "chunking_mask_kwargs": self.chunking_mask_kwargs,
- "build_tensorrt": self.build_tensorrt,
- "outpaint": self.outpaint,
- "clip_skip": self.clip_skip,
- "freeu_factors": self.freeu_factors,
- "guidance_scale": self.guidance_scale,
- "num_inference_steps": self.num_inference_steps,
- "noise_offset": self.noise_offset,
- "noise_method": self.noise_method,
- "noise_blend_method": self.noise_blend_method,
- }
-
- @staticmethod
- def deserialize_dict(plan_dict: Dict[str, Any]) -> DiffusionPlan:
- """
- Given a serialized dictionary, instantiate a diffusion plan.
- """
- kwargs = {
- "model": plan_dict["model"],
- "nodes": [DiffusionNode.deserialize_dict(node_dict) for node_dict in plan_dict.get("nodes", [])],
- }
-
- for arg in [
- "refiner",
- "inpainter",
- "size",
- "refiner_size",
- "inpainter_size",
- "lora",
- "lycoris",
- "inversion",
- "scheduler",
- "vae",
- "refiner_vae",
- "inpainter_vae",
- "width",
- "height",
- "chunking_size",
- "chunking_mask_type",
- "chunking_mask_kwargs",
- "samples",
- "iterations",
- "seed",
- "prompt",
- "prompt_2",
- "negative_prompt",
- "negative_prompt_2",
- "build_tensorrt",
- "outpaint",
- "upscale_steps",
- "clip_skip",
- "freeu_factors",
- "guidance_scale",
- "num_inference_steps",
- "noise_offset",
- "noise_method",
- "noise_blend_method",
- ]:
- if arg in plan_dict:
- kwargs[arg] = plan_dict[arg]
-
- if "image" in plan_dict:
- if isinstance(plan_dict["image"], str) and os.path.exists(plan_dict["image"]):
- kwargs["image"] = PIL.Image.open(plan_dict["image"])
- else:
- kwargs["image"] = plan_dict["image"]
-
- result = DiffusionPlan(**kwargs)
- return result
-
- @staticmethod
- def create_mask(width: int, height: int, left: int, top: int, right: int, bottom: int) -> PIL.Image.Image:
- """
- Creates a mask from 6 dimensions
- """
- image = PIL.Image.new("RGB", (width, height))
- draw = PIL.ImageDraw.Draw(image)
- draw.rectangle([(left, top), (right, bottom)], fill="#ffffff")
- return image
-
- @staticmethod
- def upscale_image(
- size: int,
- image: PIL.Image,
- upscale_steps: Union[UpscaleStepDict, List[UpscaleStepDict]],
- refiner_size: Optional[int] = None,
- inpainter_size: Optional[int] = None,
- model: Optional[str] = None,
- refiner: Optional[str] = None,
- inpainter: Optional[str] = None,
- lora: Optional[Union[str, List[str], Tuple[str, float], List[Union[str, Tuple[str, float]]]]] = None,
- lycoris: Optional[Union[str, List[str], Tuple[str, float], List[Union[str, Tuple[str, float]]]]] = None,
- inversion: Optional[Union[str, List[str]]] = None,
- scheduler: Optional[SCHEDULER_LITERAL] = None,
- vae: Optional[str] = None,
- refiner_vae: Optional[str] = None,
- inpainter_vae: Optional[str] = None,
- seed: Optional[int] = None,
- noise_offset: Optional[float] = None,
- noise_method: NOISE_METHOD_LITERAL = "perlin",
- noise_blend_method: LATENT_BLEND_METHOD_LITERAL = "inject",
- **kwargs: Any,
- ) -> DiffusionPlan:
- """
- Generates a plan to upscale a single image
- """
- if kwargs:
- logger.warning(f"Plan `upscale_image` keyword arguments ignored: {kwargs}")
- width, height = image.size
- nodes: List[NodeDict] = [
- {
- "image": image,
- "w": width,
- "h": height,
- "x": 0,
- "y": 0,
- }
- ]
- return DiffusionPlan.assemble(
- size=size,
- refiner_size=refiner_size,
- inpainter_size=inpainter_size,
- model=model,
- refiner=refiner,
- inpainter=inpainter,
- lora=lora,
- lycoris=lycoris,
- inversion=inversion,
- scheduler=scheduler,
- vae=vae,
- refiner_vae=refiner_vae,
- inpainter_vae=inpainter_vae,
- seed=seed,
- width=width,
- height=height,
- upscale_steps=upscale_steps,
- noise_offset=noise_offset,
- noise_method=noise_method,
- noise_blend_method=noise_blend_method,
- nodes=nodes
- )
-
- @staticmethod
- def assemble(
- size: int,
- refiner_size: Optional[int] = None,
- inpainter_size: Optional[int] = None,
- model: Optional[str] = None,
- refiner: Optional[str] = None,
- inpainter: Optional[str] = None,
- lora: Optional[Union[str, List[str], Tuple[str, float], List[Union[str, Tuple[str, float]]]]] = None,
- lycoris: Optional[Union[str, List[str], Tuple[str, float], List[Union[str, Tuple[str, float]]]]] = None,
- inversion: Optional[Union[str, List[str]]] = None,
- scheduler: Optional[SCHEDULER_LITERAL] = None,
- vae: Optional[str] = None,
- refiner_vae: Optional[str] = None,
- inpainter_vae: Optional[str] = None,
- model_prompt: Optional[str] = None,
- model_prompt_2: Optional[str] = None,
- model_negative_prompt: Optional[str] = None,
- model_negative_prompt_2: Optional[str] = None,
- samples: int = 1,
- iterations: int = 1,
- seed: Optional[int] = None,
- width: Optional[int] = None,
- height: Optional[int] = None,
- nodes: List[NodeDict] = [],
- chunking_size: Optional[int] = None,
- chunking_mask_type: Optional[MASK_TYPE_LITERAL] = None,
- chunking_mask_kwargs: Optional[Dict[str, Any]] = None,
- prompt: Optional[str] = None,
- prompt_2: Optional[str] = None,
- negative_prompt: Optional[str] = None,
- negative_prompt_2: Optional[str] = None,
- clip_skip: Optional[int] = None,
- freeu_factors: Optional[Tuple[float, float, float, float]] = None,
- num_inference_steps: Optional[int] = DEFAULT_INFERENCE_STEPS,
- mask: Optional[Union[str, PIL.Image.Image]] = None,
- image: Optional[Union[str, PIL.Image.Image]] = None,
- fit: Optional[IMAGE_FIT_LITERAL] = None,
- anchor: Optional[IMAGE_ANCHOR_LITERAL] = None,
- strength: Optional[float] = None,
- ip_adapter_images: Optional[List[IPAdapterImageDict]] = None,
- ip_adapter_plus: bool = False,
- ip_adapter_face: bool = False,
- control_images: Optional[List[ControlImageDict]] = None,
- remove_background: bool = False,
- fill_background: bool = False,
- scale_to_model_size: bool = False,
- invert_mask: bool = False,
- crop_inpaint: bool = True,
- inpaint_feather: int = 32,
- guidance_scale: Optional[float] = DEFAULT_GUIDANCE_SCALE,
- refiner_start: Optional[float] = DEFAULT_REFINER_START,
- refiner_strength: Optional[float] = DEFAULT_REFINER_STRENGTH,
- refiner_guidance_scale: Optional[float] = DEFAULT_REFINER_GUIDANCE_SCALE,
- refiner_aesthetic_score: Optional[float] = DEFAULT_AESTHETIC_SCORE,
- refiner_negative_aesthetic_score: Optional[float] = DEFAULT_NEGATIVE_AESTHETIC_SCORE,
- refiner_prompt: Optional[str] = None,
- refiner_prompt_2: Optional[str] = None,
- refiner_negative_prompt: Optional[str] = None,
- refiner_negative_prompt_2: Optional[str] = None,
- upscale_steps: Optional[Union[UpscaleStepDict, List[UpscaleStepDict]]] = None,
- noise_offset: Optional[float] = None,
- noise_method: NOISE_METHOD_LITERAL = "perlin",
- noise_blend_method: LATENT_BLEND_METHOD_LITERAL = "inject",
- **kwargs: Any,
- ) -> DiffusionPlan:
- """
- Assembles a diffusion plan from step dictionaries.
- """
- if kwargs:
- logger.warning(f"Plan `assemble` keyword arguments ignored: {kwargs}")
-
- # First instantiate the plan
- plan = DiffusionPlan(
- model=model,
- refiner=refiner,
- inpainter=inpainter,
- lora=lora,
- lycoris=lycoris,
- inversion=inversion,
- scheduler=scheduler,
- vae=vae,
- refiner_vae=refiner_vae,
- inpainter_vae=inpainter_vae,
- samples=samples,
- iterations=iterations,
- size=size,
- refiner_size=refiner_size,
- inpainter_size=inpainter_size,
- seed=seed,
- width=width,
- height=height,
- prompt=prompt,
- prompt_2=prompt_2,
- negative_prompt=negative_prompt,
- negative_prompt_2=negative_prompt_2,
- chunking_size=chunking_size,
- chunking_mask_type=chunking_mask_type,
- chunking_mask_kwargs=chunking_mask_kwargs,
- clip_skip=clip_skip,
- freeu_factors=freeu_factors,
- guidance_scale=guidance_scale,
- num_inference_steps=num_inference_steps,
- noise_offset=noise_offset,
- noise_method=noise_method,
- noise_blend_method=noise_blend_method,
- nodes=[],
- )
-
- # We'll assemble multiple token sets for overall diffusion
- upscale_prompt_tokens = TokenMerger()
- upscale_prompt_2_tokens = TokenMerger()
- upscale_negative_prompt_tokens = TokenMerger()
- upscale_negative_prompt_2_tokens = TokenMerger()
-
- # Helper method for getting the upscale list with merged prompts
- def get_upscale_steps() -> Optional[Union[UpscaleStepDict, List[UpscaleStepDict]]]:
- if upscale_steps is None:
- return None
- elif isinstance(upscale_steps, list):
- return [
- {
- **step, # type: ignore[misc]
- **{
- "prompt": str(
- upscale_prompt_tokens.clone(step.get("prompt", None))
- ),
- "prompt_2": str(
- upscale_prompt_2_tokens.clone(step.get("prompt_2", None))
- ),
- "negative_prompt": str(
- upscale_negative_prompt_tokens.clone(step.get("negative_prompt", None))
- ),
- "negative_prompt_2": str(
- upscale_negative_prompt_2_tokens.clone(step.get("negative_prompt_2", None))
- ),
- }
- }
- for step in upscale_steps
- ]
- else:
- return { # type: ignore[return-value]
- **upscale_steps, # type: ignore[misc]
- **{
- "prompt": str(
- upscale_prompt_tokens.clone(upscale_steps.get("prompt", None))
- ),
- "prompt_2": str(
- upscale_prompt_2_tokens.clone(upscale_steps.get("prompt_2", None))
- ),
- "negative_prompt": str(
- upscale_negative_prompt_tokens.clone(upscale_steps.get("negative_prompt", None))
- ),
- "negative_prompt_2": str(
- upscale_negative_prompt_2_tokens.clone(upscale_steps.get("negative_prompt_2", None))
- ),
- }
- }
-
- refiner_prompt_tokens = TokenMerger()
- refiner_prompt_2_tokens = TokenMerger()
- refiner_negative_prompt_tokens = TokenMerger()
- refiner_negative_prompt_2_tokens = TokenMerger()
-
- if prompt:
- upscale_prompt_tokens.add(prompt, GLOBAL_PROMPT_UPSCALE_WEIGHT)
- if prompt_2:
- upscale_prompt_2_tokens.add(prompt_2, GLOBAL_PROMPT_UPSCALE_WEIGHT)
- if negative_prompt:
- upscale_negative_prompt_tokens.add(negative_prompt, GLOBAL_PROMPT_UPSCALE_WEIGHT)
- if negative_prompt_2:
- upscale_negative_prompt_2_tokens.add(negative_prompt_2, GLOBAL_PROMPT_UPSCALE_WEIGHT)
-
- if model_prompt:
- refiner_prompt_tokens.add(model_prompt, MODEL_PROMPT_WEIGHT)
- upscale_prompt_tokens.add(model_prompt, MODEL_PROMPT_WEIGHT)
- if model_prompt_2:
- refiner_prompt_2_tokens.add(model_prompt_2, MODEL_PROMPT_WEIGHT)
- upscale_prompt_2_tokens.add(model_prompt_2, MODEL_PROMPT_WEIGHT)
-
- if model_negative_prompt:
- refiner_negative_prompt_tokens.add(model_negative_prompt, MODEL_PROMPT_WEIGHT)
- upscale_negative_prompt_tokens.add(model_negative_prompt, MODEL_PROMPT_WEIGHT)
- if model_negative_prompt_2:
- refiner_negative_prompt_2_tokens.add(model_negative_prompt_2, MODEL_PROMPT_WEIGHT)
- upscale_negative_prompt_2_tokens.add(model_negative_prompt_2, MODEL_PROMPT_WEIGHT)
-
- if refiner_prompt:
- refiner_prompt_tokens.add(refiner_prompt)
- refiner_prompt = str(refiner_prompt_tokens)
- else:
- refiner_prompt = None
-
- if refiner_prompt_2:
- refiner_prompt_2_tokens.add(refiner_prompt_2)
- refiner_prompt_2 = str(refiner_prompt_2_tokens)
- else:
- refiner_prompt_2 = None
-
- if refiner_negative_prompt:
- refiner_negative_prompt_tokens.add(refiner_negative_prompt)
- refiner_negative_prompt = str(refiner_negative_prompt_tokens)
- else:
- refiner_negative_prompt = None
-
- if refiner_negative_prompt_2:
- refiner_negative_prompt_2_tokens.add(refiner_negative_prompt_2)
- refiner_negative_prompt_2 = str(refiner_negative_prompt_2_tokens)
- else:
- refiner_negative_prompt_2 = None
-
- # Now assemble the diffusion steps
- node_count = len(nodes)
-
- if node_count == 0:
- # No nodes/canvas, create a plan from one given step
- name = "Text to Image"
- prompt_tokens = TokenMerger()
- if prompt:
- prompt_tokens.add(prompt)
- if model_prompt:
- prompt_tokens.add(model_prompt, MODEL_PROMPT_WEIGHT)
-
- prompt_2_tokens = TokenMerger()
- if prompt_2:
- prompt_2_tokens.add(prompt_2)
- if model_prompt_2:
- prompt_2_tokens.add(model_prompt_2, MODEL_PROMPT_WEIGHT)
-
- negative_prompt_tokens = TokenMerger()
- if negative_prompt:
- negative_prompt_tokens.add(negative_prompt)
- if model_negative_prompt:
- negative_prompt_tokens.add(model_negative_prompt, MODEL_PROMPT_WEIGHT)
-
- negative_prompt_2_tokens = TokenMerger()
- if negative_prompt_2:
- negative_prompt_2_tokens.add(negative_prompt_2)
- if model_negative_prompt_2:
- negative_prompt_2_tokens.add(model_negative_prompt_2, MODEL_PROMPT_WEIGHT)
-
- if image:
- if isinstance(image, str):
- image = PIL.Image.open(image)
- if width and height:
- image = fit_image(image, width, height, fit, anchor)
- else:
- width, height = image.size
-
- if mask:
- if isinstance(mask, str):
- mask = PIL.Image.open(mask)
- if width and height:
- mask = fit_image(mask, width, height, fit, anchor)
- else:
- width, height = mask.size
-
- if ip_adapter_images:
- for i, ip_adapter_image_dict in enumerate(ip_adapter_images):
- ip_adapter_image = ip_adapter_image_dict["image"]
- ip_adapter_anchor = ip_adapter_image_dict.get("anchor", anchor)
- ip_adapter_fit = ip_adapter_image_dict.get("fit", fit)
-
- if isinstance(ip_adapter_image, str):
- ip_adapter_image = PIL.Image.open(ip_adapter_image)
- if width and height:
- ip_adapter_image = fit_image(ip_adapter_image, width, height, fit, anchor)
- else:
- width, height = ip_adapter_image.size
-
- ip_adapter_images[i] = { # type: ignore[call-overload]
- "image": ip_adapter_image,
- "scale": ip_adapter_image_dict.get("scale", 0.5),
- }
-
- if control_images:
- for i, control_image_dict in enumerate(control_images):
- control_image = control_image_dict["image"]
- control_anchor = control_image_dict.get("anchor", anchor)
- control_fit = control_image_dict.get("fit", fit)
-
- if isinstance(control_image, str):
- control_image = PIL.Image.open(control_image)
- if width and height:
- control_image = fit_image(control_image, width, height, fit, anchor)
- else:
- width, height = control_image.size
-
- control_images[i] = { # type: ignore[call-overload]
- "image": control_image,
- "controlnet": control_image_dict["controlnet"],
- "process": control_image_dict.get("process", True),
- "scale": control_image_dict.get("scale", 0.5),
- "invert": control_image_dict.get("invert", False),
- "start": control_image_dict.get("start", None),
- "end": control_image_dict.get("end", None),
- }
-
- if mask and invert_mask:
- mask = PIL.ImageOps.invert(mask.convert("L"))
- if control_images and image and mask and ip_adapter_images:
- name = "Controlled Inpainting with Image Prompting"
- elif control_images and image and mask:
- name = "Controlled Inpainting"
- elif control_images and image and ip_adapter_images and strength:
- name = "Controlled Image to Image with Image Prompting"
- elif control_images and image and ip_adapter_images:
- name = "Controlled Text to Image with Image Prompting"
- elif control_images and image and strength:
- name = "Controlled Image to Image"
- elif control_images:
- name = "Controlled Text to Image"
- elif image and mask and ip_adapter_images:
- name = "Inpainting with Image Prompting"
- elif image and mask:
- name = "Inpainting"
- elif image and strength and ip_adapter_images:
- name = "Image to Image with Image Prompting"
- elif image and ip_adapter_images:
- name = "Text to Image with Image Prompting"
- elif image and strength:
- name = "Image to Image"
- else:
- name = "Text to Image"
-
- step = DiffusionStep(
- name=name,
- width=width,
- height=height,
- prompt=str(prompt_tokens),
- prompt_2=str(prompt_2_tokens),
- negative_prompt=str(negative_prompt_tokens),
- negative_prompt_2=str(negative_prompt_2_tokens),
- num_inference_steps=num_inference_steps,
- guidance_scale=guidance_scale,
- strength=strength,
- refiner_start=refiner_start,
- refiner_strength=refiner_strength,
- refiner_guidance_scale=refiner_guidance_scale,
- refiner_aesthetic_score=refiner_aesthetic_score,
- refiner_negative_aesthetic_score=refiner_negative_aesthetic_score,
- refiner_prompt=refiner_prompt,
- refiner_prompt_2=refiner_prompt_2,
- refiner_negative_prompt=refiner_negative_prompt,
- refiner_negative_prompt_2=refiner_negative_prompt_2,
- crop_inpaint=crop_inpaint,
- inpaint_feather=inpaint_feather,
- image=image,
- mask=mask,
- remove_background=remove_background,
- fill_background=fill_background,
- scale_to_model_size=scale_to_model_size,
- ip_adapter_plus=ip_adapter_plus,
- ip_adapter_face=ip_adapter_face,
- ip_adapter_images=ip_adapter_images,
- control_images=control_images
- )
-
- if not width:
- width = plan.width # Default
- if not height:
- height = plan.height # Default
-
- # Change plan defaults if passed
- plan.width = width
- plan.height = height
-
- # Assemble node
- plan.nodes = [DiffusionNode([(0, 0), (width, height)], step)]
- plan.upscale_steps = get_upscale_steps()
- return plan
-
- # Using the diffusion canvas, assemble a multi-step plan
- for i, node_dict in enumerate(nodes):
- step = DiffusionStep(
- num_inference_steps=num_inference_steps,
- guidance_scale=guidance_scale,
- refiner_start=refiner_start,
- refiner_strength=refiner_strength,
- refiner_guidance_scale=refiner_guidance_scale,
- refiner_aesthetic_score=refiner_aesthetic_score,
- refiner_negative_aesthetic_score=refiner_negative_aesthetic_score,
- refiner_prompt=refiner_prompt,
- refiner_prompt_2=refiner_prompt_2,
- refiner_negative_prompt=refiner_negative_prompt,
- refiner_negative_prompt_2=refiner_negative_prompt_2,
- )
-
- node_left = int(node_dict.get("x", 0))
- node_top = int(node_dict.get("y", 0))
- node_fit = node_dict.get("fit", None)
- node_anchor = node_dict.get("anchor", None)
-
- node_prompt = node_dict.get("prompt", None)
- node_prompt_2 = node_dict.get("prompt_2", None)
- node_negative_prompt = node_dict.get("negative_prompt", None)
- node_negative_prompt_2 = node_dict.get("negative_prompt_2", None)
- node_strength: Optional[float] = node_dict.get("strength", None)
- node_image = node_dict.get("image", None)
- node_inpaint_mask = node_dict.get("mask", None)
- node_crop_inpaint = node_dict.get("crop_inpaint", crop_inpaint)
- node_inpaint_feather = node_dict.get("inpaint_feather", inpaint_feather)
- node_invert_mask = node_dict.get("invert_mask", False)
- node_scale_to_model_size = bool(node_dict.get("scale_to_model_size", False))
- node_remove_background = bool(node_dict.get("remove_background", False))
-
- node_ip_adapter_plus = bool(node_dict.get("ip_adapter_plus", False))
- node_ip_adapter_face = bool(node_dict.get("ip_adapter_face", False))
- node_ip_adapter_images: Optional[List[IPAdapterImageDict]] = node_dict.get("ip_adapter_images", None)
- node_control_images: Optional[List[ControlImageDict]] = node_dict.get("control_images", None)
-
- node_inference_steps: Optional[int] = node_dict.get("inference_steps", None) # type: ignore[assignment]
- node_guidance_scale: Optional[float] = node_dict.get("guidance_scale", None) # type: ignore[assignment]
- node_refiner_start: Optional[float] = node_dict.get("refiner_start", None) # type: ignore[assignment]
- node_refiner_strength: Optional[float] = node_dict.get("refiner_strength", None) # type: ignore[assignment]
- node_refiner_guidance_scale: Optional[float] = node_dict.get("refiner_guidance_scale", None) # type: ignore[assignment]
- node_refiner_aesthetic_score: Optional[float] = node_dict.get("refiner_aesthetic_score", None) # type: ignore[assignment]
- node_refiner_negative_aesthetic_score: Optional[float] = node_dict.get("refiner_negative_aesthetic_score", None) # type: ignore[assignment]
-
- node_prompt_tokens = TokenMerger()
- node_prompt_2_tokens = TokenMerger()
- node_negative_prompt_tokens = TokenMerger()
- node_negative_prompt_2_tokens = TokenMerger()
-
- if "w" in node_dict:
- node_width = int(node_dict["w"])
- elif node_image is not None: # type: ignore[unreachable]
- node_width, _ = node_image.size
- elif node_inpaint_mask is not None:
- node_width, _ = node_inpaint_mask.size
- elif node_ip_adapter_images:
- node_width, _ = node_ip_adapter_images[0]["image"].size
- elif node_control_images:
- node_width, _ = node_control_images[next(iter(node_control_images))][0]["image"].size
- else:
- raise ValueError(f"Node {i} missing width, pass 'w' or an image")
- if "h" in node_dict:
- node_height = int(node_dict["h"])
- elif node_image is not None: # type: ignore[unreachable]
- _, node_height = node_image.size
- elif node_inpaint_mask is not None:
- _, node_height = node_inpaint_mask.size
- elif node_ip_adapter_images:
- _, node_height = node_ip_adapter_images[0]["image"].size
- elif node_control_images:
- _, node_height = node_control_images[next(iter(node_control_images))][0]["image"].size
- else:
- raise ValueError(f"Node {i} missing height, pass 'h' or an image")
-
- node_bounds = [
- (node_left, node_top),
- (node_left + node_width, node_top + node_height),
- ]
-
- if node_prompt:
- node_prompt_tokens.add(node_prompt)
- upscale_prompt_tokens.add(node_prompt, UPSCALE_PROMPT_STEP_WEIGHT / node_count)
- if prompt and (node_image or node_ip_adapter_images or node_control_images):
- # Only add global prompt to image nodes, it overrides too much on region nodes
- node_prompt_tokens.add(prompt, GLOBAL_PROMPT_STEP_WEIGHT)
- if model_prompt:
- node_prompt_tokens.add(model_prompt, MODEL_PROMPT_WEIGHT)
-
- if node_prompt_2:
- node_prompt_2_tokens.add(node_prompt_2)
- upscale_prompt_2_tokens.add(node_prompt_2, UPSCALE_PROMPT_STEP_WEIGHT / node_count)
- if prompt_2 and (node_image or node_ip_adapter_images or node_control_images):
- # Only add global prompt to image nodes, it overrides too much on region nodes
- node_prompt_2_tokens.add(prompt_2, GLOBAL_PROMPT_STEP_WEIGHT)
- if model_prompt_2:
- node_prompt_2_tokens.add(model_prompt_2, MODEL_PROMPT_WEIGHT)
-
- if node_negative_prompt:
- node_negative_prompt_tokens.add(node_negative_prompt)
- upscale_negative_prompt_tokens.add(node_negative_prompt, UPSCALE_PROMPT_STEP_WEIGHT / node_count)
- if negative_prompt and (node_image or node_ip_adapter_images or node_control_images):
- # Only add global prompt to image nodes, it overrides too much on region nodes
- node_negative_prompt_tokens.add(negative_prompt, GLOBAL_PROMPT_STEP_WEIGHT)
- if model_negative_prompt:
- node_negative_prompt_tokens.add(model_negative_prompt, MODEL_PROMPT_WEIGHT)
-
- if node_negative_prompt_2:
- node_negative_prompt_tokens.add(node_negative_prompt_2)
- upscale_negative_prompt_2_tokens.add(node_negative_prompt_2, UPSCALE_PROMPT_STEP_WEIGHT / node_count)
- if negative_prompt_2 and (node_image or node_ip_adapter_images or node_control_images):
- # Only add global prompt to image nodes, it overrides too much on region nodes
- node_negative_prompt_2_tokens.add(negative_prompt_2, GLOBAL_PROMPT_STEP_WEIGHT)
- if model_negative_prompt_2:
- node_negative_prompt_2_tokens.add(model_negative_prompt_2, MODEL_PROMPT_WEIGHT)
-
- black = PIL.Image.new("RGB", (node_width, node_height), (0, 0, 0))
- white = PIL.Image.new("RGB", (node_width, node_height), (255, 255, 255))
- outpaint_steps: List[DiffusionStep] = []
- outpainted_images: Dict[int, PIL.Image.Image] = {}
-
- def prepare_image(
- image: PIL.Image.Image,
- outpaint_if_necessary: bool = False,
- mask: Optional[PIL.Image.Image] = None,
- fit: Optional[IMAGE_FIT_LITERAL] = None,
- anchor: Optional[IMAGE_ANCHOR_LITERAL] = None
- )-> Union[Tuple[PIL.Image.Image, PIL.Image.Image], Tuple[DiffusionStep, Any]]:
- """
- Checks if the image needs to be outpainted
- """
- nonlocal outpainted_images
- for step_index, outpainted_image in outpainted_images.items():
- if images_are_equal(outpainted_image, image):
- return outpaint_steps[step_index], 0
-
- image_mask = PIL.Image.new("RGB", (node_width, node_height), (255, 255, 255)) # Mask for outpainting if needed
- fitted_image = fit_image(image, node_width, node_height, fit, anchor)
- fitted_alpha = fitted_image.split()[-1]
- fitted_alpha_clamp = PIL.Image.eval(fitted_alpha, lambda a: 255 if a > 128 else 0)
-
- image_mask.paste(black, mask=fitted_alpha)
-
- if mask:
- image_mask.paste(mask)
- fitted_inverse_alpha = PIL.Image.eval(fitted_alpha_clamp, lambda a: 255 - a)
- image_mask.paste(white, mask=fitted_inverse_alpha)
-
- image_mask_r_min, image_mask_r_max = image_mask.getextrema()[1]
- image_needs_outpainting = image_mask_r_max > 0
-
- if image_needs_outpainting and outpaint_if_necessary:
- step = DiffusionStep(
- name=f"Outpaint Node {i+1}",
- image=fitted_image,
- mask=feather_mask(image_mask.convert("1")),
- prompt=str(node_prompt_tokens),
- prompt_2=str(node_prompt_2_tokens),
- negative_prompt=str(node_negative_prompt_tokens),
- negative_prompt_2=str(node_negative_prompt_2_tokens),
- guidance_scale=guidance_scale,
- num_inference_steps=node_inference_steps if node_inference_steps else num_inference_steps,
- crop_inpaint=node_crop_inpaint,
- inpaint_feather=node_inpaint_feather,
- refiner_start=refiner_start,
- refiner_strength=refiner_strength,
- refiner_guidance_scale=refiner_guidance_scale,
- refiner_aesthetic_score=refiner_aesthetic_score,
- refiner_negative_aesthetic_score=refiner_negative_aesthetic_score,
- refiner_prompt=refiner_prompt,
- refiner_prompt_2=refiner_prompt_2,
- refiner_negative_prompt=refiner_negative_prompt,
- refiner_negative_prompt_2=refiner_negative_prompt_2,
- )
- outpaint_steps.append(step)
- outpainted_images[len(outpaint_steps)-1] = image
- return step, None
- return fitted_image, image_mask
-
- will_infer = (node_image is not None and node_strength is not None) or node_inpaint_mask is not None
- node_fill_background = node_remove_background and will_infer
-
- if node_inpaint_mask:
- node_inpaint_mask = node_inpaint_mask.convert("L")
- if node_invert_mask:
- node_inpaint_mask = PIL.ImageOps.invert(node_inpaint_mask)
-
- if node_image:
- node_image, new_inpaint_mask = prepare_image(
- node_image,
- mask=node_inpaint_mask,
- fit=node_fit,
- anchor=node_anchor
- )
- if node_inpaint_mask:
- node_inpaint_mask = new_inpaint_mask
-
- if node_ip_adapter_images:
- node_ip_adapter_images = [
- {
- "image": prepare_image(
- ip_adapter_image["image"],
- fit=ip_adapter_image.get("fit", None),
- anchor=ip_adapter_image.get("anchor", None),
- outpaint_if_necessary=False
- )[0],
- "scale": ip_adapter_image.get("scale", 1.0),
- }
- for ip_adapter_image in node_ip_adapter_images
- ]
-
- if node_control_images:
- node_control_images = [
- {
- "image": prepare_image(
- control_image["image"],
- fit=control_image.get("fit", None),
- anchor=control_image.get("anchor", None),
- outpaint_if_necessary=True
- )[0],
- "controlnet": control_image["controlnet"],
- "scale": control_image.get("scale", 1.0),
- "process": control_image.get("process", True),
- "invert": control_image.get("invert", False),
- "start": control_image.get("start", None),
- "end": control_image.get("end", None),
- }
- for control_image in node_control_images
- ]
-
- node_prompt_str = str(node_prompt_tokens)
- node_prompt_2_str = str(node_prompt_2_tokens)
- node_negative_prompt_str = str(node_negative_prompt_tokens)
- node_negative_prompt_2_str = str(node_negative_prompt_2_tokens)
-
- if node_inpaint_mask:
- node_inpaint_mask = node_inpaint_mask.convert("L")
- node_inpaint_mask_r_min, node_inpaint_mask_r_max = node_inpaint_mask.getextrema()
- image_needs_inpainting = node_inpaint_mask_r_max > 0
- else:
- image_needs_inpainting = False
-
- if node_strength is None or not image_needs_inpainting:
- node_inpaint_mask = None
-
- if node_control_images and node_image and node_inpaint_mask and node_ip_adapter_images:
- name = "Controlled Inpainting with Image Prompting"
- elif node_control_images and node_image and node_inpaint_mask:
- name = "Controlled Inpainting"
- elif node_control_images and node_image and node_ip_adapter_images and node_strength:
- name = "Controlled Image to Image with Image Prompting"
- elif node_control_images and node_ip_adapter_images:
- name = "Controlled Text to Image with Image Prompting"
- elif node_control_images and node_image and node_strength:
- name = "Controlled Image to Image"
- elif node_control_images:
- name = "Controlled Text to Image"
- elif node_image and node_inpaint_mask and node_ip_adapter_images:
- name = "Inpainting with Image Prompting"
- elif node_image and node_inpaint_mask:
- name = "Inpainting"
- elif node_image and node_strength and node_ip_adapter_images:
- name = "Image to Image with Image Prompting"
- elif node_ip_adapter_images:
- name = "Text to Image with Image Prompting"
- elif node_image and node_strength:
- name = "Image to Image"
- elif node_image:
- name = "Image Pass-Through"
- if node_width == width and node_height == height:
- plan.outpaint = False
- node_prompt_str = None # type: ignore[assignment]
- node_prompt_2_str = None # type: ignore[assignment]
- node_negative_prompt_str = None # type: ignore[assignment]
- node_negative_prompt_2_str = None # type: ignore[assignment]
- else:
- name = "Text to Image"
-
- step = DiffusionStep(
- name=f"{name} Node {i+1}",
- width=node_width,
- height=node_height,
- image=node_image,
- mask=node_inpaint_mask,
- prompt=node_prompt_str,
- prompt_2=node_prompt_2_str,
- negative_prompt=node_negative_prompt_str,
- negative_prompt_2=node_negative_prompt_2_str,
- crop_inpaint=node_crop_inpaint,
- inpaint_feather=node_inpaint_feather,
- strength=node_strength,
- guidance_scale=guidance_scale,
- num_inference_steps=node_inference_steps if node_inference_steps else num_inference_steps,
- ip_adapter_images=node_ip_adapter_images,
- ip_adapter_plus=node_ip_adapter_plus,
- ip_adapter_face=node_ip_adapter_face,
- control_images=node_control_images,
- refiner_start=refiner_start,
- refiner_strength=refiner_strength,
- refiner_guidance_scale=refiner_guidance_scale,
- refiner_aesthetic_score=refiner_aesthetic_score,
- refiner_negative_aesthetic_score=refiner_negative_aesthetic_score,
- refiner_prompt=refiner_prompt,
- refiner_prompt_2=refiner_prompt_2,
- refiner_negative_prompt=refiner_negative_prompt,
- refiner_negative_prompt_2=refiner_negative_prompt_2,
- remove_background=node_remove_background,
- fill_background=node_fill_background,
- scale_to_model_size=node_scale_to_model_size
- )
-
- # Add step to plan
- plan.nodes.append(DiffusionNode(node_bounds, step))
- plan.upscale_steps = get_upscale_steps()
- return plan
diff --git a/src/python/enfugue/diffusion/process.py b/src/python/enfugue/diffusion/process.py
index dabd5103..f6c7c16c 100644
--- a/src/python/enfugue/diffusion/process.py
+++ b/src/python/enfugue/diffusion/process.py
@@ -16,8 +16,11 @@
List,
Dict,
Tuple,
+ Iterator,
TYPE_CHECKING,
)
+from typing_extensions import Self
+from contextlib import contextmanager
from multiprocessing import Process
from multiprocessing.queues import Queue
@@ -34,35 +37,180 @@
# We avoid importing them before the process starts at runtime,
# since we don't want torch to initialize itself.
from enfugue.diffusion.manager import DiffusionPipelineManager
- from enfugue.diffusion.plan import DiffusionPlan
+ from enfugue.diffusion.invocation import LayeredInvocation
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
-__all__ = ["DiffusionEngineProcess"]
+__all__ = [
+ "EngineProcess",
+ "DiffusionEngineProcess"
+]
-class DiffusionEngineProcess(Process):
+class EngineProcess(Process):
"""
- This process allows for easy two-way communication with a waiting
- Stable Diffusion Pipeline. Torch is only initiated after the process
- has began.
+ This process allows for easy two-way communication with a waiting subprocess
"""
-
POLLING_DELAY_MS = 500
IDLE_SEC = 15
def __init__(
self,
+ configuration: Optional[APIConfiguration],
instructions: Queue,
results: Queue,
- intermediates: Queue,
- configuration: Optional[APIConfiguration] = None,
+ intermediates: Queue
) -> None:
- super(DiffusionEngineProcess, self).__init__()
+ super(EngineProcess, self).__init__()
self.configuration = APIConfiguration()
+ if configuration is not None:
+ self.configuration = configuration
self.instructions = instructions
self.results = results
self.intermediates = intermediates
- if configuration is not None:
- self.configuration = configuration
+
+ @property
+ def idle_seconds(self) -> int:
+ """
+ Gets the maximum number of seconds to go idle before exiting.
+ """
+ return self.configuration.get("enfugue.idle", self.IDLE_SEC)
+
+ def clear_state(self) -> None:
+ """
+ Clears all state
+ """
+ try:
+ while True:
+ self.results.get_nowait()
+ except Empty:
+ pass
+ try:
+ while True:
+ self.instructions.get_nowait()
+ except Empty:
+ pass
+
+ def clear_responses(self, instruction_id: int) -> None:
+ """
+ Clears responses for a specific instruction ID
+ """
+ try:
+ while True:
+ next_result = self.results.get_nowait()
+ # Avoid parsing
+ if f'"id": {instruction_id}' not in next_result[:40]:
+ # Not ours, put back on the queue
+ self.results.put_nowait(next_result)
+ except Empty:
+ return
+
+ def handle(
+ self,
+ instruction_id: int,
+ instruction_action: str,
+ instruction_payload: Any
+ ) -> Any:
+ """
+ Handles a request - should be overwritten by implementations
+ """
+ raise NotImplementedError()
+
+ @contextmanager
+ def context(self) -> Iterator[Self]:
+ """
+ A contextmanager for the main engine process
+ """
+ from pibble.util.helpers import OutputCatcher
+ from pibble.util.log import ConfigurationLoggingContext
+
+ catcher = OutputCatcher()
+
+ with ConfigurationLoggingContext(
+ self.configuration,
+ prefix="enfugue.engine.logging."
+ ):
+ with catcher:
+ yield self
+ out, err = catcher.output()
+ if out:
+ logger.debug(f"stdout: {out}")
+ if err:
+ logger.info(f"stderr (may not be an error:) {err}")
+ catcher.clean()
+
+ def run(self) -> None:
+ """
+ This is the function that the process will run.
+ First instantiate the diffusion pipeline, then communicate as needed.
+ """
+ with self.context():
+ last_data = datetime.datetime.now()
+ idle_seconds = 0.0
+
+ while True:
+ try:
+ payload = self.instructions.get(timeout=self.POLLING_DELAY_MS / 1000)
+ except KeyboardInterrupt:
+ return
+ except Empty:
+ idle_seconds = (datetime.datetime.now() - last_data).total_seconds()
+ if idle_seconds > self.idle_seconds:
+ logger.info(
+ f"Reached maximum idle time after {idle_seconds:.1f} seconds, exiting engine process"
+ )
+ return
+ continue
+ except Exception as ex:
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(traceback.format_exc())
+ raise IOError("Received unexpected {0}, process will exit. {1}".format(type(ex).__name__, ex))
+
+ instruction = Serializer.deserialize(payload)
+ if not isinstance(instruction, dict):
+ logger.error(f"Unexpected non-dictionary argument {instruction}")
+ continue
+
+ instruction_id = instruction["id"]
+ instruction_action = instruction["action"]
+ instruction_payload = instruction.get("payload", None)
+
+ if instruction_action == "ping":
+ self.results.put(Serializer.serialize({"id": instruction_id, "result": "pong"}))
+ elif instruction_action in ["exit", "stop"]:
+ logger.debug("Exiting process")
+ self.clear_state()
+ return
+ else:
+ response = {"id": instruction_id, "payload": instruction_payload}
+ try:
+ response["result"] = self.handle(instruction_id, instruction_action, instruction_payload)
+ except Exception as ex:
+ response["error"] = qualify(type(ex))
+ response["message"] = str(ex)
+
+ # Also log so this appears in the engine log
+ logger.error(f"Received error {response['error']}: {response['message']}")
+ if logger.isEnabledFor(logging.DEBUG):
+ response["trace"] = traceback.format_exc()
+ logger.debug(response["trace"])
+ self.results.put(Serializer.serialize(response))
+ last_data = datetime.datetime.now()
+
+class DiffusionEngineProcess(EngineProcess):
+ def clear_state(self) -> None:
+ """
+ Clears all state
+ """
+ super(DiffusionEngineProcess, self).clear_state()
+ try:
+ while True:
+ self.intermediates.get_nowait()
+ except Empty:
+ pass
+ if hasattr(self, "_pipemanager"):
+ self.pipemanager.unload_inpainter("exiting")
+ self.pipemanager.unload_animator("exiting")
+ self.pipemanager.unload_refiner("exiting")
+ self.pipemanager.unload_pipeline("exiting")
@property
def pipemanager(self) -> DiffusionPipelineManager:
@@ -75,25 +223,49 @@ def pipemanager(self) -> DiffusionPipelineManager:
self._pipemanager = DiffusionPipelineManager(self.configuration, optimize=False)
return self._pipemanager
- @property
- def idle_seconds(self) -> int:
+ def get_diffusion_plan(self, payload: Dict[str, Any]) -> LayeredInvocation:
"""
- Gets the maximum number of seconds to go idle before exiting.
+ Deserializes a plan.
"""
- return self.configuration.get("enfugue.idle", self.IDLE_SEC)
+ from enfugue.diffusion.invocation import LayeredInvocation
+ return LayeredInvocation.assemble(**payload)
- def get_diffusion_plan(self, payload: Dict[str, Any]) -> DiffusionPlan:
+ def handle(
+ self,
+ instruction_id: int,
+ instruction_action: str,
+ instruction_payload: Any
+ ) -> Any:
"""
- Deserializes a plan.
+ Handles plans or direct invocations
"""
- from enfugue.diffusion.plan import DiffusionPlan
-
- return DiffusionPlan.deserialize_dict(payload)
+ if not isinstance(instruction_payload, dict):
+ raise ValueError(f"Expected dictionary payload.")
+ try:
+ if instruction_action == "plan":
+ intermediate_dir = instruction_payload.get("intermediate_dir", None)
+ intermediate_steps = instruction_payload.get("intermediate_steps", None)
+ logger.debug("Received invocation payload, constructing plan.")
+ plan = self.get_diffusion_plan(instruction_payload)
+ return self.execute_diffusion_plan(
+ instruction_id,
+ plan,
+ intermediate_dir=intermediate_dir,
+ intermediate_steps=intermediate_steps,
+ )
+ else:
+ logger.debug("Received direct invocation payload, executing.")
+ payload = self.check_invoke_kwargs(instruction_id, **instruction_payload)
+ return self.pipemanager(**payload)
+ finally:
+ self.pipemanager.stop_keepalive()
+ self.clear_intermediates(instruction_id)
+ del self.pipemanager.keepalive_callback
def execute_diffusion_plan(
self,
instruction_id: int,
- plan: DiffusionPlan,
+ plan: LayeredInvocation,
intermediate_dir: Optional[str] = None,
intermediate_steps: Optional[int] = None,
) -> StableDiffusionPipelineOutput:
@@ -175,6 +347,7 @@ def check_invoke_kwargs(
model: Optional[str] = None,
refiner: Optional[str] = None,
inpainter: Optional[str] = None,
+ animator: Optional[str] = None,
lora: Optional[Union[str, Tuple[str, float], List[Union[str, Tuple[str, float]]]]] = None,
inversion: Optional[Union[str, List[str]]] = None,
vae: Optional[str] = None,
@@ -187,12 +360,9 @@ def check_invoke_kwargs(
intermediate_dir: Optional[str] = None,
width: Optional[int] = None,
height: Optional[int] = None,
- chunking_size: Optional[int] = None,
+ tiling_stride: Optional[int] = None,
guidance_scale: Optional[float] = None,
num_inference_steps: Optional[int] = None,
- size: Optional[int] = None,
- refiner_size: Optional[int] = None,
- inpainter_size: Optional[int] = None,
**kwargs: Any,
) -> dict:
"""
@@ -218,6 +388,9 @@ def check_invoke_kwargs(
if inpainter is not None:
self.pipemanager.inpainter = inpainter # type: ignore
+ if animator is not None:
+ self.pipemanager.animator = animator # type: ignore
+
if vae is not None:
self.pipemanager.vae = vae # type: ignore
@@ -239,23 +412,14 @@ def check_invoke_kwargs(
if build_tensorrt is not None:
self.pipemanager.build_tensorrt = build_tensorrt
- if size is not None:
- self.pipemanager.size = size
-
- if refiner_size is not None:
- self.pipemanager.refiner_size = refiner_size
-
- if inpainter_size is not None:
- self.pipemanager.inpainter_size = inpainter_size
-
if width is not None:
kwargs["width"] = int(width)
if height is not None:
kwargs["height"] = int(height)
- if chunking_size is not None:
- kwargs["chunking_size"] = int(chunking_size)
+ if tiling_stride is not None:
+ kwargs["tiling_stride"] = int(tiling_stride)
if guidance_scale is not None:
kwargs["guidance_scale"] = float(guidance_scale)
@@ -276,6 +440,8 @@ def check_invoke_kwargs(
if controlnet not in control_images_dict:
control_images_dict[controlnet] = []
control_images_dict[controlnet].append((control_image, scale))
+ if kwargs.get("animation_frames", None) is not None:
+ self.pipemanager.animator_controlnets = list(control_images_dict.keys()) # type: ignore[assignment]
if kwargs.get("mask", None) is not None:
self.pipemanager.inpainter_controlnets = list(control_images_dict.keys()) # type: ignore[assignment]
else:
@@ -297,115 +463,3 @@ def clear_intermediates(self, instruction_id: int) -> None:
self.intermediates.put_nowait(next_intermediate)
except Empty:
return
-
- def clear_responses(self, instruction_id: int) -> None:
- """
- Clears responses for a specific instruction ID
- """
- try:
- while True:
- next_result = self.results.get_nowait()
- # Avoid parsing
- if f'"id": {instruction_id}' not in next_result[:40]:
- # Not ours, put back on the queue
- self.results.put_nowait(next_result)
- except Empty:
- return
-
- def run(self) -> None:
- """
- This is the function that the process will run.
- First instantiate the diffusion pipeline, then communicate as needed.
- """
- from pibble.util.helpers import OutputCatcher
- from pibble.util.log import ConfigurationLoggingContext
-
- catcher = OutputCatcher()
-
- with ConfigurationLoggingContext(self.configuration, prefix="enfugue.engine.logging."):
- with catcher:
- last_data = datetime.datetime.now()
- idle_seconds = 0.0
-
- while True:
- try:
- payload = self.instructions.get(timeout=self.POLLING_DELAY_MS / 1000)
- except KeyboardInterrupt:
- return
- except Empty:
- idle_seconds = (datetime.datetime.now() - last_data).total_seconds()
- if idle_seconds > self.idle_seconds:
- logger.info(
- f"Reached maximum idle time after {idle_seconds:.1f} seconds, exiting engine process"
- )
- return
- continue
- except Exception as ex:
- if logger.isEnabledFor(logging.DEBUG):
- logger.debug(traceback.format_exc())
- raise IOError("Received unexpected {0}, process will exit. {1}".format(type(ex).__name__, ex))
-
- instruction = Serializer.deserialize(payload)
- if not isinstance(instruction, dict):
- logger.error(f"Unexpected non-dictionary argument {instruction}")
- continue
-
- instruction_id = instruction["id"]
- instruction_action = instruction["action"]
- instruction_payload = instruction.get("payload", None)
-
- logger.debug(f"Received instruction {instruction_id}, action {instruction_action}")
- if instruction_action == "ping":
- logger.debug("Responding with 'pong'")
- self.results.put(Serializer.serialize({"id": instruction_id, "result": "pong"}))
- elif instruction_action in ["exit", "stop"]:
- logger.debug("Exiting process")
- self.pipemanager.unload_inpainter("exiting")
- self.pipemanager.unload_refiner("exiting")
- self.pipemanager.unload_pipeline("exiting")
- return
- elif instruction_action in ["invoke", "plan"]:
- response = {"id": instruction_id, "payload": instruction_payload}
- try:
- if instruction_action == "plan":
- intermediate_dir = instruction_payload.get("intermediate_dir", None)
- intermediate_steps = instruction_payload.get("intermediate_steps", None)
- plan = self.get_diffusion_plan(instruction_payload)
- response["result"] = self.execute_diffusion_plan(
- instruction_id,
- plan,
- intermediate_dir=intermediate_dir,
- intermediate_steps=intermediate_steps,
- )
- else:
- payload = self.check_invoke_kwargs(instruction_id, **instruction_payload)
- response["result"] = self.pipemanager(**payload)
- except Exception as ex:
- response["error"] = qualify(type(ex))
- response["message"] = str(ex)
-
- # Also log so this appears in the engine log
- logger.error(f"Received error {response['error']}: {response['message']}")
- if logger.isEnabledFor(logging.DEBUG):
- response["trace"] = traceback.format_exc()
- logger.debug(response["trace"])
-
- del self.pipemanager.keepalive_callback
- self.results.put(Serializer.serialize(response))
- self.clear_intermediates(instruction_id)
- else:
- self.results.put(
- Serializer.serialize(
- {
- "id": instruction_id,
- "error": f"Unknown action '{instruction_action}'",
- }
- )
- )
- out, err = catcher.output()
- if out:
- logger.debug(f"stdout: {out}")
- if err:
- logger.error(f"stderr: {err}")
- catcher.clean()
- last_data = datetime.datetime.now()
diff --git a/src/python/enfugue/diffusion/rt/README.md b/src/python/enfugue/diffusion/rt/README.md
index 4581903b..46c5de86 100644
--- a/src/python/enfugue/diffusion/rt/README.md
+++ b/src/python/enfugue/diffusion/rt/README.md
@@ -30,6 +30,8 @@ def __init__(
chunking_size: int = 32,
chunking_mask_type: MASK_TYPE_LITERAL = "bilinear",
chunking_mask_kwargs: Dict[str, Any] = {},
+ temporal_engine_size: int = 16,
+ temporal_chunking_size: int = 4
max_batch_size: int = 16,
force_engine_rebuild: bool = False,
vae_engine_dir: Optional[str] = None,
diff --git a/src/python/enfugue/diffusion/rt/pipeline.py b/src/python/enfugue/diffusion/rt/pipeline.py
index a630ffb8..70e3d47b 100644
--- a/src/python/enfugue/diffusion/rt/pipeline.py
+++ b/src/python/enfugue/diffusion/rt/pipeline.py
@@ -32,7 +32,7 @@
if TYPE_CHECKING:
from enfugue.diffusers.support.ip import IPAdapter
- from enfugue.diffusion.constants import MASK_TYPE_LITERAL
+ from enfugue.diffusion.constants import MASK_TYPE_LITERAL, IP_ADAPTER_LITERAL
class EnfugueTensorRTStableDiffusionPipeline(EnfugueStableDiffusionPipeline):
models: Dict[str, BaseModel]
@@ -57,9 +57,9 @@ def __init__(
controlnets: Optional[Dict[str, ControlNetModel]] = None,
ip_adapter: Optional[IPAdapter] = None,
engine_size: int = 512, # Recommended even for machines that can handle more
- chunking_size: int = 32,
- chunking_mask_type: MASK_TYPE_LITERAL = "bilinear",
- chunking_mask_kwargs: Dict[str, Any] = {},
+ tiling_size: int = 32,
+ tiling_mask_type: MASK_TYPE_LITERAL = "bilinear",
+ tiling_mask_kwargs: Dict[str, Any] = {},
max_batch_size: int = 16,
# ONNX export parameters
force_engine_rebuild: bool = False,
@@ -73,6 +73,9 @@ def __init__(
build_half: bool = False,
onnx_opset: int = 17,
) -> None:
+ if engine_size is None:
+ raise ValueError("Cannot use TensorRT with a 'None' engine size.")
+
super(EnfugueTensorRTStableDiffusionPipeline, self).__init__(
vae=vae,
vae_preview=vae_preview,
@@ -91,14 +94,15 @@ def __init__(
controlnets=controlnets,
ip_adapter=ip_adapter,
engine_size=engine_size,
- chunking_size=chunking_size,
- chunking_mask_type=chunking_mask_type,
- chunking_mask_kwargs=chunking_mask_kwargs,
+ tiling_size=tiling_size,
+ tiling_mask_type=tiling_mask_type,
+ tiling_mask_kwargs=tiling_mask_kwargs,
)
if self.controlnets:
# Hijack forward
self.unet.forward = self.controlled_unet_forward # type: ignore[method-assign]
+
self.vae.forward = self.vae.decode # type: ignore[method-assign]
self.onnx_opset = onnx_opset
self.force_engine_rebuild = force_engine_rebuild
@@ -112,7 +116,7 @@ def __init__(
self.build_preview_features = build_preview_features
self.max_batch_size = max_batch_size
- if self.build_dynamic_shape or self.engine_size > 512:
+ if self.build_dynamic_shape or self.engine_size > 512: # type: ignore
self.max_batch_size = 4
# Set default to DDIM - The PNDM default that some models have does not work with TRT
@@ -193,11 +197,11 @@ def load_resources(self, image_height: int, image_width: int, batch_size: int) -
def get_runtime_context(
self,
batch_size: int,
+ animation_frames: Optional[int],
device: Union[str, torch.device],
- ip_adapter_scale: Optional[Union[float, List[float]]] = None,
- ip_adapter_plus: bool = False,
- ip_adapter_face: bool = False,
- step_complete: Optional[Callable[[bool], None]] = None
+ ip_adapter_scale: Optional[Union[float, List[float]]]=None,
+ ip_adapter_mode: Optional[IP_ADAPTER_LITERAL]=None,
+ step_complete: Optional[Callable[[bool], None]]=None
) -> Iterator[None]:
"""
We initialize the TensorRT runtime here.
@@ -216,8 +220,10 @@ def align_unet(
self,
device: torch.device,
dtype: torch.dtype,
- freeu_factors: Optional[Tuple[float, float, float, float]] = None,
- offload_models: bool = False
+ animation_frames: Optional[int]=None,
+ motion_scale: Optional[float]=None,
+ freeu_factors: Optional[Tuple[float, float, float, float]]=None,
+ offload_models: bool=False
) -> None:
"""
TRT skips.
@@ -266,8 +272,8 @@ def prepare_engines(
engines_to_build,
self.onnx_opset,
use_fp16=self.build_half,
- opt_image_height=self.engine_size,
- opt_image_width=self.engine_size,
+ opt_image_height=self.engine_size, # type: ignore[arg-type]
+ opt_image_width=self.engine_size, # type: ignore[arg-type]
force_engine_rebuild=self.force_engine_rebuild,
static_batch=self.build_static_batch,
static_shape=not self.build_dynamic_shape,
@@ -282,22 +288,23 @@ def create_latents(
num_channels_latents: int,
height: int,
width: int,
- dtype: Union[str, torch.dtype],
+ dtype: torch.dtype,
device: Union[str, torch.device],
generator: Optional[torch.Generator] = None,
+ animation_frames: Optional[int] = None,
) -> torch.Tensor:
"""
Override to change to float32
"""
return super(EnfugueTensorRTStableDiffusionPipeline, self).create_latents(
- batch_size, num_channels_latents, height, width, torch.float32, device, generator,
+ batch_size, num_channels_latents, height, width, torch.float32, device, generator, animation_frames
)
def encode_prompt(
self,
prompt: Optional[str],
device: torch.device,
- num_images_per_prompt: int = 1,
+ num_results_per_prompt: int = 1,
do_classifier_free_guidance: bool = False,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
@@ -321,7 +328,7 @@ def encode_prompt(
return super(EnfugueTensorRTStableDiffusionPipeline, self).encode_prompt(
prompt=prompt,
device=device,
- num_images_per_prompt=num_images_per_prompt,
+ num_results_per_prompt=num_results_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
diff --git a/src/python/enfugue/diffusion/support/__init__.py b/src/python/enfugue/diffusion/support/__init__.py
index 3973d11f..29406947 100644
--- a/src/python/enfugue/diffusion/support/__init__.py
+++ b/src/python/enfugue/diffusion/support/__init__.py
@@ -7,6 +7,6 @@
from enfugue.diffusion.support.background import BackgroundRemover
from enfugue.diffusion.support.ip import IPAdapter
-EdgeDetector, LineDetector, DepthDetector, PoseDetector, ControlImageProcessor, Upscaler, IPAdapter, BackgroundRemover # Silence importchecker
+EdgeDetector, LineDetector, DepthDetector, PoseDetector, ControlImageProcessor, Upscaler, BackgroundRemover, IPAdapter # Silence importchecker
-__all__ = ["EdgeDetector", "LineDetector", "DepthDetector", "PoseDetector", "ControlImageProcessor", "Upscaler", "IPAdapter"]
+__all__ = ["EdgeDetector", "LineDetector", "DepthDetector", "PoseDetector", "ControlImageProcessor", "Upscaler", "BackgroundRemover", "IPAdapter"]
diff --git a/src/python/enfugue/diffusion/support/ip/adapter.py b/src/python/enfugue/diffusion/support/ip/adapter.py
index 57a928ea..273166a7 100644
--- a/src/python/enfugue/diffusion/support/ip/adapter.py
+++ b/src/python/enfugue/diffusion/support/ip/adapter.py
@@ -3,21 +3,21 @@
from typing import List, Union, Dict, Any, Iterator, Optional, Tuple, Callable, TYPE_CHECKING
from typing_extensions import Self
from contextlib import contextmanager
-
-from transformers import (
- CLIPVisionModelWithProjection,
- CLIPImageProcessor,
- PretrainedConfig
-)
from enfugue.util import logger
from enfugue.diffusion.support.model import SupportModel
if TYPE_CHECKING:
import torch
from PIL import Image
+ from enfugue.diffusion.constants import IP_ADAPTER_LITERAL
from enfugue.diffusion.support.ip.projection import ImageProjectionModel
from enfugue.diffusion.support.ip.resampler import Resampler # type: ignore
from diffusers.models import UNet2DConditionModel, ControlNetModel
+ from transformers import (
+ CLIPVisionModelWithProjection,
+ CLIPImageProcessor,
+ PretrainedConfig
+ )
class IPAdapter(SupportModel):
"""
@@ -25,8 +25,7 @@ class IPAdapter(SupportModel):
"""
cross_attention_dim: int = 768
is_sdxl: bool = False
- use_fine_grained: bool = False
- use_face_model: bool = False
+ model: IP_ADAPTER_LITERAL = "default"
DEFAULT_ENCODER_CONFIG_PATH = "https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/config.json"
DEFAULT_ENCODER_PATH = "https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/pytorch_model.bin"
@@ -40,16 +39,16 @@ class IPAdapter(SupportModel):
XL_ADAPTER_PATH = "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl.bin"
FINE_GRAINED_XL_ADAPTER_PATH = "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin"
+ FACE_XL_ADAPTER_PATH = "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus-face_sdxl_vit-h.bin"
def load(
self,
unet: UNet2DConditionModel,
- is_sdxl: bool = False,
- use_fine_grained: bool = False,
- use_face_model: bool = False,
- scale: float = 1.0,
- keepalive_callback: Optional[Callable[[],None]] = None,
- controlnets: Optional[Dict[str, ControlNetModel]] = None,
+ model: Optional[IP_ADAPTER_LITERAL]="default",
+ is_sdxl: bool=False,
+ scale: float=1.0,
+ keepalive_callback: Optional[Callable[[],None]]=None,
+ controlnets: Optional[Dict[str, ControlNetModel]]=None,
) -> None:
"""
Loads the IP adapter.
@@ -71,18 +70,19 @@ def load(
AttentionProcessor2_0,
)
+ if model is None:
+ model = "default"
+
if (
self.cross_attention_dim != unet.config.cross_attention_dim or # type: ignore[attr-defined]
self.is_sdxl != is_sdxl or
- self.use_fine_grained != use_fine_grained or
- self.use_face_model != use_face_model
+ self.model != model
):
del self.projector
del self.encoder
self.is_sdxl = is_sdxl
- self.use_fine_grained = use_fine_grained
- self.use_face_model = use_face_model
+ self.model = model
self.cross_attention_dim = unet.config.cross_attention_dim # type: ignore[attr-defined]
self._default_unet_attention_processors: Dict[str, Any] = {}
@@ -148,10 +148,36 @@ def load(
new_processors[key] = CNAttentionProcessor()
controlnets[controlnet].set_attn_processor(new_processors)
- def check_download(self, is_sdxl: bool) -> None:
+ @property
+ def use_fine_grained(self) -> bool:
"""
- Downloads necessary files for either pipeline
+ Returns true if using a plus model
"""
+ return self.model == "plus" or self.model == "plus-face"
+
+ def check_download(
+ self,
+ is_sdxl: bool=False,
+ model: Optional[IP_ADAPTER_LITERAL]="default",
+ task_callback: Optional[Callable[[str], None]]=None,
+ ) -> None:
+ """
+ Downloads necessary files for any pipeline
+ """
+ # Gather previous state
+ _task_callback = self.task_callback
+ _is_sdxl = self.is_sdxl
+ _model = self.model
+
+ # Set new state
+ self.task_callback = task_callback
+ self.is_sdxl = is_sdxl
+ if model is None:
+ self.model = "default"
+ else:
+ self.model = model
+
+ # Trigger getters
if is_sdxl:
_ = self.xl_encoder_config
_ = self.xl_encoder_model
@@ -161,32 +187,33 @@ def check_download(self, is_sdxl: bool) -> None:
_ = self.default_encoder_model
_ = self.default_image_prompt_checkpoint
+ # Reset state
+ self.task_callback = _task_callback
+ self.is_sdxl = _is_sdxl
+ self.model = _model
+
def set_scale(
self,
unet: UNet2DConditionModel,
- new_scale: float,
- is_sdxl: bool = False,
- use_fine_grained: bool = False,
- use_face_model: bool = False,
- keepalive_callback: Optional[Callable[[],None]] = None,
- controlnets: Optional[Dict[str, ControlNetModel]] = None,
+ scale: float,
+ is_sdxl: bool=False,
+ model: Optional[IP_ADAPTER_LITERAL]="default",
+ keepalive_callback: Optional[Callable[[],None]]=None,
+ controlnets: Optional[Dict[str, ControlNetModel]]=None,
) -> int:
"""
Sets the scale on attention processors.
"""
- if (
- self.is_sdxl != is_sdxl or
- self.use_fine_grained != use_fine_grained or
- self.use_face_model != use_face_model
- ):
+ if model is None:
+ model = "default"
+ if self.is_sdxl != is_sdxl or self.model != model:
# Completely reload adapter
self.unload(unet, controlnets)
self.load(
unet,
is_sdxl=is_sdxl,
- scale=new_scale,
- use_fine_grained=use_fine_grained,
- use_face_model=use_face_model,
+ scale=scale,
+ model=model,
keepalive_callback=keepalive_callback,
controlnets=controlnets
)
@@ -199,7 +226,7 @@ def set_scale(
for name in unet.attn_processors.keys():
processor = unet.attn_processors[name]
if isinstance(processor, IPAttentionProcessor) or isinstance(processor, IPAttentionProcessor2_0):
- processor.scale = new_scale
+ processor.scale = scale
processors_altered += 1
return processors_altered
@@ -251,10 +278,10 @@ def default_image_prompt_checkpoint(self) -> str:
Gets the path to the IP checkpoint for 1.5
Downloads if needed
"""
- if self.use_fine_grained and self.use_face_model:
+ if self.model == "plus-face":
model_url = self.FACE_ADAPTER_PATH
filename = "ip-adapter-plus-face_sd15.pth"
- elif self.use_fine_grained:
+ elif self.model == "plus":
model_url = self.FINE_GRAINED_ADAPTER_PATH
filename = "ip-adapter-plus_sd15.pth"
else:
@@ -275,6 +302,7 @@ def xl_encoder_model(self) -> str:
"""
if self.use_fine_grained:
return self.default_encoder_model
+
return self.get_model_file(
self.XL_ENCODER_PATH,
filename="ip-adapter_sdxl_encoder.pth",
@@ -289,6 +317,7 @@ def xl_encoder_config(self) -> str:
"""
if self.use_fine_grained:
return self.default_encoder_config
+
return self.get_model_file(
self.XL_ENCODER_CONFIG_PATH,
filename="ip-adapter_sdxl_encoder_config.json"
@@ -300,9 +329,19 @@ def xl_image_prompt_checkpoint(self) -> str:
Gets the path to the IP checkpoint for XL
Downloads if needed
"""
+ if self.model == "plus-face":
+ model_url = self.FACE_XL_ADAPTER_PATH
+ filename = "ip-adapter-plus-face_sdxl_vit-h.pth"
+ elif self.model == "plus":
+ model_url = self.FINE_GRAINED_XL_ADAPTER_PATH
+ filename = "ip-adapter-plus_sdxl_vit-h.pth"
+ else:
+ model_url = self.XL_ADAPTER_PATH
+ filename = "ip-adapter_sdxl.pth"
+
return self.get_model_file(
- self.FINE_GRAINED_XL_ADAPTER_PATH if self.use_fine_grained else self.XL_ADAPTER_PATH,
- filename="ip-adapter-plus_sdxl_vit-h.pth" if self.use_fine_grained else "ip-adapter_sdxl.pth",
+ model_url,
+ filename=filename,
extensions=[".bin", ".pth", ".safetensors"]
)
@@ -325,6 +364,10 @@ def encoder(self) -> CLIPVisionModelWithProjection:
"""
Gets the encoder, initializes if needed
"""
+ from transformers import (
+ CLIPVisionModelWithProjection,
+ PretrainedConfig
+ )
if not hasattr(self, "_encoder"):
if self.is_sdxl:
logger.debug(f"Initializing CLIPVisionModelWithProjection from {self.xl_encoder_model}")
@@ -409,6 +452,7 @@ def processor(self) -> CLIPImageProcessor:
"""
Gets the processor, initializes if needed
"""
+ from transformers import CLIPImageProcessor
if not hasattr(self, "_processor"):
self._processor = CLIPImageProcessor()
return self._processor
diff --git a/src/python/enfugue/diffusion/support/model.py b/src/python/enfugue/diffusion/support/model.py
index fbad1d33..9e9e73f6 100644
--- a/src/python/enfugue/diffusion/support/model.py
+++ b/src/python/enfugue/diffusion/support/model.py
@@ -54,6 +54,8 @@ def __init__(
dtype: torch.dtype,
offline: bool = False
) -> None:
+ if model_dir.startswith("~"):
+ model_dir = os.path.expanduser(model_dir)
self.model_dir = model_dir
self.device = device
self.dtype = dtype
@@ -92,6 +94,27 @@ def get_model_file(
return local_path
raise IOError(f"Cannot retrieve model file {uri}")
+ @classmethod
+ def get_default_instance(cls) -> SupportModel:
+ """
+ Builds a default interpolator without a configuration passed
+ """
+ import torch
+ from enfugue.diffusion.util import get_optimal_device
+ from enfugue.util import get_local_configuration
+ device = get_optimal_device()
+ try:
+ configuration = get_local_configuration()
+ except:
+ from pibble.api.configuration import APIConfiguration
+ configuration = APIConfiguration()
+
+ return cls(
+ configuration.get("enfugue.engine.cache", "~/.cache/enfugue/other"),
+ device,
+ torch.float16 if device.type == "cuda" else torch.float32
+ )
+
@contextmanager
def context(self) -> Iterator[Self]:
"""
@@ -107,9 +130,11 @@ def context(self) -> Iterator[Self]:
import torch.cuda
torch.cuda.empty_cache()
+ torch.cuda.synchronize()
elif self.device.type == "mps":
import torch
import torch.mps
torch.mps.empty_cache()
+ torch.mps.synchronize()
gc.collect()
diff --git a/src/python/enfugue/diffusion/util/__init__.py b/src/python/enfugue/diffusion/util/__init__.py
index 34196cd5..3948808a 100644
--- a/src/python/enfugue/diffusion/util/__init__.py
+++ b/src/python/enfugue/diffusion/util/__init__.py
@@ -1,4 +1,11 @@
-from enfugue.diffusion.util.vision_util import *
+from __future__ import annotations
from enfugue.diffusion.util.torch_util import *
from enfugue.diffusion.util.model_util import *
+from enfugue.diffusion.util.morph_util import *
+from enfugue.diffusion.util.video_util import *
+from enfugue.diffusion.util.chunk_util import *
+from enfugue.diffusion.util.prompt_util import *
+from enfugue.diffusion.util.vision_util import *
+from enfugue.diffusion.util.tensorrt_util import *
from enfugue.diffusion.util.generation_util import *
+from enfugue.diffusion.util.cuda_util import *
diff --git a/src/python/enfugue/diffusion/util/chunk_util.py b/src/python/enfugue/diffusion/util/chunk_util.py
new file mode 100644
index 00000000..8e36b091
--- /dev/null
+++ b/src/python/enfugue/diffusion/util/chunk_util.py
@@ -0,0 +1,259 @@
+from dataclasses import dataclass
+from typing import Optional, Iterator, List, Tuple, Union
+from math import ceil
+
+__all__ = ["Chunker"]
+
+@dataclass
+class Chunker:
+ width: int
+ height: int
+ frames: Optional[int] = None
+ size: Optional[Union[int, Tuple[int, int]]] = None
+ stride: Optional[Union[int, Tuple[int, int]]] = None
+ frame_size: Optional[int] = None
+ frame_stride: Optional[int] = None
+ tile: Union[bool, Tuple[bool, bool]] = False
+ loop: bool = False
+ vae_scale_factor: int = 8
+ temporal_first: bool = False
+
+ def get_pixel_from_latent(self, chunk: List[int]) -> List[int]:
+ """
+ Turns latent chunk into pixel chunk
+ """
+ start = chunk[0]
+ start_px = start * self.vae_scale_factor
+ end = chunk[-1]
+ end_px = end * self.vae_scale_factor
+ low = min(chunk)
+ high = max(chunk)
+ wrapped = start != low
+ if wrapped:
+ low_px = low * self.vae_scale_factor
+ high_px = high * self.vae_scale_factor
+ return list(range(low_px, high_px)) + list(range(end_px))
+ return list(range(start_px, end_px))
+
+ @property
+ def latent_width(self) -> int:
+ """
+ Returns latent (not pixel) width
+ """
+ return self.width // self.vae_scale_factor
+
+ @property
+ def latent_height(self) -> int:
+ """
+ Returns latent (not pixel) height
+ """
+ return self.height // self.vae_scale_factor
+
+ @property
+ def latent_size(self) -> Tuple[int, int]:
+ """
+ Returns latent (not pixel) size
+ """
+ if self.size is None:
+ return (self.latent_width, self.latent_height)
+ if isinstance(self.size, tuple):
+ width, height = self.size
+ else:
+ width, height = self.size, self.size
+ return (
+ width // self.vae_scale_factor,
+ height // self.vae_scale_factor,
+ )
+
+ @property
+ def latent_stride(self) -> Tuple[int, int]:
+ """
+ Returns latent (not pixel) stride
+ """
+ if self.stride is None:
+ return (self.latent_width, self.latent_height)
+ if isinstance(self.stride, tuple):
+ left, top = self.stride
+ else:
+ left, top = self.stride, self.stride
+ return (
+ left // self.vae_scale_factor,
+ top // self.vae_scale_factor,
+ )
+
+ @property
+ def num_horizontal_chunks(self) -> int:
+ """
+ Gets the number of horizontal chunks.
+ """
+ if not self.size or not self.stride:
+ return 1
+ if isinstance(self.tile, tuple):
+ tile_x, tile_y, = self.tile
+ else:
+ tile_x = self.tile
+ if tile_x:
+ return max(ceil(self.latent_width / self.latent_stride[0]), 1)
+ return max(ceil((self.latent_width - self.latent_size[0]) / self.latent_stride[0] + 1), 1)
+
+ @property
+ def num_vertical_chunks(self) -> int:
+ """
+ Gets the number of vertical chunks.
+ """
+ if not self.size or not self.stride:
+ return 1
+ if isinstance(self.tile, tuple):
+ tile_x, tile_y, = self.tile
+ else:
+ tile_y = self.tile
+ if tile_y:
+ return max(ceil(self.latent_height / self.latent_stride[1]), 1)
+ return max(ceil((self.latent_height - self.latent_size[1]) / self.latent_stride[1] + 1), 1)
+
+ @property
+ def num_chunks(self) -> int:
+ """
+ Gets the number of latent space image chunks
+ """
+ return self.num_horizontal_chunks * self.num_vertical_chunks
+
+ @property
+ def num_frame_chunks(self) -> int:
+ """
+ Gets the number of frame chunks.
+ """
+ if not self.frames or not self.frame_size or not self.frame_stride:
+ return 1
+ if self.loop:
+ return max(ceil(self.frames / self.frame_stride), 1)
+ return max(ceil((self.frames - self.frame_size) / self.frame_stride + 1), 1)
+
+ @property
+ def tile_x(self) -> bool:
+ """
+ Gets whether or not tiling is eanbled on the X dimension.
+ """
+ if isinstance(self.tile, tuple):
+ return self.tile[0]
+ return self.tile
+
+ @property
+ def tile_y(self) -> bool:
+ """
+ Gets whether or not tiling is eanbled on the Y dimension.
+ """
+ if isinstance(self.tile, tuple):
+ return self.tile[1]
+ return self.tile
+
+ @property
+ def chunks(self) -> Iterator[Tuple[Tuple[int, int], Tuple[int, int]]]:
+ """
+ Gets the chunked latent indices
+ """
+ if not self.size or not self.stride:
+ yield (
+ (0, self.latent_height),
+ (0, self.latent_width)
+ )
+ return
+
+ vertical_chunks = self.num_vertical_chunks
+ horizontal_chunks = self.num_horizontal_chunks
+ total = vertical_chunks * horizontal_chunks
+
+ latent_size_x, latent_size_y = self.latent_size
+ latent_stride_x, latent_stride_y = self.latent_stride
+ if isinstance(self.tile, tuple):
+ tile_x, tile_y = self.tile
+ else:
+ tile_x, tile_y = self.tile, self.tile
+
+ for i in range(total):
+ vertical_offset = None
+ horizontal_offset = None
+
+ top = (i // horizontal_chunks) * latent_stride_y
+ bottom = top + latent_size_y
+
+ left = (i % horizontal_chunks) * latent_stride_x
+ right = left + latent_size_x
+
+ if bottom > self.latent_height:
+ vertical_offset = bottom - self.latent_height
+ bottom -= vertical_offset
+ if not tile_y:
+ top = max(0, top - vertical_offset)
+
+ if right > self.latent_width:
+ horizontal_offset = right - self.latent_width
+ right -= horizontal_offset
+ if not tile_x:
+ left = max(0, left - horizontal_offset)
+
+ horizontal = [left, right]
+ vertical = [top, bottom]
+
+ if horizontal_offset is not None and tile_x:
+ horizontal[-1] = horizontal_offset
+
+ if vertical_offset is not None and tile_y:
+ vertical[-1] = vertical_offset
+
+ yield tuple(vertical), tuple(horizontal) # type: ignore
+
+ @property
+ def frame_chunks(self) -> Iterator[Tuple[int, int]]:
+ """
+ Iterates over the frame chunks.
+ """
+ if not self.frames:
+ return
+ if not self.frame_size or not self.frame_stride:
+ yield (0, self.frames)
+ return
+ for i in range(self.num_frame_chunks):
+ offset = None
+ start = i * self.frame_stride
+ end = start + self.frame_size
+
+ if end > self.frames:
+ offset = end - self.frames
+ end -= offset
+ if not self.loop:
+ start -= offset
+ frames = [start, end]
+ if offset is not None and self.loop:
+ frames[-1] = offset
+
+ yield tuple(frames) # type: ignore
+
+ def __len__(self) -> int:
+ """
+ Implements len() to return the total number of chunks
+ """
+ return self.num_chunks * self.num_frame_chunks
+
+ def __iter__(self) -> Iterator[
+ Tuple[
+ Tuple[int, int],
+ Tuple[int, int],
+ Tuple[Optional[int], Optional[int]]
+ ]
+ ]:
+ """
+ Iterates over all chunks, yielding (vertical, horizontal, temporal)
+ """
+ if self.frames:
+ if self.temporal_first:
+ for frame_chunk in self.frame_chunks:
+ for vertical_chunk, horizontal_chunk in self.chunks:
+ yield (vertical_chunk, horizontal_chunk, frame_chunk)
+ else:
+ for vertical_chunk, horizontal_chunk in self.chunks:
+ for frame_chunk in self.frame_chunks:
+ yield (vertical_chunk, horizontal_chunk, frame_chunk)
+ else:
+ for vertical_chunk, horizontal_chunk in self.chunks:
+ yield (vertical_chunk, horizontal_chunk, (None, None))
diff --git a/src/python/enfugue/diffusion/util/cuda_util.py b/src/python/enfugue/diffusion/util/cuda_util.py
new file mode 100644
index 00000000..b45d5123
--- /dev/null
+++ b/src/python/enfugue/diffusion/util/cuda_util.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+import os
+
+__all__ = ["get_cudnn_lib_dir"]
+
+def get_cudnn_lib_dir() -> str:
+ """
+ Gets the CUDNN directory
+ """
+ import nvidia
+ import nvidia.cudnn
+ cudnn_dir = os.path.dirname(nvidia.cudnn.__file__)
+ cudnn_lib_dir = os.path.join(cudnn_dir, "lib")
+ if os.path.exists(cudnn_lib_dir):
+ return cudnn_lib_dir
+ raise IOError("Couldn't find CUDNN directory.")
diff --git a/src/python/enfugue/diffusion/util/generation_util.py b/src/python/enfugue/diffusion/util/generation_util.py
index 7857a4b7..e11019ee 100644
--- a/src/python/enfugue/diffusion/util/generation_util.py
+++ b/src/python/enfugue/diffusion/util/generation_util.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from pibble.util.strings import Serializer
-from typing import Any, Dict, List, Tuple, TYPE_CHECKING
+from typing import Any, Dict, List, Tuple, Union, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from enfugue.diffusion.manager import DiffusionPipelineManager
@@ -22,6 +22,7 @@ def __init__(
grid_size: int = 256,
grid_columns: int = 4,
caption_height: int = 50,
+ use_video: bool = False,
**base_kwargs: Any
) -> None:
self.seed = seed
@@ -29,6 +30,7 @@ def __init__(
self.grid_columns = grid_columns
self.caption_height = caption_height
self.base_kwargs = base_kwargs
+ self.use_video = use_video
@property
def font(self) -> ImageFont:
@@ -97,7 +99,10 @@ def format_parameters(self, parameters: Dict[str, Any]) -> str:
for key in parameters
])
- def collage(self, results: List[Tuple[Dict[str, Any], List[Image]]]) -> Image:
+ def collage(
+ self,
+ results: List[Tuple[Dict[str, Any], Optional[str], List[Image]]]
+ ) -> Union[Image, List[Image]]:
"""
Builds the results into a collage.
"""
@@ -105,7 +110,11 @@ def collage(self, results: List[Tuple[Dict[str, Any], List[Image]]]) -> Image:
from PIL import Image, ImageDraw
# Get total images
- total_images = sum([len(images) for kwargs, images in results])
+ if self.use_video:
+ total_images = len(results)
+ else:
+ total_images = sum([len(images) for kwargs, label, images in results])
+
if total_images == 0:
raise RuntimeError("No images passed.")
@@ -122,41 +131,76 @@ def collage(self, results: List[Tuple[Dict[str, Any], List[Image]]]) -> Image:
# Create blank image
grid = Image.new("RGB", (width, height), (255, 255, 255))
- draw = ImageDraw.Draw(grid)
+
+ # Multiply if making a video
+ if self.use_video:
+ frame_count = max([len(images) for kwargs, label, images in results])
+ grid = [grid.copy() for i in range(frame_count)]
+ draw = [ImageDraw.Draw(image) for image in grid]
+ else:
+ draw = ImageDraw.Draw(grid)
# Iterate through each result image and paste
row, column = 0, 0
- for parameter_set, images in results:
+ for parameter_set, label, images in results:
for i, image in enumerate(images):
- width, height = image.size
# Fit the image to the grid size
+ width, height = image.size
image = fit_image(image, self.grid_size, self.grid_size, "contain", "center-center")
+ # Figure out which image/draw to use
+ if self.use_video:
+ target_image = grid[i]
+ target_draw = draw[i]
+ else:
+ target_image = grid
+ target_draw = draw
# Paste the image on the grid
- grid.paste(image, (column * self.grid_size, row * (self.grid_size + self.caption_height)))
+ target_image.paste(
+ image,
+ (column * self.grid_size, row * (self.grid_size + self.caption_height))
+ )
# Put the caption under the image
- draw.text(
+ if label is None:
+ if self.use_video:
+ label = f"{self.format_parameters(parameter_set)}, {width}×{height}"
+ else:
+ label = f"{self.format_parameters(parameter_set)}, sample {i+1}, {width}×{height}"
+ target_draw.text(
(column * self.grid_size + 5, row * (self.grid_size + self.caption_height) + self.grid_size + 2),
- self.split_text(f"{self.format_parameters(parameter_set)}, sample {i+1}, {width}×{height}"),
+ self.split_text(label),
fill=(0,0,0),
font=self.font
)
# Increment as necessary
+ if not self.use_video:
+ column += 1
+ if column >= self.grid_columns:
+ row += 1
+ column = 0
+ # Increment as necessary
+ if self.use_video:
column += 1
if column >= self.grid_columns:
row += 1
column = 0
return grid
- def execute(self, manager: DiffusionPipelineManager, *parameter_sets: Dict[str, Any]) -> Image:
+ def execute(
+ self,
+ manager: DiffusionPipelineManager,
+ *parameter_sets: Dict[str, Any]
+ )-> Union[Image, Tuple[Image]]:
"""
Executes each parameter set and pastes on the grid.
"""
- results: List[Tuple[Dict[str, Any], List[Image]]] = []
+ results: List[Tuple[Dict[str, Any], Optional[str], List[Image]]] = []
for parameter_set in parameter_sets:
manager.seed = self.seed
+ label = parameter_set.pop("label", None)
result = manager(**{**self.base_kwargs, **parameter_set})
results.append((
parameter_set,
+ label,
result["images"]
))
return self.collage(results)
diff --git a/src/python/enfugue/diffusion/util/model_util.py b/src/python/enfugue/diffusion/util/model_util.py
index 3be3d6a8..a3d9efeb 100644
--- a/src/python/enfugue/diffusion/util/model_util.py
+++ b/src/python/enfugue/diffusion/util/model_util.py
@@ -1,17 +1,18 @@
+from __future__ import annotations
+
import os
import re
import gc
-import torch
-import safetensors.torch
-
from enfugue.util import logger
+from enfugue.diffusion.constants import *
+
+from typing import Optional, Union, Literal, Dict, cast, TYPE_CHECKING
-from typing import Optional, Union, Literal, Dict, cast
+if TYPE_CHECKING:
+ import torch
-__all__ = [
- "ModelMerger"
-]
+__all__ = [ "ModelMerger"]
class ModelMerger:
"""
@@ -55,6 +56,7 @@ def as_half(tensor: torch.Tensor) -> torch.Tensor:
"""
Halves a tensor if necessary
"""
+ import torch
if tensor.dtype == torch.float:
return tensor.half()
return tensor
@@ -104,8 +106,10 @@ def load_checkpoint(checkpoint_path: str) -> Dict:
_, ext = os.path.splitext(checkpoint_path)
logger.debug(f"Model merger loading {checkpoint_path}")
if ext.lower() == ".safetensors":
+ import safetensors.torch
ckpt = safetensors.torch.load_file(checkpoint_path, device="cpu")
else:
+ import torch
ckpt = torch.load(checkpoint_path, map_location="cpu")
return ModelMerger.get_state_dict_from_checkpoint(ckpt)
@@ -131,6 +135,7 @@ def save(self, output_path: str) -> None:
"""
Runs the configured merger.
"""
+ import torch
logger.debug(
f"Executing model merger with interpolation '{self.interpolation}', primary model {self.primary_model}, secondary model {self.secondary_model}, tertiary model {self.tertiary_model}"
)
@@ -207,6 +212,7 @@ def save(self, output_path: str) -> None:
logger.debug(f"Saving merged model to {output_path}")
_, extension = os.path.splitext(output_path)
if extension.lower() == ".safetensors":
+ import safetensors.torch
safetensors.torch.save_file(theta_0, output_path)
else:
torch.save(theta_0, output_path)
diff --git a/src/python/enfugue/diffusion/util/morph_util.py b/src/python/enfugue/diffusion/util/morph_util.py
new file mode 100644
index 00000000..bcd83152
--- /dev/null
+++ b/src/python/enfugue/diffusion/util/morph_util.py
@@ -0,0 +1,290 @@
+# Inspired by the following:
+# https://github.com/ddowd97/Python-Image-Morpher/blob/master/Morphing/Morphing.py
+# https://github.com/jankovicsandras/autoimagemorph/blob/master/autoimagemorph.py
+# https://github.com/spmallick/learnopencv/blob/master/FaceMorph/faceMorph.py
+import cv2
+import numpy as np
+
+from typing import Union, Literal, Iterator, Tuple, List, Any
+
+from PIL import Image
+from scipy.spatial import Delaunay
+from matplotlib.path import Path
+
+class Triangle:
+ """
+ Stores vertices for a triangle and allows some calculations.
+ """
+ def __init__(self, vertices: np.ndarray) -> None:
+ self.vertices = vertices
+
+ @property
+ def points(self) -> np.ndarray:
+ """
+ Gets the points contained within this triangle.
+ """
+ if not hasattr(self, "_points"):
+ self.min_x = int(self.vertices[:, 0].min())
+ self.max_x = int(self.vertices[:, 0].max())
+ self.min_y = int(self.vertices[:, 1].min())
+ self.max_y = int(self.vertices[:, 1].max())
+ x_list = range(self.min_x, self.max_x + 1)
+ y_list = range(self.min_y, self.max_y + 1)
+ point_list = [(x, y) for x in x_list for y in y_list]
+
+ points = np.array(point_list, np.float64)
+ p = Path(self.vertices)
+ grid = p.contains_points(points)
+ mask = grid.reshape(self.max_x - self.min_x + 1, self.max_y + self.min_y + 1)
+ filtered = np.where(np.array(mask) == True)
+
+ self._points = np.vstack((filtered[0] + self.min_x, filtered[1] + self.min_y, np.ones(filtered[0].shape[0])))
+ return self._points
+
+
+class Morpher:
+ """
+ A quick-calculating morpher class that allows you to morph between two images.
+ """
+ def __init__(
+ self,
+ left: Union[str, np.ndarray, Image.Image],
+ right: Union[str, np.ndarray, Image.Image],
+ features: int = 8
+ ) -> None:
+ from enfugue.diffusion.util import ComputerVision
+ if isinstance(left, str):
+ left = Image.open(left)
+ if isinstance(left, np.ndarray):
+ left = ComputerVision.revert_image(left)
+ if isinstance(right, str):
+ right = Image.open(right)
+ if isinstance(right, np.ndarray):
+ right = ComputerVision.revert_image(right)
+ self.left = left
+ self.right = right.resize(self.left.size)
+ self.features = features
+
+ @property
+ def start(self) -> np.ndarray:
+ """
+ Gets the starting image in OpenCV format.
+ """
+ from enfugue.diffusion.util import ComputerVision
+ if not hasattr(self, "_start"):
+ self._start = ComputerVision.convert_image(self.left)
+ return self._start
+
+ @property
+ def end(self) -> np.ndarray:
+ """
+ Gets the ending image in OpenCV format.
+ """
+ from enfugue.diffusion.util import ComputerVision
+ if not hasattr(self, "_end"):
+ self._end = ComputerVision.convert_image(self.right)
+ return self._end
+
+ @property
+ def start_points(self) -> List[List[int]]:
+ """
+ Gets feature points from the left (start)
+ """
+ if not hasattr(self, "_start_points"):
+ self._start_points = self.get_image_feature_points(self.start)
+ return self._start_points
+
+ @property
+ def end_points(self) -> List[List[int]]:
+ """
+ Gets feature points from the left (end)
+ """
+ if not hasattr(self, "_end_points"):
+ self._end_points = self.get_image_feature_points(self.end)
+ return self._end_points
+
+ @property
+ def triangles(self) -> Iterator[Tuple[Triangle, Triangle]]:
+ """
+ Iterate over the tesselated triangles.
+ """
+ start = np.array(self.start_points, np.float64)
+ end = np.array(self.end_points, np.float64)
+
+ tesselated = Delaunay(start)
+
+ start_np = start[tesselated.simplices]
+ end_np = end[tesselated.simplices]
+
+ for x, y in zip(start_np, end_np):
+ yield (Triangle(x), Triangle(y))
+
+ def get_image_feature_points(self, image: np.ndarray):
+ """
+ Gets tracked features for an image
+ """
+ height, width, channels = image.shape
+ # Initialize points with four corners
+ points = [
+ [0, 0],
+ [width - 1, 0],
+ [0, height-1],
+ [width-1, height-1]
+ ]
+
+ # Get height and width of cells
+ h = int(height / self.features) - 1
+ w = int(width / self.features) - 1
+
+ # Iterate over cells
+ for i in range(self.features):
+ for j in range(self.features):
+ # Crop and find feature point in frame
+ cropped = image[(j*h):(j*h)+h, (i*w):(i*w)+w]
+ monochrome = cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY)
+ features = cv2.goodFeaturesToTrack(monochrome, 1, 0.1, 10) # Tunable
+ if features is None:
+ # If there's nothing worth tracking in this cell, make our point the center
+ features = [[[h/2, w/2]]]
+
+ # Go through features and add coordinates to array
+ features = np.int0(features) # type: ignore[attr-defined]
+ for feature in features:
+ x, y = feature.ravel()
+ x += (i*w)
+ y += (j*h)
+ points.append([x, y])
+
+ # Return 4 + (features ^ 2) points
+ return points
+
+ def morph(
+ self,
+ start: np.ndarray,
+ end: np.ndarray,
+ target: np.ndarray,
+ start_triangle: Triangle,
+ end_triangle: Triangle,
+ target_triangle: Triangle,
+ alpha: float
+ ) -> None:
+ """
+ Morphs into a target array at a certain alpha
+ """
+ # Find bounding boundangle for each triangle
+ start_bound = cv2.boundingRect(np.float32([start_triangle.vertices])) # type: ignore[arg-type]
+ end_bound = cv2.boundingRect(np.float32([end_triangle.vertices])) # type: ignore[arg-type]
+ target_bound = cv2.boundingRect(np.float32([target_triangle.vertices])) # type: ignore[arg-type]
+
+ # Offset points by left top corner of the respective rectangles
+ start_rect = []
+ end_rect = []
+ target_rect = []
+
+ for i in range(0, 3):
+ target_rect.append((
+ (target_triangle.vertices[i][0] - target_bound[0]),
+ (target_triangle.vertices[i][1] - target_bound[1])
+ ))
+ start_rect.append((
+ (start_triangle.vertices[i][0] - start_bound[0]),
+ (start_triangle.vertices[i][1] - start_bound[1])
+ ))
+ end_rect.append((
+ (end_triangle.vertices[i][0] - end_bound[0]),
+ (end_triangle.vertices[i][1] - end_bound[1])
+ ))
+
+ # Get mask by filling triangle
+ mask = np.zeros((target_bound[3], target_bound[2], 3), dtype = np.float32)
+ cv2.fillConvexPoly(mask, np.int32(target_rect), (1.0, 1.0, 1.0), 16, 0) # type: ignore[arg-type]
+
+ # Apply warpImage to small rectangular patches
+ start_image = start[
+ start_bound[1]:start_bound[1] + start_bound[3],
+ start_bound[0]:start_bound[0] + start_bound[2]
+ ]
+ end_image = end[
+ end_bound[1]:end_bound[1] + end_bound[3],
+ end_bound[0]:end_bound[0] + end_bound[2]
+ ]
+ size = (target_bound[2], target_bound[3])
+
+ warp_start = self.affine(start_image, start_rect, target_rect, size)
+ warp_end = self.affine(end_image, end_rect, target_rect, size)
+
+ blend = (1.0 - alpha) * warp_start + alpha * warp_end
+
+ # Copy triangular region of the rectangular patch to the output image
+ target[
+ target_bound[1]:target_bound[1]+target_bound[3],
+ target_bound[0]:target_bound[0]+target_bound[2]
+ ] = target[
+ target_bound[1]:target_bound[1]+target_bound[3],
+ target_bound[0]:target_bound[0]+target_bound[2]
+ ] * (1 - mask) + blend * mask
+
+ def affine(
+ self,
+ source: np.ndarray,
+ source_rect: List[Tuple[Any, Any]],
+ target_rect: List[Tuple[Any, Any]],
+ size: Tuple[int, int]
+ ) -> np.ndarray:
+ """
+ Applies the affine transform for a section of an image
+ """
+ warp_mat = cv2.getAffineTransform(np.float32(source_rect), np.float32(target_rect)) # type: ignore[arg-type]
+ return cv2.warpAffine(
+ source,
+ warp_mat,
+ (size[0], size[1]),
+ None,
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_REFLECT_101
+ )
+
+ def __call__(self, alpha: float, return_type: Literal["pil", "np"] = "pil"):
+ """
+ Gets an image at a particular point between 0 and 1
+ """
+ from enfugue.diffusion.util import ComputerVision
+ start = np.float32(self.start)
+ end = np.float32(self.end)
+ target = np.zeros(start.shape, dtype=start.dtype)
+
+ points = []
+ for ((start_x, start_y), (end_x, end_y)) in zip(self.start_points, self.end_points):
+ points.append([
+ (1 - alpha) * start_x + alpha * end_x,
+ (1 - alpha) * start_y + alpha * end_y,
+ ])
+
+ for start_tri, end_tri in self.triangles:
+ target_tri = Triangle((1 - alpha) * start_tri.vertices + end_tri.vertices * alpha)
+ self.morph(start, end, target, start_tri, end_tri, target_tri, alpha) # type: ignore[arg-type]
+
+ target = np.uint8(target) # type: ignore[assignment]
+ if return_type == "pil":
+ return ComputerVision.revert_image(target)
+ return target
+
+ def save_video(
+ self,
+ path: str,
+ length: int = 20,
+ rate: float = 20.,
+ overwrite: bool = False,
+ ) -> int:
+ """
+ Saves the warped image(s) to an .mp4
+ """
+ from enfugue.diffusion.util import Video
+ return Video([
+ self(i/length)
+ for i in range(length+1)
+ ]).save(
+ path,
+ rate=rate,
+ overwrite=overwrite
+ )
diff --git a/src/python/enfugue/diffusion/util/prompt_util.py b/src/python/enfugue/diffusion/util/prompt_util.py
new file mode 100644
index 00000000..cf1496c3
--- /dev/null
+++ b/src/python/enfugue/diffusion/util/prompt_util.py
@@ -0,0 +1,354 @@
+from __future__ import annotations
+from dataclasses import dataclass
+from typing import Optional, Union, Tuple, List, Callable, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from torch import Tensor, dtype
+
+__all__ = ["Prompt", "EncodedPrompt", "EncodedPrompts"]
+
+@dataclass(frozen=True)
+class Prompt:
+ """
+ This class holds, at a minimum, a prompt string.
+ It can also contain a start frame, end frame, and weight.
+ """
+ positive: Optional[str] = None
+ negative: Optional[str] = None
+ positive_2: Optional[str] = None
+ negative_2: Optional[str] = None
+ start: Optional[int] = None
+ end: Optional[int] = None
+ weight: Optional[float] = None
+
+ def get_frame_overlap(self, frames: List[int]) -> float:
+ """
+ Gets the frame overlap ratio for this prompt
+ """
+ if self.start is None:
+ return 1.0
+ end = self.end
+ if end is None:
+ end = max(frames)
+
+ prompt_frame_list = list(range(self.start, end))
+ return len(set(prompt_frame_list).intersection(set(frames))) / len(frames)
+
+ def __str__(self) -> str:
+ if self.positive is None:
+ return "(none)"
+ return self.positive
+
+@dataclass
+class EncodedPrompt:
+ """
+ After encoding a prompt, this class holds the tensors and provides
+ methods for accessing the encoded tensors.
+ """
+ prompt: Union[Prompt, str]
+ embeds: Tensor
+ negative_embeds: Optional[Tensor]
+ pooled_embeds: Optional[Tensor]
+ negative_pooled_embeds: Optional[Tensor]
+
+ def __str__(self) -> str:
+ return str(self.prompt)
+
+ def check_get_tensor(
+ self,
+ frames: Optional[List[int]],
+ tensor: Optional[Tensor]
+ ) -> Tuple[Optional[Tensor], Union[float, int]]:
+ """
+ Checks if a tensor exists and should be returned and should be scaled.
+ """
+ if frames is None or isinstance(self.prompt, str) or tensor is None:
+ return tensor, 1.0
+ weight = 1.0 if self.prompt.weight is None else self.prompt.weight
+ if frames is None or self.prompt.start is None:
+ return tensor * weight, weight
+ overlap = self.prompt.get_frame_overlap(frames)
+ if overlap == 0 or weight == 0:
+ return None, 0
+ weight *= overlap
+ return tensor * weight, weight
+
+ def get_embeds(
+ self,
+ frames: Optional[List[int]] = None
+ ) -> Tuple[Optional[Tensor], Union[float, int]]:
+ """
+ Gets the encoded embeds.
+ """
+ return self.check_get_tensor(frames, self.embeds)
+
+ def get_negative_embeds(
+ self,
+ frames: Optional[List[int]] = None
+ ) -> Tuple[Optional[Tensor], Union[float, int]]:
+ """
+ Gets the encoded negative embeds.
+ """
+ return self.check_get_tensor(frames, self.negative_embeds)
+
+ def get_pooled_embeds(
+ self,
+ frames: Optional[List[int]] = None
+ ) -> Tuple[Optional[Tensor], Union[float, int]]:
+ """
+ Gets the encoded pooled embeds.
+ """
+ return self.check_get_tensor(frames, self.pooled_embeds)
+
+ def get_negative_pooled_embeds(
+ self,
+ frames: Optional[List[int]] = None
+ ) -> Tuple[Optional[Tensor], Union[float, int]]:
+ """
+ Gets the encoded negative pooled embeds.
+ """
+ return self.check_get_tensor(frames, self.negative_pooled_embeds)
+
+ @property
+ def dtype(self) -> dtype:
+ """
+ Gets the dtype of the encoded prompt.
+ """
+ return self.embeds.dtype
+
+if TYPE_CHECKING:
+ PromptGetterCallable = Callable[
+ [EncodedPrompt, Optional[List[int]]],
+ Tuple[Optional[Tensor], Union[float, int]]
+ ]
+
+@dataclass
+class EncodedImagePrompt:
+ """
+ Holds an encoded image prompt when using IP adapter
+ """
+ prompt_embeds: Tensor
+ uncond_embeds: Tensor
+ scale: float
+
+@dataclass
+class EncodedPrompts:
+ """
+ Holds any number of encoded prompts.
+ """
+ prompts: List[EncodedPrompt]
+ is_sdxl: bool
+ do_classifier_free_guidance: bool
+ image_prompt_embeds: Optional[Tensor] # input, frames, batch, tokens, embeds
+ image_uncond_prompt_embeds: Optional[Tensor] # input, frames, batch, tokens, embeds
+
+ def get_stacked_tensor(
+ self,
+ frames: Optional[List[int]],
+ getter: PromptGetterCallable
+ ) -> Optional[Tensor]:
+ """
+ Gets a tensor from prompts using a callable.
+ """
+ import torch
+ return_tensor = None
+ for prompt in self.prompts:
+ tensor, weight = getter(prompt, frames)
+ if tensor is not None and weight is not None and weight > 0:
+ if return_tensor is None:
+ return_tensor = tensor * weight
+ else:
+ return_tensor = torch.cat([return_tensor, tensor * weight], dim=1) # type: ignore[unreachable]
+ return return_tensor
+
+ def get_mean_tensor(
+ self,
+ frames: Optional[List[int]],
+ getter: PromptGetterCallable
+ ) -> Optional[Tensor]:
+ """
+ Gets a tensor from prompts using a callable.
+ """
+ import torch
+ return_tensor = None
+ total_weight = 0.0
+ for prompt in self.prompts:
+ tensor, weight = getter(prompt, frames)
+ if tensor is not None and weight is not None and weight > 0:
+ total_weight += weight
+ if return_tensor is None:
+ return_tensor = (tensor * weight).unsqueeze(0)
+ else:
+ return_tensor = torch.cat([return_tensor, (tensor * weight).unsqueeze(0)]) # type: ignore[unreachable]
+ if return_tensor is not None:
+ return torch.sum(return_tensor, 0) / total_weight
+ return None
+
+ def get_image_prompt_embeds(
+ self,
+ frames: Optional[List[int]]=None
+ ) -> Tensor:
+ """
+ Gets image prompt embeds.
+ """
+ if self.image_prompt_embeds is None:
+ raise RuntimeError("get_image_prompt_embeds called, but no image prompt embeds present.")
+ import torch
+ return_tensor: Optional[Tensor] = None
+ for image_embeds in self.image_prompt_embeds:
+ if frames is None:
+ image_embeds = image_embeds[0]
+ else:
+ frame_length = image_embeds.shape[0]
+ if frames[-1] <= frames[0]:
+ # Wraparound
+ image_embeds = torch.cat([
+ image_embeds[frames[0]:frame_length],
+ image_embeds[:frames[-1]]
+ ])
+ else:
+ image_embeds = image_embeds[frames]
+ # Collapse along frames
+ image_embeds = image_embeds.mean(0)
+
+ if return_tensor is None:
+ return_tensor = image_embeds
+ else:
+ return_tensor = torch.cat(
+ [return_tensor, image_embeds],
+ dim=1
+ )
+ if return_tensor is None:
+ raise RuntimeError("Prompt embeds could not be retrieved.")
+ return return_tensor
+
+ def get_image_uncond_prompt_embeds(
+ self,
+ frames: Optional[List[int]]=None
+ ) -> Tensor:
+ """
+ Gets image unconditioning prompt embeds.
+ """
+ if self.image_uncond_prompt_embeds is None:
+ raise RuntimeError("get_image_prompt_embeds called, but no image prompt embeds present.")
+ import torch
+ return_tensor: Optional[Tensor] = None
+ for uncond_embeds in self.image_uncond_prompt_embeds:
+ if frames is None:
+ uncond_embeds = uncond_embeds[0]
+ else:
+ frame_length = uncond_embeds.shape[0]
+ if frames[-1] <= frames[0]:
+ # Wraparound
+ uncond_embeds = torch.cat([
+ uncond_embeds[frames[0]:frame_length],
+ uncond_embeds[:frames[-1]]
+ ])
+ else:
+ uncond_embeds = uncond_embeds[frames]
+ # Collapse along frames
+ uncond_embeds = uncond_embeds.mean(0)
+
+ if return_tensor is None:
+ return_tensor = uncond_embeds
+ else:
+ return_tensor = torch.cat(
+ [return_tensor, uncond_embeds],
+ dim=1
+ )
+ if return_tensor is None:
+ raise RuntimeError("Prompt embeds could not be retrieved.")
+ return return_tensor
+
+ def get_embeds(self, frames: Optional[List[int]] = None) -> Optional[Tensor]:
+ """
+ Gets the encoded embeds.
+ """
+ import torch
+ get_embeds: PromptGetterCallable = lambda prompt, frames: prompt.get_embeds(frames)
+ method = self.get_mean_tensor if self.is_sdxl else self.get_stacked_tensor
+ result = method(frames, get_embeds)
+ if result is None:
+ return None
+ if self.is_sdxl and self.image_prompt_embeds is not None:
+ result = torch.cat([result, self.get_image_prompt_embeds(frames)], dim=1)
+ if self.is_sdxl and self.do_classifier_free_guidance:
+ negative_result = self.get_negative_embeds(frames)
+ if negative_result is None:
+ negative_result = torch.zeros_like(result)
+ result = torch.cat([negative_result, result], dim=0)
+ elif not self.is_sdxl and self.image_prompt_embeds is not None and result is not None:
+ if self.do_classifier_free_guidance:
+ negative, positive = result.chunk(2)
+ else:
+ negative, positive = None, result
+ positive = torch.cat([positive, self.get_image_prompt_embeds(frames)], dim=1)
+ if self.do_classifier_free_guidance and negative is not None and self.image_uncond_prompt_embeds is not None:
+ negative = torch.cat([negative, self.get_image_uncond_prompt_embeds(frames)], dim=1)
+ return torch.cat([negative, positive], dim=0)
+ else:
+ return positive
+ return result
+
+ def get_negative_embeds(self, frames: Optional[List[int]] = None) -> Optional[Tensor]:
+ """
+ Gets the encoded negative embeds.
+ """
+ if not self.is_sdxl:
+ return None
+ import torch
+ get_embeds: PromptGetterCallable = lambda prompt, frames: prompt.get_negative_embeds(frames)
+ method = self.get_mean_tensor if self.is_sdxl else self.get_stacked_tensor
+ result = method(frames, get_embeds)
+ if self.is_sdxl and self.image_uncond_prompt_embeds is not None and result is not None:
+ return torch.cat([result, self.get_image_uncond_prompt_embeds(frames)], dim=1)
+ elif self.image_uncond_prompt_embeds is not None and result is not None:
+ if self.do_classifier_free_guidance:
+ negative, positive = result.chunk(2)
+ else:
+ negative, positive = result, None
+ negative = torch.cat([negative, self.get_image_uncond_prompt_embeds(frames)], dim=1)
+ return negative
+ return result
+
+ def get_pooled_embeds(self, frames: Optional[List[int]] = None) -> Optional[Tensor]:
+ """
+ Gets the encoded pooled embeds.
+ """
+ if not self.is_sdxl:
+ return None
+ get_embeds: PromptGetterCallable = lambda prompt, frames: prompt.get_pooled_embeds(frames)
+ return self.get_mean_tensor(frames, get_embeds)
+
+ def get_negative_pooled_embeds(self, frames: Optional[List[int]] = None) -> Optional[Tensor]:
+ """
+ Gets the encoded negative pooled embeds.
+ """
+ if not self.is_sdxl:
+ return None
+ get_embeds: PromptGetterCallable = lambda prompt, frames: prompt.get_negative_pooled_embeds(frames)
+ return self.get_mean_tensor(frames, get_embeds)
+
+ def get_add_text_embeds(self, frames: Optional[List[int]] = None) -> Optional[Tensor]:
+ """
+ Gets added text embeds for SDXL.
+ """
+ if not self.is_sdxl:
+ return None
+ import torch
+ pooled_embeds = self.get_pooled_embeds(frames)
+ if self.do_classifier_free_guidance and pooled_embeds is not None:
+ negative_pooled_embeds = self.get_negative_pooled_embeds()
+ if negative_pooled_embeds is None:
+ negative_pooled_embeds = torch.zeros_like(pooled_embeds)
+ pooled_embeds = torch.cat([negative_pooled_embeds, pooled_embeds], dim=0)
+ return pooled_embeds
+
+ @property
+ def dtype(self) -> dtype:
+ """
+ Gets the dtype of the encoded prompt.
+ """
+ if not self.prompts:
+ raise ValueError("No prompts, cannot determine dtype.")
+ return self.prompts[0].dtype
diff --git a/src/python/enfugue/diffusion/util/tensorrt_util.py b/src/python/enfugue/diffusion/util/tensorrt_util.py
new file mode 100644
index 00000000..5cf5f200
--- /dev/null
+++ b/src/python/enfugue/diffusion/util/tensorrt_util.py
@@ -0,0 +1,152 @@
+from typing import List, Tuple, Any
+from hashlib import md5
+
+__all__ = [
+ "get_clip_engine_key",
+ "get_unet_engine_key",
+ "get_vae_engine_key",
+ "get_controlled_unet_engine_key",
+]
+
+def get_clip_engine_key(
+ size: int,
+ lora: List[Tuple[str, float]],
+ lycoris: List[Tuple[str, float]],
+ inversion: List[str],
+ **kwargs: Any
+) -> str:
+ """
+ Uses hashlib to generate the unique key for the CLIP engine.
+ CLIP must be rebuilt for each:
+ 1. Model
+ 2. Dimension
+ 3. LoRA
+ 4. LyCORIS
+ 5. Textual Inversion
+ """
+ return md5(
+ "-".join(
+ [
+ str(size),
+ ":".join(
+ "=".join([str(part) for part in lora_weight])
+ for lora_weight in sorted(lora, key=lambda lora_part: lora_part[0])
+ ),
+ ":".join(
+ "=".join([str(part) for part in lycoris_weight])
+ for lycoris_weight in sorted(lycoris, key=lambda lycoris_part: lycoris_part[0])
+ ),
+ ":".join(sorted(inversion)),
+ ]
+ ).encode("utf-8")
+ ).hexdigest()
+
+def get_unet_engine_key(
+ size: int,
+ lora: List[Tuple[str, float]],
+ lycoris: List[Tuple[str, float]],
+ inversion: List[str],
+ **kwargs: Any,
+) -> str:
+ """
+ Uses hashlib to generate the unique key for the UNET engine.
+ UNET must be rebuilt for each:
+ 1. Model
+ 2. Dimension
+ 3. LoRA
+ 4. LyCORIS
+ 5. Textual Inversion
+ """
+ return md5(
+ "-".join(
+ [
+ str(size),
+ ":".join(
+ "=".join([str(part) for part in lora_weight])
+ for lora_weight in sorted(lora, key=lambda lora_part: lora_part[0])
+ ),
+ ":".join(
+ "=".join([str(part) for part in lycoris_weight])
+ for lycoris_weight in sorted(lycoris, key=lambda lycoris_part: lycoris_part[0])
+ ),
+ ":".join(sorted(inversion)),
+ ]
+ ).encode("utf-8")
+ ).hexdigest()
+
+def get_controlled_unet_key(
+ size: int,
+ lora: List[Tuple[str, float]],
+ lycoris: List[Tuple[str, float]],
+ inversion: List[str],
+ **kwargs: Any,
+) -> str:
+ """
+ Uses hashlib to generate the unique key for the UNET engine with controlnet blocks.
+ ControlledUNET must be rebuilt for each:
+ 1. Model
+ 2. Dimension
+ 3. LoRA
+ 4. LyCORIS
+ 5. Textual Inversion
+ """
+ return md5(
+ "-".join(
+ [
+ str(size),
+ ":".join(
+ "=".join([str(part) for part in lora_weight])
+ for lora_weight in sorted(lora, key=lambda lora_part: lora_part[0])
+ ),
+ ":".join(
+ "=".join([str(part) for part in lycoris_weight])
+ for lycoris_weight in sorted(lycoris, key=lambda lycoris_part: lycoris_part[0])
+ ),
+ ":".join(sorted(inversion)),
+ ]
+ ).encode("utf-8")
+ ).hexdigest()
+
+def get_controlled_unet_engine_key(
+ size: int,
+ lora: List[Tuple[str, float]],
+ lycoris: List[Tuple[str, float]],
+ inversion: List[str],
+ **kwargs: Any,
+) -> str:
+ """
+ Uses hashlib to generate the unique key for the UNET engine with controlnet blocks.
+ ControlledUNET must be rebuilt for each:
+ 1. Model
+ 2. Dimension
+ 3. LoRA
+ 4. LyCORIS
+ 5. Textual Inversion
+ """
+ return md5(
+ "-".join(
+ [
+ str(size),
+ ":".join(
+ "=".join([str(part) for part in lora_weight])
+ for lora_weight in sorted(lora, key=lambda lora_part: lora_part[0])
+ ),
+ ":".join(
+ "=".join([str(part) for part in lycoris_weight])
+ for lycoris_weight in sorted(lycoris, key=lambda lycoris_part: lycoris_part[0])
+ ),
+ ":".join(sorted(inversion)),
+ ]
+ ).encode("utf-8")
+ ).hexdigest()
+
+def get_vae_engine_key(
+ size: int,
+ **kwargs: Any,
+) -> str:
+ """
+ Uses hashlib to generate the unique key for the VAE engine. VAE need only be rebuilt for each:
+ 1. Model
+ 2. Dimension
+ """
+ return md5(str(size).encode("utf-8")).hexdigest()
diff --git a/src/python/enfugue/diffusion/util/torch_util/capability_util.py b/src/python/enfugue/diffusion/util/torch_util/capability_util.py
index 9dd4e7ca..c6a449fb 100644
--- a/src/python/enfugue/diffusion/util/torch_util/capability_util.py
+++ b/src/python/enfugue/diffusion/util/torch_util/capability_util.py
@@ -95,9 +95,11 @@ def empty_cache() -> None:
import torch
import torch.cuda
torch.cuda.empty_cache()
+ torch.cuda.synchronize()
elif mps_available():
import torch
import torch.mps
torch.mps.empty_cache()
+ torch.mps.synchronize()
import gc
gc.collect()
diff --git a/src/python/enfugue/diffusion/util/torch_util/mask_util.py b/src/python/enfugue/diffusion/util/torch_util/mask_util.py
index d77a869b..0f573d6c 100644
--- a/src/python/enfugue/diffusion/util/torch_util/mask_util.py
+++ b/src/python/enfugue/diffusion/util/torch_util/mask_util.py
@@ -2,7 +2,7 @@
from dataclasses import dataclass, field
-from typing import Any, Union, Dict, TYPE_CHECKING
+from typing import Any, Union, Dict, Optional, TYPE_CHECKING
from typing_extensions import Self
from enfugue.diffusion.constants import MASK_TYPE_LITERAL
@@ -16,7 +16,6 @@
__all__ = ["MaskWeightBuilder"]
-
@dataclass(frozen=True)
class DiffusionMask:
"""
@@ -321,6 +320,36 @@ def gaussian(
self.unmasked_weights[unmask]
)
+ def temporal(
+ self,
+ tensor: Tensor,
+ frames: Optional[int] = None,
+ unfeather_start: bool = False,
+ unfeather_end: bool = False
+ ) -> Tensor:
+ """
+ Potentially expands a tensor temporally
+ """
+ import torch
+ if frames is None:
+ return tensor
+ tensor = tensor.unsqueeze(2).repeat(1, 1, frames, 1, 1)
+ if not unfeather_start or not unfeather_end:
+ frame_length = frames // 3
+ for i in range(frame_length):
+ feathered = torch.tensor(i / frame_length)
+ if not unfeather_start:
+ tensor[:, :, i, :, :] = torch.minimum(
+ tensor[:, :, i, :, :],
+ feathered
+ )
+ if not unfeather_end:
+ tensor[:, :, frames - i - 1, :, :] = torch.minimum(
+ tensor[:, :, frames - i - 1, :, :],
+ feathered
+ )
+ return tensor
+
def __call__(
self,
mask_type: MASK_TYPE_LITERAL,
@@ -328,10 +357,13 @@ def __call__(
dim: int,
width: int,
height: int,
+ frames: Optional[int] = None,
unfeather_left: bool = False,
unfeather_top: bool = False,
unfeather_right: bool = False,
unfeather_bottom: bool = False,
+ unfeather_start: bool = False,
+ unfeather_end: bool = False,
**kwargs: Any
) -> Tensor:
"""
@@ -355,4 +387,10 @@ def __call__(
unfeather_bottom=unfeather_bottom,
**kwargs
)
- return mask.unsqueeze(0).unsqueeze(0).repeat(batch, dim, 1, 1)
+
+ return self.temporal(
+ mask.unsqueeze(0).unsqueeze(0).repeat(batch, dim, 1, 1),
+ frames=frames,
+ unfeather_start=unfeather_start,
+ unfeather_end=unfeather_end,
+ )
diff --git a/src/python/enfugue/diffusion/util/video_util.py b/src/python/enfugue/diffusion/util/video_util.py
new file mode 100644
index 00000000..c41fd90b
--- /dev/null
+++ b/src/python/enfugue/diffusion/util/video_util.py
@@ -0,0 +1,164 @@
+from __future__ import annotations
+import os
+from typing import TYPE_CHECKING, Optional, Iterator, Callable, Iterable
+from enfugue.util import logger
+
+if TYPE_CHECKING:
+ from PIL.Image import Image
+ import cv2
+
+__all__ = ["Video"]
+
+def latent_friendly(number: int) -> int:
+ """
+ Returns a latent-friendly image size (divisible by 8)
+ """
+ return (number // 8) * 8
+
+class Video:
+ """
+ Provides helper methods for video
+ """
+ def __init__(self, frames: Iterable[Image]) -> None:
+ self.frames = frames
+
+ def save(
+ self,
+ path: str,
+ overwrite: bool = False,
+ rate: float = 20.0,
+ encoder: str = "avc1",
+ ) -> int:
+ """
+ Saves PIL image frames to a video.
+ Returns the total size of the video in bytes.
+ """
+ import cv2
+ from enfugue.diffusion.util import ComputerVision
+ if path.startswith("~"):
+ path = os.path.expanduser(path)
+ if os.path.exists(path):
+ if not overwrite:
+ raise IOError(f"File exists at path {path}, pass overwrite=True to write anyway.")
+ os.unlink(path)
+ basename, ext = os.path.splitext(os.path.basename(path))
+ if ext in [".gif", ".png", ".tiff", ".webp"]:
+ frames = [frame for frame in self.frames]
+ frames[0].save(path, loop=0, duration=1000.0/rate, save_all=True, append_images=frames[1:])
+ return os.path.getsize(path)
+ elif ext != ".mp4":
+ raise IOError(f"Unknown file extension {ext}")
+ fourcc = cv2.VideoWriter_fourcc(*encoder) # type: ignore
+ writer = None
+
+ for frame in self.frames:
+ if writer is None:
+ writer = cv2.VideoWriter(path, fourcc, rate, frame.size) # type: ignore[union-attr]
+ writer.write(ComputerVision.convert_image(frame))
+
+ if writer is None:
+ raise IOError(f"No frames written to {path}")
+
+ writer.release()
+
+ if not os.path.exists(path):
+ raise IOError(f"Nothing was written to {path}")
+ return os.path.getsize(path)
+
+ @classmethod
+ def file_to_frames(
+ cls,
+ path: str,
+ skip_frames: Optional[int] = None,
+ maximum_frames: Optional[int] = None,
+ resolution: Optional[int] = None,
+ on_open: Optional[Callable[[cv2.VideoCapture], None]] = None,
+ ) -> Iterator[Image.Image]:
+ """
+ Starts a video capture and yields PIL images for each frame.
+ """
+ import cv2
+ from enfugue.diffusion.util import ComputerVision
+
+ if path.startswith("~"):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ raise IOError(f"Video at path {path} not found or inaccessible")
+
+ basename, ext = os.path.splitext(os.path.basename(path))
+ if ext in [".gif", ".png", ".apng", ".tiff", ".webp", ".avif"]:
+ from PIL import Image
+ image = Image.open(path)
+ for i in range(image.n_frames):
+ image.seek(i)
+ copied = image.copy()
+ copied = copied.convert("RGBA")
+ yield copied
+ return
+
+ frames = 0
+
+ frame_start = 0 if skip_frames is None else skip_frames
+ frame_end = None if maximum_frames is None else frame_start + maximum_frames - 1
+
+ frame_string = "end-of-video" if frame_end is None else f"frame {frame_end}"
+ logger.debug(f"Reading video file at {path} starting from frame {frame_start} until {frame_string}")
+
+ capture = cv2.VideoCapture(path)
+ if on_open is not None:
+ on_open(capture)
+
+ def resize_image(image: Image.Image) -> Image.Image:
+ """
+ Resizes an image frame if requested.
+ """
+ if resolution is None:
+ return image
+
+ width, height = image.size
+ ratio = float(resolution) / float(min(width, height))
+ height = round(height * ratio)
+ width = round(width * ratio)
+ return image.resize((width, height))
+
+ while capture.isOpened():
+ success, image = capture.read()
+ if not success:
+ break
+ elif frames == 0:
+ logger.debug("First frame captured, iterating.")
+
+ frames += 1
+ if frame_start > frames:
+ continue
+
+ yield resize_image(ComputerVision.revert_image(image))
+
+ if frame_end is not None and frames >= frame_end:
+ break
+
+ capture.release()
+ if frames == 0:
+ raise IOError(f"No frames were read from video at {path}")
+
+ @classmethod
+ def from_file(
+ cls,
+ path: str,
+ skip_frames: Optional[int] = None,
+ maximum_frames: Optional[int] = None,
+ resolution: Optional[int] = None,
+ on_open: Optional[Callable[[cv2.VideoCapture], None]] = None,
+ ) -> Video:
+ """
+ Uses Video.frames_from_file and instantiates a Video object.
+ """
+ return cls(
+ frames=cls.file_to_frames(
+ path=path,
+ skip_frames=skip_frames,
+ maximum_frames=maximum_frames,
+ resolution=resolution,
+ on_open=on_open,
+ )
+ )
diff --git a/src/python/enfugue/diffusion/util/vision_util.py b/src/python/enfugue/diffusion/util/vision_util.py
index fab11a78..78f6e68c 100644
--- a/src/python/enfugue/diffusion/util/vision_util.py
+++ b/src/python/enfugue/diffusion/util/vision_util.py
@@ -1,28 +1,18 @@
-import os
import cv2
import numpy as np
-from typing import Iterator, Iterable, Optional, Callable
-
-from datetime import datetime
from PIL import Image
-from enfugue.util import logger
-__all__ = ["ComputerVision"]
+from typing import Union, Literal
-def latent_friendly(number: int) -> int:
- """
- Returns a latent-friendly image size (divisible by 8)
- """
- return (number // 8) * 8
+__all__ = ["ComputerVision"]
class ComputerVision:
"""
Provides helper methods for cv2
"""
-
- @staticmethod
- def show(name: str, image: Image.Image) -> None:
+ @classmethod
+ def show(cls, name: str, image: Image.Image) -> None:
"""
Shows an image.
Tries to use the Colab monkeypatch first, in case this is being ran in Colab.
@@ -35,212 +25,78 @@ def show(name: str, image: Image.Image) -> None:
cv2.waitKey(0)
cv2.destroyAllWindows()
- @staticmethod
- def convert_image(image: Image.Image) -> np.ndarray:
+ @classmethod
+ def convert_image(cls, image: Image.Image) -> np.ndarray:
"""
Converts PIL image to OpenCV format.
"""
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
- @staticmethod
- def revert_image(array: np.ndarray) -> Image.Image:
+ @classmethod
+ def revert_image(cls, array: np.ndarray) -> Image.Image:
"""
- Converts PIL image to OpenCV format.
+ Converts OpenCV format to PIL image
"""
return Image.fromarray(cv2.cvtColor(array, cv2.COLOR_BGR2RGB))
- @staticmethod
- def frames_to_video(
- path: str,
- frames: Iterable[Image.Image],
- overwrite: bool = False,
- rate: float = 20.
- ) -> int:
- """
- Saves PIL image frames to an .mp4 video.
- Returns the total size of the video in bytes.
- """
- if path.startswith("~"):
- path = os.path.expanduser(path)
- if os.path.exists(path):
- if not overwrite:
- raise IOError(f"File exists at path {path}, pass overwrite=True to write anyway.")
- os.unlink(path)
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
- writer = None
- for frame in frames:
- if writer is None:
- writer = cv2.VideoWriter(path, fourcc, rate, frame.size)
- writer.write(ComputerVision.convert_image(frame))
- if writer is None:
- raise IOError(f"No frames written to {path}")
-
- writer.release()
- return os.path.getsize(path)
-
- @staticmethod
- def frames_from_video(
- path: str,
- skip_frames: Optional[int] = None,
- maximum_frames: Optional[int] = None,
- resolution: Optional[int] = None,
- ) -> Iterator[Image.Image]:
- """
- Starts a video capture and yields PIL images for each frame.
- """
- if path.startswith("~"):
- path = os.path.expanduser(path)
- if not os.path.exists(path):
- raise IOError(f"Video at path {path} not found or inaccessible")
- frames = 0
-
- frame_start = 0 if skip_frames is None else skip_frames
- frame_end = None if maximum_frames is None else frame_start + maximum_frames - 1
-
- frame_string = "end-of-video" if frame_end is None else f"frame {frame_end}"
- logger.debug(f"Reading video file at {path} starting from frame {frame_start} until {frame_string}")
-
- capture = cv2.VideoCapture(path)
-
- def resize_image(image: Image.Image) -> Image.Image:
- """
- Resizes an image frame if requested.
- """
- if resolution is None:
- return image
-
- width, height = image.size
- ratio = float(resolution) / float(min(width, height))
- height = round(height * ratio)
- width = round(width * ratio)
- return image.resize((width, height))
-
- while capture.isOpened():
- success, image = capture.read()
- if not success:
- break
- elif frames == 0:
- logger.debug("First frame captured, iterating.")
-
- frames += 1
- if frame_start > frames:
- continue
-
- yield resize_image(ComputerVision.revert_image(image))
-
- if frame_end is not None and frames >= frame_end:
- break
-
- capture.release()
- if frames == 0:
- raise IOError(f"No frames were read from video at {path}")
-
- @staticmethod
- def video_to_video(
- source_path: str,
- destination_path: str,
- overwrite: bool = False,
- rate: float = 20.,
- skip_frames: Optional[int] = None,
- maximum_frames: Optional[int] = None,
- resolution: Optional[int] = None,
- process_frame: Optional[Callable[[Image.Image], Image.Image]] = None,
- ) -> int:
+ @classmethod
+ def noise(
+ cls,
+ image: Union[np.ndarray, Image.Image],
+ method: Literal["gaussian", "poisson", "speckle", "salt-and-pepper"] = "poisson",
+ gaussian_mean: Union[int, float] = 0.0,
+ gaussian_variance: float = 0.01,
+ poisson_factor: Union[int, float] = 2.25,
+ salt_pepper_ratio: float = 0.5,
+ salt_pepper_amount: float = 0.004,
+ speckle_amount: float = 0.01,
+ ) -> Union[np.ndarray, Image.Image]:
"""
- Saves PIL image frames to an .mp4 video.
- Returns the total size of the video in bytes.
+ Adds noise to an image.
"""
- if destination_path.startswith("~"):
- destination_path = os.path.expanduser(destination_path)
- if os.path.exists(destination_path):
- if not overwrite:
- raise IOError(f"File exists at destination_path {destination_path}, pass overwrite=True to write anyway.")
- os.unlink(destination_path)
-
- if source_path.startswith("~"):
- source_path = os.path.expanduser(source_path)
- if not os.path.exists(source_path):
- raise IOError(f"Video at path {source_path} not found or inaccessible")
-
- frames = 0
-
- frame_start = 0 if skip_frames is None else skip_frames
- frame_end = None if maximum_frames is None else frame_start + maximum_frames - 1
-
- frame_string = "end-of-video" if frame_end is None else f"frame {frame_end}"
- logger.debug(f"Reading video file at {source_path} starting from frame {frame_start} until {frame_string}. Will process and write to {destination_path}")
-
- capture = cv2.VideoCapture(source_path)
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
- writer = None
-
- def process_image(image: Image.Image) -> Image.Image:
- """
- Processes an image frame if requested.
- """
- width, height = image.size
- if resolution is not None:
- ratio = float(resolution) / float(min(width, height))
- height = round(height * ratio)
- width = round(width * ratio)
-
- image = image.resize((
- latent_friendly(width),
- latent_friendly(height)
- ))
-
- if process_frame is not None:
- image = process_frame(image)
-
- return image
-
- opened = datetime.now()
- started = opened
- last_log = 0
- processed_frames = 0
-
- while capture.isOpened():
- success, image = capture.read()
- if not success:
- break
- elif frames == 0:
- opened = datetime.now()
- logger.debug("Video opened, iterating through frames.")
-
- frames += 1
- if frame_start > frames:
- continue
- elif frame_start == frames:
- started = datetime.now()
- logger.debug(f"Beginning processing from frame {frames}")
-
- image = process_image(ComputerVision.revert_image(image))
- processed_frames += 1
-
- if writer is None:
- writer = cv2.VideoWriter(destination_path, fourcc, rate, image.size)
-
- writer.write(ComputerVision.convert_image(image))
-
- if last_log < processed_frames - rate:
- unit = "frames/sec"
- process_rate = processed_frames / (datetime.now() - started).total_seconds()
- if process_rate < 1.0:
- unit = "sec/frame"
- process_rate = 1.0 / process_rate
-
- logger.debug(f"Processed {processed_frames} at {process_rate:.2f} {unit}")
- last_log = processed_frames
-
- if frame_end is not None and frames >= frame_end:
- break
- if writer is None:
- raise IOError(f"No frames written to path {destination_path}")
-
- writer.release()
- capture.release()
-
- if frames == 0:
- raise IOError(f"No frames were read from video at {source_path}")
-
- return os.path.getsize(destination_path)
+ return_pil = isinstance(image, Image.Image)
+ if return_pil:
+ image = cls.convert_image(image)
+ image = image.astype(np.float64) / 255.0
+ width, height, channels = image.shape
+ if method == "gaussian":
+ gaussian_sigma = gaussian_variance ** 0.5
+ gaussian = np.random.normal(
+ gaussian_mean,
+ gaussian_sigma,
+ (width, height, channels)
+ )
+ gaussian = gaussian.reshape(width, height, channels)
+ image += gaussian
+ elif method == "salt-and-pepper":
+ output = np.copy(image)
+ # Do salt
+ salt = np.ceil(salt_pepper_amount * image.size * salt_pepper_ratio)
+ coordinates = [
+ np.random.randint(0, i - 1, int(salt))
+ for i in image.shape
+ ]
+ output[coordinates] = 1
+ # Do pepper
+ pepper = np.ceil(salt_pepper_amount * image.size * (1.0 - salt_pepper_ratio))
+ coordinates = [
+ np.random.randint(0, i - 1, int(pepper))
+ for i in image.shape
+ ]
+ output[coordinates] = 0
+ image = output
+ elif method == "poisson":
+ distinct_values = len(np.unique(image))
+ distinct_values = poisson_factor ** np.ceil(np.log2(distinct_values))
+ image = np.random.poisson(image * distinct_values) / float(distinct_values)
+ elif method == "speckle":
+ speckled = np.random.randn(width, height, channels)
+ speckled = speckled.reshape(width, height, channels)
+ image += (image * speckled * speckle_amount)
+ else:
+ raise ValueError(f"Unknown noise method {method}") # type: ignore[unreachable]
+ image *= 255.0
+ image = image.astype(np.uint8)
+ if return_pil:
+ return cls.revert_image(image)
+ return image
diff --git a/src/python/enfugue/partner/civitai.py b/src/python/enfugue/partner/civitai.py
index e5a03eed..4a21dffd 100644
--- a/src/python/enfugue/partner/civitai.py
+++ b/src/python/enfugue/partner/civitai.py
@@ -107,6 +107,7 @@ def get_models(
"LORA",
"Controlnet",
"Poses",
+ "MotionModule",
]
] = None,
sort: Optional[Literal["Highest Rated", "Most Downloaded", "Newest"]] = None,
diff --git a/src/python/enfugue/server.py b/src/python/enfugue/server.py
index c0650c47..f8c37942 100644
--- a/src/python/enfugue/server.py
+++ b/src/python/enfugue/server.py
@@ -1,5 +1,9 @@
import os
import tempfile
+import multiprocessing
+
+multiprocessing.set_start_method("spawn", force=True)
+
from typing import Dict, Any
from enfugue.api import EnfugueAPIServer
diff --git a/src/python/enfugue/setup.py b/src/python/enfugue/setup.py
index a1ef9368..d2df9d38 100644
--- a/src/python/enfugue/setup.py
+++ b/src/python/enfugue/setup.py
@@ -29,7 +29,7 @@
"colored>=1.4,<1.5",
"diffusers>=0.18", # Minimum, works with 0.20.dev
"albumentations>=0.4.3,<0.5",
- "opencv-python>=4.7.0.72,<4.8",
+ "opencv-python>=4.6.0.66,<5.0",
"pudb==2019.2",
"invisible-watermark>=0.2,<0.3",
"imageio>=2.31.1,<3.0",
@@ -56,9 +56,7 @@
"torchsde>=0.2.5,<0.3",
"timm>=0.9.2,<1.0",
"opensimplex>=0.4.5,<0.5",
- "taming-transformers",
- "clip",
- "latent-diffusion",
+ "tensorflow", # Any version
]
extras_require = {
@@ -70,6 +68,12 @@
"onnx-graphsurgeon==0.3.26",
"tensorrt>=8.6.0,<8.7",
],
+ "source": [
+ # These packages should be installed from source, but we'll put them herer too
+ "taming-transformers",
+ "clip",
+ "latent-diffusion",
+ ],
"build": [
"mypy==1.2.0",
"mypy-extensions==1.0.0",
diff --git a/src/python/enfugue/test/0_e2e_diffusion.py b/src/python/enfugue/test/0_e2e_diffusion.py
index 7d7bd3da..d05916ed 100644
--- a/src/python/enfugue/test/0_e2e_diffusion.py
+++ b/src/python/enfugue/test/0_e2e_diffusion.py
@@ -16,8 +16,8 @@
GRID_SIZE = 256
GRID_COLS = 4
CAPTION_HEIGHT = 50
-CHECKPOINT = "epicphotogasm_x.safetensors"
-CHECKPOINT_URL = "https://civitai.com/api/download/models/172306?type=Model&format=SafeTensor&size=pruned&fp=fp16"
+CHECKPOINT = "epicphotogasm_zUniversal.safetensors"
+CHECKPOINT_URL = "https://civitai.com/api/download/models/201259"
INPAINT_IMAGE = "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
INPAINT_MASK = "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
@@ -83,7 +83,7 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
# Two seeds for controlled/non-controlled:
# Controlnet doesn't change enough if it uses the same
# seed as the prompt the image was generated with
- kwargs["seed"] = 12345 if "control_images" in kwargs else 54321
+ kwargs["seed"] = 123456 if "control_images" in kwargs else 654321
if "model" not in kwargs:
kwargs["model"] = CHECKPOINT
kwargs["intermediates"] = False
@@ -127,7 +127,10 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
invoke(
"img2img",
prompt=prompt,
- image=base,
+ layers=[{
+ "image": base,
+ "denoise": True
+ }],
strength=0.8
)
@@ -137,63 +140,72 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
invoke(
"inpaint",
prompt="a handsome man with ray-ban sunglasses",
- image=inpaint_image,
mask=inpaint_mask,
+ layers=[{
+ "image": inpaint_image,
+ "fit": "cover"
+ }],
width=512,
height=512,
- fit="cover"
)
invoke(
"inpaint-4ch",
inpainter=CHECKPOINT,
prompt="a handsome man with ray-ban sunglasses",
- image=inpaint_image,
mask=inpaint_mask,
+ layers=[{
+ "image": inpaint_image,
+ "fit": "cover"
+ }],
width=512,
height=512,
- fit="cover"
)
# Automatic background removal with no inference
invoke(
"background",
- image=inpaint_image,
- remove_background=True,
+ layers=[{
+ "image": inpaint_image,
+ "fit": "cover",
+ "remove_background": True
+ }],
+ outpaint=False
)
# Automatic background removal with outpaint
invoke(
"background-fill",
prompt="a handsome man outside on a sunny day, green forest in the distance",
- image=inpaint_image,
- remove_background=True,
- fill_background=True
+ layers=[{
+ "image": inpaint_image,
+ "fit": "cover",
+ "remove_background": True
+ }]
)
# IP Adapter
invoke(
"ip-adapter",
- ip_adapter_images=[{
+ layers=[{
"image": inpaint_image,
- "scale": 0.3
+ "ip_adapter_scale": 0.3
}]
)
invoke(
"ip-adapter-plus",
- ip_adapter_plus=True,
- ip_adapter_images=[{
+ ip_adapter_model="plus",
+ layers=[{
"image": inpaint_image,
- "scale": 0.3
+ "ip_adapter_scale": 0.3
}]
)
invoke(
"ip-adapter-plus-face",
- ip_adapter_plus=True,
- ip_adapter_face=True,
- ip_adapter_images=[{
- "image": inpaint_image,
- "scale": 0.3
+ ip_adapter_model="plus-face",
+ layers=[{
+ "image": inpaint_image,
+ "ip_adapter_scale": 0.3
}]
)
@@ -202,7 +214,7 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
"outpaint",
prompt="a handsome man walking outside on a boardwalk, distant house, foggy weather, dark clouds overhead",
negative_prompt="frame, framing, comic book paneling, multiple images, awning, roof, shelter, trellice",
- nodes=[
+ layers=[
{
"image": inpaint_image,
"x": 128,
@@ -212,43 +224,18 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
"fit": "cover"
}
],
- strength=1.0
+ outpaint=True
)
- # Regions + multi-diffusion
- invoke(
- "regions",
- prompt="Roses in a bouquet",
- chunking_size=128,
- nodes=[
- {
- "x": 0,
- "y": 0,
- "w": 256,
- "h": 512,
- "prompt": "A single red rose, white background",
- "negative_prompt": "bouquet",
- "remove_background": True
- },
- {
- "x": 256,
- "y": 0,
- "w": 256,
- "h": 512,
- "prompt": "A single white rose, black background",
- "negative_prompt": "bouquet",
- "remove_background": True
- }
- ]
- )
-
# Controlnets
for controlnet in ["canny", "hed", "pidi", "scribble", "depth", "normal", "mlsd", "line", "anime", "pose"]:
invoke(
f"txt2img-controlnet-{controlnet}",
prompt=prompt,
- control_images=[{
- "controlnet": controlnet,
+ layers=[{
+ "control_units": [{
+ "controlnet": controlnet,
+ }],
"image": base,
}]
)
@@ -256,26 +243,28 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
invoke(
f"img2img-controlnet-{controlnet}",
prompt=prompt,
- image=base,
strength=0.8,
- control_images=[{
- "controlnet": controlnet,
+ layers=[{
+ "control_units": [{
+ "controlnet": controlnet,
+ }],
"image": base,
+ "denoise": True
}]
)
invoke(
f"img2img-ip-controlnet-{controlnet}",
prompt=prompt,
- image=base,
strength=0.8,
- ip_adapter_images=[{
- "image": base,
- "scale": 0.5
- }],
- control_images=[{
- "controlnet": controlnet,
+ layers=[{
+ "control_units": [{
+ "controlnet": controlnet,
+ }],
+
"image": base,
+ "denoise": True,
+ "ip_adapter_scale": 0.5
}]
)
@@ -286,43 +275,48 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
prompt=prompt,
scheduler=scheduler
)
-
- invoke(
- f"txt2img-multi-scheduler-{scheduler}",
- prompt=prompt,
- scheduler=scheduler,
- height=768,
- width=786,
- chunking_size=256,
- )
+ if scheduler != "dpmsde":
+ invoke(
+ f"txt2img-multi-scheduler-{scheduler}",
+ prompt=prompt,
+ scheduler=scheduler,
+ height=768,
+ width=786,
+ tiling_stride=256,
+ )
# Upscalers
invoke(
f"upscale-standalone-esrgan",
- upscale_steps=[{
+ upscale=[{
"amount": 2,
"method": "esrgan"
}],
- image=base
+ layers=[{
+ "image": base
+ }]
)
invoke(
f"upscale-standalone-gfpgan",
- upscale_steps=[{
+ upscale=[{
"amount": 2,
"method": "gfpgan"
}],
- image=base
+ layers=[{
+ "image": base
+ }]
)
+
invoke(
f"upscale-iterative-diffusion",
prompt="A green tree frog",
- upscale_steps=[
+ upscale=[
{
"amount": 2,
"method": "esrgan",
"strength": 0.15,
"controlnets": "tile",
- "chunking_size": 128,
+ "tiling_stride": 128,
"guidance_scale": 8
},
{
@@ -330,10 +324,10 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
"method": "esrgan",
"strength": 0.1,
"controlnets": "tile",
- "chunking_size": 256,
+ "tiling_stride": 256,
"guidance_scale": 8
}
- ],
+ ]
)
# SDXL
@@ -349,20 +343,24 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
"sdxl-inpaint",
prompt="a handsome man with ray-ban sunglasses",
model=DEFAULT_SDXL_MODEL,
- image=fit_image(inpaint_image, width=1024, height=1024, fit="stretch"),
mask=fit_image(inpaint_mask, width=1024, height=1024, fit="stretch"),
width=1024,
height=1024,
+ layers=[{
+ "image": fit_image(inpaint_image, width=1024, height=1024, fit="stretch")
+ }]
)
invoke(
"sdxl-inpaint-4ch",
prompt="a handsome man with ray-ban sunglasses",
inpainter=DEFAULT_SDXL_MODEL,
- image=fit_image(inpaint_image, width=1024, height=1024, fit="stretch"),
mask=fit_image(inpaint_mask, width=1024, height=1024, fit="stretch"),
width=1024,
height=1024,
+ layers=[{
+ "image": fit_image(inpaint_image, width=1024, height=1024, fit="stretch")
+ }]
)
control = invoke(
@@ -379,23 +377,27 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
invoke(
f"sdxl-{controlnet}-txt2img",
model=DEFAULT_SDXL_MODEL,
- control_images=[{
- "controlnet": controlnet,
- "image": control,
- "scale": 0.5
+ layers=[{
+ "control_units": [{
+ "controlnet": controlnet,
+ "scale": 0.5
+ }],
+ "image": control
}],
prompt="A bride and groom on their wedding day",
guidance_scale=6
)[0]
-
+
invoke(
f"sdxl-{controlnet}-txt2img-refined",
model=DEFAULT_SDXL_MODEL,
refiner=DEFAULT_SDXL_REFINER,
- control_images=[{
- "controlnet": controlnet,
- "image": control,
- "scale": 0.5
+ layers=[{
+ "control_units": [{
+ "controlnet": controlnet,
+ "scale": 0.5
+ }],
+ "image": control
}],
prompt="A bride and groom on their wedding day",
refiner_start=0.85,
@@ -405,15 +407,15 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
invoke(
f"sdxl-{controlnet}-img2img",
model=DEFAULT_SDXL_MODEL,
- image=control,
strength=0.8,
- control_images=[
- {
+ layers=[{
+ "control_units": [{
"controlnet": controlnet,
- "image": control,
"scale": 0.5
- }
- ],
+ }],
+ "image": control,
+ "denoise": True
+ }],
prompt="A bride and groom on their wedding day",
guidance_scale=6
)[0]
@@ -422,14 +424,14 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
f"sdxl-{controlnet}-img2img-refined",
model=DEFAULT_SDXL_MODEL,
refiner=DEFAULT_SDXL_REFINER,
- control_images=[
- {
+ layers=[{
+ "control_units": [{
"controlnet": controlnet,
- "image": control,
"scale": 0.5
- }
- ],
- image=control,
+ }],
+ "image": control,
+ "denoise": True
+ }],
strength=0.8,
refiner_start=0.85,
prompt="A bride and groom on their wedding day",
@@ -440,18 +442,15 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
f"sdxl-{controlnet}-img2img-ip-refined",
model=DEFAULT_SDXL_MODEL,
refiner=DEFAULT_SDXL_REFINER,
- control_images=[
- {
+ strength=0.8,
+ layers=[{
+ "control_units": [{
"controlnet": controlnet,
- "image": control,
"scale": 0.5
- }
- ],
- image=control,
- strength=0.8,
- ip_adapter_images=[{
+ }],
"image": control,
- "scale": 0.5
+ "denoise": True,
+ "ip_adapter_scale": 0.5
}],
prompt="A bride and groom on their wedding day",
refiner_start=0.85,
@@ -462,20 +461,17 @@ def invoke(name: str, **kwargs: Any) -> List[Image.Image]:
f"sdxl-{controlnet}-img2img-ip-plus-refined",
model=DEFAULT_SDXL_MODEL,
refiner=DEFAULT_SDXL_REFINER,
- control_images=[
- {
+ strength=0.8,
+ ip_adapter_model="plus",
+ layers=[{
+ "control_units": [{
"controlnet": controlnet,
- "image": control,
"scale": 0.5
- }
- ],
- image=control,
- strength=0.8,
- ip_adapter_images=[{
+ }],
"image": control,
- "scale": 0.5
+ "denoise": True,
+ "ip_adapter_scale": 0.5
}],
- ip_adapter_plus=True,
prompt="A bride and groom on their wedding day",
refiner_start=0.85,
guidance_scale=6
diff --git a/src/python/enfugue/test/2_latent_scaling.py b/src/python/enfugue/test/0_latent_scaling.py
similarity index 99%
rename from src/python/enfugue/test/2_latent_scaling.py
rename to src/python/enfugue/test/0_latent_scaling.py
index 8dbea1f4..7c7b6b6f 100644
--- a/src/python/enfugue/test/2_latent_scaling.py
+++ b/src/python/enfugue/test/0_latent_scaling.py
@@ -20,7 +20,6 @@ def main() -> None:
downscale_source = image_from_uri(DOWNSCALE_SOURCE)
manager = DiffusionPipelineManager()
- manager.chunking_size = 0
device = manager.device
dtype = manager.dtype
diff --git a/src/python/enfugue/test/1_interpolate.py b/src/python/enfugue/test/1_interpolate.py
new file mode 100644
index 00000000..322ea764
--- /dev/null
+++ b/src/python/enfugue/test/1_interpolate.py
@@ -0,0 +1,43 @@
+import io
+import os
+import PIL
+import requests
+
+from enfugue.util import image_from_uri, fit_image
+from datetime import datetime
+
+from pibble.util.log import DebugUnifiedLoggingContext
+
+BASE_IMAGE = "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
+
+def main() -> None:
+ with DebugUnifiedLoggingContext():
+ save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test-results", "background-remove")
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ image = image_from_uri(BASE_IMAGE)
+ image = fit_image(image, width=544, height=512, fit="cover")
+ left = fit_image(image, width=512, height=512, anchor="top-left")
+ right = fit_image(image, width=512, height=512, anchor="top-right")
+ left.save(os.path.join(save_dir, "left.png"))
+ right.save(os.path.join(save_dir, "right.png"))
+
+ #manager = DiffusionPipelineManager()
+ frames = 32
+ from enfugue.diffusion.interpolate.interpolator import InterpolationEngine
+ interpolator = InterpolationEngine()
+ with interpolator:
+ start = datetime.now()
+ frame_list = interpolator(
+ [left, right],
+ (2,2)
+ )
+ for i, frame in enumerate(frame_list):
+ frame.save(os.path.join(save_dir, f"interpolated-{i+1}.png"))
+ seconds = (datetime.now() - start).total_seconds()
+ average = seconds / (frames-2)
+ print(f"Interpolated {frames-2} frames in {seconds} ({average}s/frame)")
+
+if __name__ == "__main__":
+ main()
diff --git a/src/python/enfugue/test/1_layers.py b/src/python/enfugue/test/1_layers.py
new file mode 100644
index 00000000..c2d7c30c
--- /dev/null
+++ b/src/python/enfugue/test/1_layers.py
@@ -0,0 +1,178 @@
+"""
+Uses the invocation planner to test parsing of various layer states from UI
+"""
+import os
+
+from pibble.util.log import DebugUnifiedLoggingContext
+
+from enfugue.diffusion.manager import DiffusionPipelineManager
+from enfugue.diffusion.invocation import LayeredInvocation
+from enfugue.util import logger, save_frames_or_image
+
+from PIL import Image
+
+def main() -> None:
+ HERE = os.path.dirname(os.path.abspath(__file__))
+ with DebugUnifiedLoggingContext():
+ save_dir = os.path.join(HERE, "test-results", "layers")
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ image = Image.open(os.path.join(HERE, "test-images", "small-inpaint.jpg"))
+ mask = Image.open(os.path.join(HERE, "test-images", "small-inpaint-mask-invert.jpg"))
+
+ manager = DiffusionPipelineManager()
+
+ def log_invocation(name, invocation):
+ logger.info(name)
+ processed = invocation.preprocess(manager, raise_when_unused=False)
+ formatted = LayeredInvocation.format_serialization_dict(
+ save_directory=save_dir,
+ save_name=name,
+ **processed
+ )
+ logger.debug(f"{formatted}")
+
+ # Layered, image covers entire canvas
+ log_invocation(
+ "simple",
+ LayeredInvocation(
+ width=512,
+ height=512,
+ layers=[
+ {
+ "image": image,
+ }
+ ]
+ )
+ )
+
+ # Layered, image covers part of canvas, no denoising
+ log_invocation(
+ "outpaint",
+ LayeredInvocation(
+ width=512,
+ height=512,
+ layers=[
+ {
+ "image": image,
+ "x": 0,
+ "y": 0,
+ "w": 256,
+ "h": 256
+ }
+ ]
+ )
+ )
+
+ # Layered, image covers part of canvas, denoising
+ log_invocation(
+ "overpaint",
+ LayeredInvocation(
+ width=512,
+ height=512,
+ layers=[
+ {
+ "image": image,
+ "x": 0,
+ "y": 0,
+ "w": 512,
+ "h": 256,
+ "denoise": True
+ }
+ ]
+ )
+ )
+
+ # Layered, image covers part of canvas, joined with mask
+ log_invocation(
+ "outpaint-merge",
+ LayeredInvocation(
+ width=512,
+ height=512,
+ mask={
+ "image": mask,
+ "invert": True
+ },
+ layers=[
+ {
+ "image": image,
+ "x": 0,
+ "y": 0,
+ "w": 512,
+ "h": 256,
+ }
+ ]
+ )
+ )
+
+ # Layered, image covers part of canvas, joined with mask
+ log_invocation(
+ "outpaint-merge",
+ LayeredInvocation(
+ width=512,
+ height=512,
+ mask={
+ "image": mask,
+ "invert": True
+ },
+ layers=[
+ {
+ "image": image,
+ "x": 0,
+ "y": 0,
+ "w": 512,
+ "h": 256,
+ }
+ ]
+ )
+ )
+
+ # Layered, image covers entirety of canvas, rembg
+ log_invocation(
+ "remove_background",
+ LayeredInvocation(
+ width=512,
+ height=512,
+ layers=[
+ {
+ "image": image,
+ "remove_background": True
+ }
+ ]
+ )
+ )
+
+ # Layered, two copies of images, merged masks
+ log_invocation(
+ "remove_background_merge",
+ LayeredInvocation(
+ width=512,
+ height=512,
+ mask={
+ "image": mask,
+ "invert": True
+ },
+ layers=[
+ {
+ "image": image,
+ "x": 128,
+ "w": 512-128,
+ "y": 0,
+ "h": 512
+ },
+ {
+ "image": image,
+ "remove_background": True,
+ "x": 0,
+ "y": 0,
+ "h": 256,
+ "w": 256,
+ "fit": "stretch"
+ }
+ ]
+ )
+ )
+
+if __name__ == "__main__":
+ main()
diff --git a/src/python/enfugue/test/1_layers_invocation.py b/src/python/enfugue/test/1_layers_invocation.py
new file mode 100644
index 00000000..f574b715
--- /dev/null
+++ b/src/python/enfugue/test/1_layers_invocation.py
@@ -0,0 +1,31 @@
+"""
+Uses the engine to exercise layered plan instantiation and execution
+"""
+import os
+from pibble.util.log import DebugUnifiedLoggingContext
+from enfugue.util import logger
+from enfugue.diffusion.invocation import LayeredInvocation
+from enfugue.diffusion.manager import DiffusionPipelineManager
+
+def main() -> None:
+ with DebugUnifiedLoggingContext():
+ save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test-results", "base")
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ manager = DiffusionPipelineManager()
+
+ # Base plan
+ plan = LayeredInvocation.assemble(
+ width=512,
+ height=512,
+ prompt="A happy looking puppy"
+ )
+
+ plan.execute(
+ manager,
+ task_callback=lambda arg: logger.info(arg),
+ )["images"][0].save(os.path.join(save_dir, "./puppy-plan.png"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/python/enfugue/test/2_inpaint.py b/src/python/enfugue/test/2_inpaint.py
index a9e762a2..a5e8426d 100644
--- a/src/python/enfugue/test/2_inpaint.py
+++ b/src/python/enfugue/test/2_inpaint.py
@@ -8,8 +8,9 @@
from typing import List
from pibble.util.log import DebugUnifiedLoggingContext
from enfugue.util import logger
-from enfugue.diffusion.plan import DiffusionPlan
from enfugue.diffusion.manager import DiffusionPipelineManager
+from enfugue.diffusion.invocation import LayeredInvocation
+from enfugue.diffusion.constants import DEFAULT_INPAINTING_MODEL
HERE = os.path.dirname(os.path.abspath(__file__))
@@ -24,30 +25,38 @@ def main() -> None:
image = PIL.Image.open(os.path.join(HERE, "test-images", "small-inpaint.jpg"))
mask = PIL.Image.open(os.path.join(HERE, "test-images", "small-inpaint-mask-invert.jpg"))
+
prompt = "a man breakdancing in front of a bright blue sky"
negative_prompt = "tree, skyline, buildings"
+ width, height = image.size
- plan = DiffusionPlan.assemble(
- size=512,
- prompt = prompt,
- negative_prompt = negative_prompt,
- num_inference_steps = 20,
- image = image,
- mask = mask,
- invert_mask = True
+ plan = LayeredInvocation.assemble(
+ width=width,
+ height=height,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=20,
+ image=image,
+ mask={
+ "image": mask,
+ "invert": True
+ }
)
plan.execute(manager)["images"][0].save(os.path.join(save_dir, f"result.png"))
- plan = DiffusionPlan.assemble(
- size=512,
- inpainter = "v1-5-pruned.ckpt", # Force 4-dim inpainting
- prompt = "blue sky and green grass",
- negative_prompt = negative_prompt,
- num_inference_steps = 20,
- image = image,
- mask = mask,
- invert_mask = True
+ plan = LayeredInvocation.assemble(
+ width=width,
+ height=height,
+ inpainter="v1-5-pruned.ckpt", # Force 4-dim inpainting
+ prompt="blue sky and green grass",
+ negative_prompt=negative_prompt,
+ num_inference_steps=20,
+ image=image,
+ mask={
+ "image": mask,
+ "invert": True
+ }
)
plan.execute(manager)["images"][0].save(os.path.join(save_dir, f"result-4-dim.png"))
diff --git a/src/python/enfugue/test/2_large_inpaint.py b/src/python/enfugue/test/2_large_inpaint.py
index 601cc64f..57fdee88 100644
--- a/src/python/enfugue/test/2_large_inpaint.py
+++ b/src/python/enfugue/test/2_large_inpaint.py
@@ -8,7 +8,7 @@
from typing import List
from pibble.util.log import DebugUnifiedLoggingContext
from enfugue.util import logger
-from enfugue.diffusion.plan import DiffusionPlan
+from enfugue.diffusion.invocation import LayeredInvocation
from enfugue.diffusion.manager import DiffusionPipelineManager
HERE = os.path.dirname(os.path.abspath(__file__))
@@ -32,20 +32,15 @@ def main() -> None:
width, height = image.size
prompt, negative_prompt = PROMPTS[size]
- plan = DiffusionPlan.assemble(
- size = 512,
+ plan = LayeredInvocation.assemble(
prompt = prompt,
negative_prompt = negative_prompt,
num_inference_steps = 20,
width = width,
height = height,
- nodes = [
- {
- "image": image,
- "mask": mask,
- "strength": 1.0
- }
- ]
+ mask=mask,
+ image=image,
+ strength=1.0,
)
plan.execute(manager)["images"][0].save(os.path.join(save_dir, f"{size}-result.png"))
diff --git a/src/python/enfugue/test/2_noise.py b/src/python/enfugue/test/2_noise.py
index 21f95f03..442c3402 100644
--- a/src/python/enfugue/test/2_noise.py
+++ b/src/python/enfugue/test/2_noise.py
@@ -16,7 +16,6 @@ def main() -> None:
manager = DiffusionPipelineManager()
manager.seed = 12345
- manager.chunking_size = 0
device = manager.device
dtype = manager.dtype
diff --git a/src/python/enfugue/test/2_pipemanager_xl.py b/src/python/enfugue/test/2_pipemanager_xl.py
index 1c3ae7cc..8548dde9 100644
--- a/src/python/enfugue/test/2_pipemanager_xl.py
+++ b/src/python/enfugue/test/2_pipemanager_xl.py
@@ -22,7 +22,6 @@ def main() -> None:
# Start with absolute defaults.
# Even if there's nothing on your machine, this should work by downloading everything needed.
manager = DiffusionPipelineManager()
- manager.size = 1024
manager.model = DEFAULT_SDXL_MODEL
def run_and_save(filename: str) -> None:
diff --git a/src/python/enfugue/test/2_plan.py b/src/python/enfugue/test/2_plan.py
deleted file mode 100644
index e4a96521..00000000
--- a/src/python/enfugue/test/2_plan.py
+++ /dev/null
@@ -1,78 +0,0 @@
-"""
-Uses the engine to create a simple image using default settings
-"""
-import os
-from pibble.util.log import DebugUnifiedLoggingContext
-from enfugue.diffusion.plan import DiffusionPlan
-from enfugue.diffusion.manager import DiffusionPipelineManager
-
-def main() -> None:
- with DebugUnifiedLoggingContext():
- save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test-results", "base")
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
- manager = DiffusionPipelineManager()
- # Base plan
- manager.seed = 123456
- plan = DiffusionPlan.assemble(size=512, prompt="A happy looking puppy")
- plan.execute(manager)["images"][0].save(os.path.join(save_dir, "./puppy-plan.png"))
-
- # Inpainting + region prompt + background removal
- plan = DiffusionPlan.assemble(
- size=512,
- prompt="A cat and dog laying on a couch",
- nodes=[
- {
- "x": 0,
- "y": 128,
- "w": 256,
- "h": 256,
- "prompt": "A golden retriever laying down",
- "remove_background": True
- },
- {
- "x": 256,
- "y": 128,
- "w": 256,
- "h": 256,
- "prompt": "A cat laying down",
- "remove_background": True
- }
- ]
- )
- plan.execute(manager)["images"][0].save(os.path.join(save_dir, "./puppy-kitty-inpaint.png"))
-
- # Upscale
- plan.upscale_steps = {
- "amount": 2,
- "method": "esrgan"
- }
- manager.seed = 12345
- plan.execute(manager)["images"][0].save(os.path.join(save_dir, "./puppy-plan-upscale.png"))
-
- # Upscale diffusion
- plan.upscale_steps = {
- "amount": 2,
- "method": "esrgan",
- "strength": 0.2
- }
- manager.seed = 12345
- result = plan.execute(manager)["images"][0]
- result.save(os.path.join(save_dir, "./puppy-plan-upscale-diffusion.png"))
-
- # Upscale again just from the image
- plan = DiffusionPlan.upscale_image(
- size=512,
- image=result,
- upscale_steps=[{
- "method": "esrgan",
- "amount": 2,
- "strength": 0.2,
- "chunking_size": 256,
- }]
- )
- plan.execute(manager)["images"][0].save(os.path.join(save_dir, "./puppy-plan-upscale-solo.png"))
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/python/enfugue/test/2_schedulers.py b/src/python/enfugue/test/2_schedulers.py
index 4e484eb2..5d6c9938 100644
--- a/src/python/enfugue/test/2_schedulers.py
+++ b/src/python/enfugue/test/2_schedulers.py
@@ -3,10 +3,10 @@
"""
import os
import PIL
+import traceback
from typing import List, Any
from pibble.util.log import DebugUnifiedLoggingContext
from enfugue.util import logger
-from enfugue.diffusion.plan import DiffusionPlan
from enfugue.diffusion.manager import DiffusionPipelineManager
SCHEDULERS = [
@@ -48,7 +48,7 @@ def main() -> None:
}
multi_kwargs = {
"width": 768,
- "chunking_size": 64,
+ "tiling_stride": 64,
}
def run_and_save(target_dir: str, **other_kwargs: Any) -> None:
@@ -79,6 +79,7 @@ def intermediate_callback(images: List[PIL.Image.Image]) -> None:
run_and_save(target_dir_multi, **multi_kwargs)
except Exception as ex:
logger.error("Error with scheduler {0}: {1}({2})".format(scheduler, type(ex).__name__, ex))
+ logger.debug(traceback.format_exc())
if __name__ == "__main__":
main()
diff --git a/src/python/enfugue/test/2_upscale_diffusion.py b/src/python/enfugue/test/2_upscale_diffusion.py
new file mode 100644
index 00000000..e816d599
--- /dev/null
+++ b/src/python/enfugue/test/2_upscale_diffusion.py
@@ -0,0 +1,38 @@
+"""
+Uses the engine to exercise layered plan instantiation and execution
+"""
+import os
+from pibble.util.log import DebugUnifiedLoggingContext
+from enfugue.util import logger
+from enfugue.diffusion.invocation import LayeredInvocation
+from enfugue.diffusion.manager import DiffusionPipelineManager
+
+def main() -> None:
+ with DebugUnifiedLoggingContext():
+ save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test-results", "base")
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ manager = DiffusionPipelineManager()
+
+ # Re-diffused upscaled image
+ plan = LayeredInvocation.assemble(
+ seed=12345,
+ width=512,
+ height=512,
+ prompt="A happy looking puppy",
+ upscale={
+ "method": "esrgan",
+ "amount": 2,
+ "strength": 0.2
+ }
+ )
+
+ plan.execute(
+ manager,
+ task_callback=lambda arg: logger.info(arg),
+ )["images"][0].save(os.path.join(save_dir, "./puppy-plan-upscale.png"))
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/python/enfugue/test/2_vae.py b/src/python/enfugue/test/2_vae.py
index cc9f56cb..2dfd5c6f 100644
--- a/src/python/enfugue/test/2_vae.py
+++ b/src/python/enfugue/test/2_vae.py
@@ -8,7 +8,6 @@
from typing import List
from pibble.util.log import DebugUnifiedLoggingContext
from enfugue.util import logger
-from enfugue.diffusion.plan import DiffusionPlan
from enfugue.diffusion.manager import DiffusionPipelineManager
VAE = [
diff --git a/src/python/enfugue/test/3_inpaint_xl.py b/src/python/enfugue/test/3_inpaint_xl.py
index b7bcd4ac..b7f5d780 100644
--- a/src/python/enfugue/test/3_inpaint_xl.py
+++ b/src/python/enfugue/test/3_inpaint_xl.py
@@ -8,7 +8,7 @@
from typing import List
from pibble.util.log import DebugUnifiedLoggingContext
from enfugue.util import logger
-from enfugue.diffusion.plan import DiffusionPlan
+from enfugue.diffusion.invocation import LayeredInvocation
from enfugue.diffusion.constants import DEFAULT_SDXL_MODEL
from enfugue.diffusion.manager import DiffusionPipelineManager
@@ -19,7 +19,7 @@ def main() -> None:
save_dir = os.path.join(HERE, "test-results", "inpaint")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
-
+
image = PIL.Image.open(os.path.join(HERE, "test-images", "inpaint-xl.jpg"))
mask = PIL.Image.open(os.path.join(HERE, "test-images", "inpaint-xl-mask.jpg"))
@@ -28,7 +28,7 @@ def main() -> None:
prompt = "a huge cactus standing in the desert"
- plan = DiffusionPlan.assemble(
+ plan = LayeredInvocation.assemble(
size = 1024,
prompt = prompt,
model = DEFAULT_SDXL_MODEL,
@@ -39,5 +39,17 @@ def main() -> None:
plan.execute(manager)["images"][0].save(os.path.join(save_dir, f"result-xl.png"))
+ # Force 4-dim
+ plan = LayeredInvocation.assemble(
+ size = 1024,
+ prompt = prompt,
+ inpainter = DEFAULT_SDXL_MODEL,
+ num_inference_steps = 50,
+ image = image,
+ mask = mask
+ )
+
+ plan.execute(manager)["images"][0].save(os.path.join(save_dir, f"result-xl-4dim.png"))
+
if __name__ == "__main__":
main()
diff --git a/src/python/enfugue/test/3_plan_xl.py b/src/python/enfugue/test/3_plan_xl.py
index b33b4b5f..c4a76740 100644
--- a/src/python/enfugue/test/3_plan_xl.py
+++ b/src/python/enfugue/test/3_plan_xl.py
@@ -1,9 +1,9 @@
"""
-Uses the engine to create a simple image using default settings
+Uses the engine to create a simple image using default settings in XL
"""
import os
from pibble.util.log import DebugUnifiedLoggingContext
-from enfugue.diffusion.plan import DiffusionPlan, DiffusionNode, DiffusionStep
+from enfugue.diffusion.invocation import LayeredInvocation
from enfugue.diffusion.manager import DiffusionPipelineManager
from enfugue.diffusion.constants import DEFAULT_SDXL_MODEL, DEFAULT_SDXL_REFINER
@@ -14,51 +14,47 @@ def main() -> None:
os.makedirs(save_dir)
kwargs = {
+ "seed": 12345,
"size": 1024,
"model": DEFAULT_SDXL_MODEL,
- "refiner": DEFAULT_SDXL_REFINER,
"prompt": "A happy looking puppy",
- "upscale_diffusion_guidance_scale": 5.0,
- "upscale_diffusion_strength": 0.3,
- "upscale_diffusion_steps": 50
+ "refiner": DEFAULT_SDXL_REFINER,
+ "refiner_start": 0.85
}
manager = DiffusionPipelineManager()
# Base plan
- manager.seed = 123456
- plan = DiffusionPlan.assemble(**kwargs)
+ plan = LayeredInvocation.assemble(**kwargs)
plan.execute(manager)["images"][0].save(os.path.join(save_dir, "./puppy-plan-xl.png"))
# Upscale
- plan.upscale_steps = [{
+ plan.upscale = [{
"amount": 2,
"method": "esrgan"
}]
- manager.seed = 123456
plan.execute(manager)["images"][0].save(os.path.join(save_dir, "./puppy-plan-xl-upscale.png"))
# Upscale diffusion at 2x
- plan.upscale_steps = [{
+ plan.upscale = [{
"amount": 2,
"method": "esrgan",
"strength": 0.2,
"num_inference_steps": 50,
- "guidance_scale": 10.0
+ "guidance_scale": 10.0,
+ "tiling_stride": 256,
}]
- plan.upscale_diffusion = True
- manager.seed = 123456
result = plan.execute(manager)["images"][0]
result.save(os.path.join(save_dir, "./puppy-plan-xl-upscale-diffusion.png"))
# Upscale again alone
- plan = DiffusionPlan.upscale_image(
+ plan = LayeredInvocation.assemble(
image=result,
- upscale_steps=[{
+ upscale=[{
"amount": 2,
"method": "esrgan",
"strength": 0.2,
- "chunking_size": 512,
+ "tiling_stride": 512,
"strength": 0.2,
"num_inference_steps": 50,
"guidance_scale": 10.0
diff --git a/src/python/enfugue/test/3_vae_xl.py b/src/python/enfugue/test/3_vae_xl.py
index e267a648..20b300e0 100644
--- a/src/python/enfugue/test/3_vae_xl.py
+++ b/src/python/enfugue/test/3_vae_xl.py
@@ -8,7 +8,6 @@
from typing import List
from pibble.util.log import DebugUnifiedLoggingContext
from enfugue.util import logger
-from enfugue.diffusion.plan import DiffusionPlan
from enfugue.diffusion.manager import DiffusionPipelineManager
from enfugue.diffusion.constants import DEFAULT_SDXL_MODEL, DEFAULT_SDXL_REFINER
diff --git a/src/python/enfugue/test/5_animate.py b/src/python/enfugue/test/5_animate.py
new file mode 100644
index 00000000..786001f7
--- /dev/null
+++ b/src/python/enfugue/test/5_animate.py
@@ -0,0 +1,45 @@
+"""
+Tests automatic loading of motion module/animator pipeline
+"""
+import os
+import PIL
+
+from enfugue.util import logger
+from enfugue.diffusion.engine import DiffusionEngine
+from enfugue.diffusion.constants import *
+from enfugue.diffusion.util import Video
+
+from pibble.util.log import DebugUnifiedLoggingContext
+from pibble.util.numeric import human_size
+
+PROMPT = "a beautiful woman smiling, open mouth, bright teeth"
+FRAMES = 16
+RATE = 8.0
+
+def main() -> None:
+ here = os.path.dirname(os.path.abspath(__file__))
+ output_dir = os.path.join(here, "test-results", "animation")
+
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ with DebugUnifiedLoggingContext():
+ with DiffusionEngine.debug() as engine:
+ base_result = engine(
+ animation_frames=FRAMES,
+ seed=123456,
+ prompt=PROMPT,
+ )["images"]
+
+ logger.debug(f"Result is {len(base_result)} frames, writing.")
+
+ target = os.path.join(output_dir, "base.mp4")
+ size = Video(base_result).save(
+ target,
+ rate=RATE,
+ overwrite=True,
+ )
+ print(f"Wrote {FRAMES} frames to {target} ({human_size(size)}")
+
+if __name__ == "__main__":
+ main()
diff --git a/src/python/enfugue/util/downloads.py b/src/python/enfugue/util/downloads.py
index eb429fa1..42c760ad 100644
--- a/src/python/enfugue/util/downloads.py
+++ b/src/python/enfugue/util/downloads.py
@@ -1,7 +1,7 @@
import os
import requests
-from typing import Optional
+from typing import Optional, Callable
from enfugue.util.log import logger
@@ -11,7 +11,8 @@ def check_download(
remote_url: str,
local_path: str,
chunk_size: int=8192,
- check_size: bool=True
+ check_size: bool=True,
+ progress_callback: Optional[Callable[[int, int], None]]=None,
) -> None:
"""
Checks if a file exists.
@@ -30,17 +31,24 @@ def check_download(
if not os.path.exists(local_path):
logger.info(f"Downloading file from {remote_url}. Will write to {local_path}")
response = requests.get(remote_url, allow_redirects=True, stream=True)
+ content_length: Optional[int] = response.headers.get("Content-Length", None) # type: ignore[assignment]
+ if content_length is not None:
+ content_length = int(content_length)
with open(local_path, "wb") as fh:
+ written_bytes = 0
for chunk in response.iter_content(chunk_size=chunk_size):
fh.write(chunk)
-
+ if progress_callback is not None and content_length is not None:
+ written_bytes = min(written_bytes + chunk_size, content_length)
+ progress_callback(written_bytes, content_length)
def check_download_to_dir(
remote_url: str,
local_dir: str,
- file_name: Optional[str] = None,
+ file_name: Optional[str]=None,
chunk_size: int=8192,
- check_size: bool=True
+ check_size: bool=True,
+ progress_callback: Optional[Callable[[int, int], None]]=None,
) -> str:
"""
Checks if a file exists in a directory based on a remote path.
@@ -50,5 +58,11 @@ def check_download_to_dir(
if file_name is None:
file_name = os.path.basename(remote_url)
local_path = os.path.join(local_dir, file_name)
- check_download(remote_url, local_path, chunk_size=chunk_size, check_size=check_size)
+ check_download(
+ remote_url,
+ local_path,
+ chunk_size=chunk_size,
+ check_size=check_size,
+ progress_callback=progress_callback
+ )
return local_path
diff --git a/src/python/enfugue/util/images.py b/src/python/enfugue/util/images.py
index ec17c82f..a92e57f5 100644
--- a/src/python/enfugue/util/images.py
+++ b/src/python/enfugue/util/images.py
@@ -1,20 +1,30 @@
from __future__ import annotations
import io
+import os
import math
-from typing import Optional, Literal, TYPE_CHECKING
+from typing import Optional, Literal, Union, List, Tuple, Dict, Any, TYPE_CHECKING
if TYPE_CHECKING:
from PIL.Image import Image
from pibble.resources.retriever import Retriever
+from pibble.util.strings import get_uuid
__all__ = [
"fit_image",
"feather_mask",
+ "tile_image",
"image_from_uri",
"images_are_equal",
+ "get_frames_or_image",
+ "get_frames_or_image_from_file",
+ "save_frames_or_image",
+ "create_mask",
+ "scale_image",
+ "get_image_metadata",
+ "redact_images_from_metadata",
"IMAGE_FIT_LITERAL",
"IMAGE_ANCHOR_LITERAL",
]
@@ -32,18 +42,41 @@
"bottom-right",
]
-
def fit_image(
- image: Image,
+ image: Union[Image, List[Image]],
width: int,
height: int,
fit: Optional[IMAGE_FIT_LITERAL] = None,
anchor: Optional[IMAGE_ANCHOR_LITERAL] = None,
+ offset_left: Optional[int] = None,
+ offset_top: Optional[int] = None
) -> Image:
"""
Given an image of unknown size, make it a known size with optional fit parameters.
"""
+ if not isinstance(image, list):
+ if getattr(image, "n_frames", 1) > 1:
+ frames = []
+ for i in range(image.n_frames):
+ image.seek(i)
+ frames.append(image.copy().convert("RGBA"))
+ image = frames
+ if isinstance(image, list):
+ return [
+ fit_image(
+ img,
+ width=width,
+ height=height,
+ fit=fit,
+ anchor=anchor,
+ offset_left=offset_left,
+ offset_top=offset_top,
+ )
+ for img in image
+ ]
+
from PIL import Image
+
if fit is None or fit == "actual":
left, top = 0, 0
crop_left, crop_top = 0, 0
@@ -63,16 +96,25 @@ def fit_image(
left = width - image_width
blank_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
- blank_image.paste(image, (left, top))
+
+ if offset_top is not None:
+ top += offset_top
+ if offset_left is not None:
+ left += offset_left
+ if image.mode == "RGBA":
+ blank_image.paste(image, (left, top), image)
+ else:
+ blank_image.paste(image, (left, top))
return blank_image
+
elif fit == "contain":
image_width, image_height = image.size
width_ratio, height_ratio = width / image_width, height / image_height
- horizontal_image_width, horizontal_image_height = int(image_width * width_ratio), int(
- image_height * width_ratio
- )
- vertical_image_width, vertical_image_height = int(image_width * height_ratio), int(image_height * height_ratio)
+ horizontal_image_width = int(image_width * width_ratio)
+ horizontal_image_height = int(image_height * width_ratio)
+ vertical_image_width = int(image_width * height_ratio)
+ vertical_image_height = int(image_height * height_ratio)
top, left = 0, 0
direction = None
if width >= horizontal_image_width and height >= horizontal_image_height:
@@ -91,19 +133,27 @@ def fit_image(
left = width // 2 - vertical_image_width // 2
elif left_part == "right":
left = width - vertical_image_width
+
+ if offset_top is not None:
+ top += offset_top
+ if offset_left is not None:
+ left += offset_left
+
blank_image = Image.new("RGBA", (width, height))
- blank_image.paste(input_image, (left, top))
+ if input_image.mode == "RGBA":
+ blank_image.paste(input_image, (left, top), input_image)
+ else:
+ blank_image.paste(input_image, (left, top))
return blank_image
+
elif fit == "cover":
image_width, image_height = image.size
width_ratio, height_ratio = width / image_width, height / image_height
- horizontal_image_width, horizontal_image_height = math.ceil(image_width * width_ratio), math.ceil(
- image_height * width_ratio
- )
- vertical_image_width, vertical_image_height = math.ceil(image_width * height_ratio), math.ceil(
- image_height * height_ratio
- )
+ horizontal_image_width = math.ceil(image_width * width_ratio)
+ horizontal_image_height = math.ceil(image_height * width_ratio)
+ vertical_image_width = math.ceil(image_width * height_ratio)
+ vertical_image_height = math.ceil(image_height * height_ratio)
top, left = 0, 0
direction = None
if width <= horizontal_image_width and height <= horizontal_image_height:
@@ -124,20 +174,45 @@ def fit_image(
left = width - vertical_image_width
else:
input_image = image.resize((width, height)) # We're probably off by a pixel
+
+ if offset_top is not None:
+ top += offset_top
+ if offset_left is not None:
+ left += offset_left
+
blank_image = Image.new("RGBA", (width, height))
- blank_image.paste(input_image, (left, top))
+ if input_image.mode == "RGBA":
+ blank_image.paste(input_image, (left, top), input_image)
+ else:
+ blank_image.paste(input_image, (left, top))
return blank_image
+
elif fit == "stretch":
return image.resize((width, height)).convert("RGBA")
+
else:
raise ValueError(f"Unknown fit {fit}")
-
-def feather_mask(image: Image) -> Image:
+def feather_mask(
+ image: Union[Image, List[Image]]
+) -> Union[Image, List[Image]]:
"""
Given an image, create a feathered binarized mask by 'growing' the black/white pixel sections.
"""
+ if not isinstance(image, list):
+ if getattr(image, "n_frames", 1) > 1:
+ frames = []
+ for i in range(image.n_frames):
+ image.seek(i)
+ frames.append(image.copy().convert("RGBA"))
+ image = frames
+ if isinstance(image, list):
+ return [
+ feather_mask(img)
+ for img in image
+ ]
+
width, height = image.size
mask = image.convert("L")
@@ -151,9 +226,25 @@ def feather_mask(image: Image) -> Image:
if 0 <= nx < width and 0 <= ny < height and mask.getpixel((nx, ny)) == 255:
feathered.putpixel((x, y), (255))
break
-
return feathered
+def tile_image(image: Image, tiles: Union[int, Tuple[int, int]]) -> Image:
+ """
+ Given an image and number of tiles, create a tiled image.
+ Accepts either an integer (squre tiles) or tuple (rectangular)
+ """
+ from PIL import Image
+ width, height = image.size
+ if isinstance(tiles, tuple):
+ width_tiles, height_tiles = tiles
+ else:
+ width_tiles, height_tiles = tiles, tiles
+ tiled = Image.new(image.mode, (width * width_tiles, height * height_tiles))
+ for i in range(width_tiles):
+ for j in range(height_tiles):
+ tiled.paste(image, (i * width, j * height))
+ return tiled
+
def image_from_uri(uri: str) -> Image:
"""
Loads an image using the pibble reteiever; works with http, file, ftp, ftps, sftp, and s3
@@ -191,3 +282,143 @@ def image_pixelize(image: Image, factor: int = 2, exact: bool = True) -> None:
image = image.resize((downsample_width, downsample_height), resample=Resampling.NEAREST)
image = image.resize((upsample_width, upsample_height), resample=Resampling.NEAREST)
return image
+
+def get_frames_or_image(image: Union[Image, List[Image]]) -> Union[Image, List[Image]]:
+ """
+ Makes sure an image is a list of images if it has more than one frame
+ """
+ if not isinstance(image, list):
+ if getattr(image, "n_frames", 1) > 1:
+ def get_frame(i: int) -> Image:
+ image.seek(i) # type: ignore[union-attr]
+ return image.copy().convert("RGB") # type: ignore[union-attr]
+ return [
+ get_frame(i)
+ for i in range(image.n_frames)
+ ]
+ return image
+
+def save_frames_or_image(
+ image: Union[Image, List[Image]],
+ directory: str,
+ name: Optional[str]=None,
+ video_format: str="webp",
+ image_format: str="png"
+) -> str:
+ """
+ Saves frames to image or video
+ """
+ image = get_frames_or_image(image)
+ if name is None:
+ name = get_uuid()
+ if isinstance(image, list):
+ from enfugue.diffusion.util.video_util import Video
+ path = os.path.join(directory, f"{name}.{video_format}")
+ Video(image).save(path)
+ else:
+ path = os.path.join(directory, f"{name}.{image_format}")
+ image.save(path)
+ return path
+
+def get_frames_or_image_from_file(path: str) -> Union[Image, List[Image]]:
+ """
+ Opens a file to a single image or multiple
+ """
+ if path.startswith("data:"):
+ # Should be a video
+ if not path.startswith("data:video"):
+ raise IOError(f"Received non-video data in video handler: {path}")
+ # Dump to tempfile
+ from tempfile import mktemp
+ from base64 import b64decode
+ header, _, data = path.partition(",")
+ fmt, _, encoding = header.partition(";")
+ _, _, file_ext = fmt.partition("/")
+ dump_file = mktemp(f".{file_ext}")
+ try:
+ with open(dump_file, "wb") as fh:
+ fh.write(b64decode(data))
+ from enfugue.diffusion.util.video_util import Video
+ return list(Video.file_to_frames(dump_file))
+ finally:
+ os.unlink(dump_file)
+ else:
+ name, ext = os.path.splitext(path)
+
+ if ext in [".webp", ".webm", ".mp4", ".avi", ".mov", ".gif", ".m4v", ".mkv", ".ogg"]:
+ from enfugue.diffusion.util.video_util import Video
+ return list(Video.file_to_frames(path))
+ else:
+ from PIL import Image
+ return Image.open(path)
+
+def create_mask(
+ width: int,
+ height: int,
+ left: int,
+ top: int,
+ right: int,
+ bottom: int
+) -> Image:
+ """
+ Creates a mask from 6 dimensions
+ """
+ from PIL import Image, ImageDraw
+ image = Image.new("RGB", (width, height))
+ draw = ImageDraw.Draw(image)
+ draw.rectangle([(left, top), (right, bottom)], fill="#ffffff")
+ return image
+
+def scale_image(image: Image, scale: Union[int, float]) -> Image:
+ """
+ Scales an image proportionally.
+ """
+ width, height = image.size
+ scaled_width = 8 * round((width * scale) / 8)
+ scaled_height = 8 * round((height * scale) / 8)
+ return image.resize((scaled_width, scaled_height))
+
+def get_image_metadata(image: Union[str, Image, List[Image]]) -> Dict[str, Any]:
+ """
+ Gets metadata from an image
+ """
+ if isinstance(image, str):
+ return get_image_metadata(get_frames_or_image_from_file(image))
+ elif isinstance(image, list):
+ (width, height) = image[0].size
+ return {
+ "width": width,
+ "height": height,
+ "frames": len(image),
+ "metadata": getattr(image[0], "text", {}),
+ }
+ else:
+ (width, height) = image.size
+ return {
+ "width": width,
+ "height": height,
+ "metadata": getattr(image, "text", {})
+ }
+
+def redact_images_from_metadata(metadata: Dict[str, Any]) -> None:
+ """
+ Removes images from a metadata dictionary
+ """
+ for key in ["image", "mask"]:
+ image = metadata.get(key, None)
+ if image is not None:
+ if isinstance(image, dict):
+ image["image"] = get_image_metadata(image["image"])
+ elif isinstance(image, str):
+ metadata[key] = get_image_metadata(metadata[key])
+ else:
+ metadata[key] = get_image_metadata(metadata[key])
+ if "control_images" in metadata:
+ for i, control_dict in enumerate(metadata["control_images"]):
+ control_dict["image"] = get_image_metadata(control_dict["image"])
+ if "ip_adapter_images" in metadata:
+ for i, ip_adapter_dict in enumerate(metadata["ip_adapter_images"]):
+ ip_adapter_dict["image"] = get_image_metadata(ip_adapter_dict["image"])
+ if "layers" in metadata:
+ for layer in metadata["layers"]:
+ redact_images_from_metadata(layer)
diff --git a/src/python/enfugue/util/installation.py b/src/python/enfugue/util/installation.py
index 8f0f4648..c21f6551 100644
--- a/src/python/enfugue/util/installation.py
+++ b/src/python/enfugue/util/installation.py
@@ -6,6 +6,7 @@
from typing import TypedDict, List, Dict, Any, Iterator, Optional, Union, cast
from semantic_version import Version
+from pibble.api.configuration import APIConfiguration
from pibble.util.files import load_yaml, load_json
__all__ = [
@@ -22,17 +23,14 @@
"find_files_in_directory"
]
-
class VersionDict(TypedDict):
"""
The version dictionary.
"""
-
version: Version
release: datetime.date
description: str
-
def get_local_installation_directory() -> str:
"""
Gets where the local installation directory is (i.e. where the package data files are,
@@ -46,7 +44,6 @@ def get_local_installation_directory() -> str:
raise IOError("Couldn't find installation directory.")
return here
-
def get_local_config_directory() -> str:
"""
Gets where the local configuration directory is.
@@ -58,7 +55,6 @@ def get_local_config_directory() -> str:
raise IOError("Couldn't find config directory.")
return os.path.join(here, "config")
-
def get_local_static_directory() -> str:
"""
Gets where the local static directory is.
@@ -70,7 +66,6 @@ def get_local_static_directory() -> str:
raise IOError("Couldn't find static directory.")
return os.path.join(here, "static")
-
def check_make_directory(directory: str) -> None:
"""
Checks if a directory doesn't exist, and makes it.
@@ -80,13 +75,12 @@ def check_make_directory(directory: str) -> None:
try:
os.makedirs(directory)
return
- except:
+ except Exception as ex:
if not os.path.exists(directory):
- raise
+ raise IOError(f"Couldn't create directory `{directory}`: {type(ex).__name__}({ex})")
return
-
-def get_local_configuration() -> Dict[str, Any]:
+def get_local_configuration(as_api_configuration: bool = False) -> Union[Dict[str, Any], APIConfiguration]:
"""
Gets configuration from a file in the environment, or the base config.
"""
@@ -105,6 +99,8 @@ def get_local_configuration() -> Dict[str, Any]:
raise IOError(f"Unknown extension {ext}")
if "configuration" in configuration:
configuration = configuration["configuration"]
+ if as_api_configuration:
+ return APIConfiguration(**configuration)
return configuration
diff --git a/src/python/enfugue/util/misc.py b/src/python/enfugue/util/misc.py
index 4097c39d..8b94b7fe 100644
--- a/src/python/enfugue/util/misc.py
+++ b/src/python/enfugue/util/misc.py
@@ -1,9 +1,17 @@
from typing import Dict, Any
__all__ = [
- "merge_into"
+ "noop",
+ "merge_into",
+ "replace_images",
+ "redact_for_log"
]
+def noop(*args: Any, **kwargs: Any) -> None:
+ """
+ Does nothing.
+ """
+
def merge_into(source: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]:
"""
Merges a source dictionary into a target dictionary.
@@ -19,3 +27,48 @@ def merge_into(source: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]:
else:
dest[key] = source[key]
return dest
+
+def replace_images(dictionary: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Replaces images in a dictionary with a metadata dictionary.
+ """
+ from PIL.Image import Image
+ for key, value in dictionary.items():
+ if isinstance(value, Image):
+ width, height = value.size
+ metadata = {"width": width, "height": height, "mode": value.mode}
+ if hasattr(value, "filename"):
+ metadata["filename"] = value.filename
+ if hasattr(value, "text"):
+ metadata["text"] = value.text
+ dictionary[key] = metadata
+ elif isinstance(value, dict):
+ dictionary[key] = replace_images(value)
+ elif isinstance(value, list) or isinstance(value, tuple):
+ dictionary[key] = [
+ replace_images(part) if isinstance(part, dict) else part
+ for part in value
+ ]
+ if isinstance(value, tuple):
+ dictionary[key] = tuple(dictionary[key])
+ return dictionary
+
+def redact_for_log(dictionary: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Redacts prompts from logs to encourage log sharing for troubleshooting.
+ """
+ redacted = {}
+ for key, value in dictionary.items():
+ if isinstance(value, dict):
+ redacted[key] = redact_for_log(value)
+ elif isinstance(value, tuple):
+ redacted[key] = "(" + ", ".join([str(redact_for_log({"v": v})["v"]) for v in value]) + ")" # type: ignore[assignment]
+ elif isinstance(value, list):
+ redacted[key] = "[" + ", ".join([str(redact_for_log({"v": v})["v"]) for v in value]) + "]" # type: ignore[assignment]
+ elif type(value) not in [str, float, int, bool, type(None)]:
+ redacted[key] = type(value).__name__ # type: ignore[assignment]
+ elif "prompt" in key and value is not None:
+ redacted[key] = "***" # type: ignore[assignment]
+ else:
+ redacted[key] = str(value) # type: ignore[assignment]
+ return redacted
diff --git a/src/python/enfugue/util/tokens.py b/src/python/enfugue/util/tokens.py
index d8df3136..26a0c6a3 100644
--- a/src/python/enfugue/util/tokens.py
+++ b/src/python/enfugue/util/tokens.py
@@ -2,7 +2,10 @@
from typing import Dict, Union, Iterator, Tuple, Optional
-__all__ = ["TokenMerger"]
+__all__ = [
+ "TokenMerger",
+ "merge_tokens"
+]
class TokenMerger:
@@ -12,10 +15,16 @@ class TokenMerger:
tokens: Dict[str, Union[int, float]]
- def __init__(self, *initial_phrases: str) -> None:
+ def __init__(
+ self,
+ *initial_phrases: Union[str, Tuple[str, Union[int, float]]]
+ ) -> None:
self.tokens = {}
for phrase in initial_phrases:
- self.add(phrase)
+ weight = 1.0
+ if isinstance(phrase, tuple):
+ phrase, weight = phrase
+ self.add(phrase, weight)
def add(self, phrase: str, weight: Union[int, float] = 1) -> None:
"""
@@ -64,3 +73,9 @@ def __str__(self) -> str:
Stringifies the tokens.
"""
return ",".join([token for token, weight in iter(self)])
+
+def merge_tokens(**prompt_weights: float) -> str:
+ """
+ Merges any number of tokens quickly.
+ """
+ return str(TokenMerger(*list(prompt_weights.items())))