From c90b705334c3807f4f4042cbb351b59cbe5c6f7b Mon Sep 17 00:00:00 2001 From: JustinMeans <46542161+JustinMeans@users.noreply.github.com> Date: Tue, 20 Dec 2022 23:57:34 -0700 Subject: [PATCH] Adds Negative Prompts (#61) * Synced to main branch and minimizes line changes * Adds negative prompt argument to CLI Co-authored-by: Wanaldino Antimonio --- .../pipeline/StableDiffusionPipeline.swift | 9 ++++++--- swift/StableDiffusionCLI/main.swift | 4 ++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift index 8464f8a7..0cd2253c 100644 --- a/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift +++ b/swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift @@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging { /// /// - Parameters: /// - prompt: Text prompt to guide sampling + /// - negativePrompt: Negative text prompt to guide sampling /// - stepCount: Number of inference steps to perform /// - imageCount: Number of samples/images to generate for the input prompt /// - seed: Random seed which @@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging { /// The images will be nil if safety checks were performed and found the result to be un-safe public func generateImages( prompt: String, + negativePrompt: String = "", imageCount: Int = 1, stepCount: Int = 50, seed: UInt32 = 0, @@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging { progressHandler: (Progress) -> Bool = { _ in true } ) throws -> [CGImage?] { - // Encode the input prompt as well as a blank unconditioned input + // Encode the input prompt and negative prompt let promptEmbedding = try textEncoder.encode(prompt) - let blankEmbedding = try textEncoder.encode("") + let negativePromptEmbedding = try textEncoder.encode(negativePrompt) if reduceMemory { textEncoder.unloadResources() } // Convert to Unet hidden state representation + // Concatenate the prompt and negative prompt embeddings let concatEmbedding = MLShapedArray( - concatenating: [blankEmbedding, promptEmbedding], + concatenating: [negativePromptEmbedding, promptEmbedding], alongAxis: 0 ) diff --git a/swift/StableDiffusionCLI/main.swift b/swift/StableDiffusionCLI/main.swift index df0ad926..5cbe6271 100644 --- a/swift/StableDiffusionCLI/main.swift +++ b/swift/StableDiffusionCLI/main.swift @@ -19,6 +19,9 @@ struct StableDiffusionSample: ParsableCommand { @Argument(help: "Input string prompt") var prompt: String + @Option(help: "Input string negative prompt") + var negativePrompt: String + @Option( help: ArgumentHelp( "Path to stable diffusion resources.", @@ -85,6 +88,7 @@ struct StableDiffusionSample: ParsableCommand { let images = try pipeline.generateImages( prompt: prompt, + negativePrompt: negativePrompt, imageCount: imageCount, stepCount: stepCount, seed: seed,