Skip to content

Commit

Permalink
Adds Negative Prompts (#61)
Browse files Browse the repository at this point in the history
* Synced to main branch and minimizes line changes

* Adds negative prompt argument to CLI

Co-authored-by: Wanaldino Antimonio <[email protected]>
  • Loading branch information
JustinMeans and Wanaldino authored Dec 21, 2022
1 parent 4c00b32 commit c90b705
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 6 additions & 3 deletions swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
///
/// - Parameters:
/// - prompt: Text prompt to guide sampling
/// - negativePrompt: Negative text prompt to guide sampling
/// - stepCount: Number of inference steps to perform
/// - imageCount: Number of samples/images to generate for the input prompt
/// - seed: Random seed which
Expand All @@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// The images will be nil if safety checks were performed and found the result to be un-safe
public func generateImages(
prompt: String,
negativePrompt: String = "",
imageCount: Int = 1,
stepCount: Int = 50,
seed: UInt32 = 0,
Expand All @@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
progressHandler: (Progress) -> Bool = { _ in true }
) throws -> [CGImage?] {

// Encode the input prompt as well as a blank unconditioned input
// Encode the input prompt and negative prompt
let promptEmbedding = try textEncoder.encode(prompt)
let blankEmbedding = try textEncoder.encode("")
let negativePromptEmbedding = try textEncoder.encode(negativePrompt)

if reduceMemory {
textEncoder.unloadResources()
}

// Convert to Unet hidden state representation
// Concatenate the prompt and negative prompt embeddings
let concatEmbedding = MLShapedArray<Float32>(
concatenating: [blankEmbedding, promptEmbedding],
concatenating: [negativePromptEmbedding, promptEmbedding],
alongAxis: 0
)

Expand Down
4 changes: 4 additions & 0 deletions swift/StableDiffusionCLI/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ struct StableDiffusionSample: ParsableCommand {
@Argument(help: "Input string prompt")
var prompt: String

@Option(help: "Input string negative prompt")
var negativePrompt: String

@Option(
help: ArgumentHelp(
"Path to stable diffusion resources.",
Expand Down Expand Up @@ -85,6 +88,7 @@ struct StableDiffusionSample: ParsableCommand {

let images = try pipeline.generateImages(
prompt: prompt,
negativePrompt: negativePrompt,
imageCount: imageCount,
stepCount: stepCount,
seed: seed,
Expand Down

0 comments on commit c90b705

Please sign in to comment.