Skip to content

Commit

Permalink
feat: enable controlnet and photo maker for img2img mode
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet committed Apr 14, 2024
1 parent ec82d52 commit 036ba9e
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 258 deletions.
80 changes: 45 additions & 35 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,13 +656,16 @@ int main(int argc, const char* argv[]) {
return 1;
}

bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
uint8_t* control_image_buffer = NULL;
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;

int c = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &params.width, &params.height, &c, 3);
int width = 0;
int height = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
return 1;
Expand All @@ -672,29 +675,30 @@ int main(int argc, const char* argv[]) {
free(input_image_buffer);
return 1;
}
if (params.width <= 0) {
if (width <= 0) {
fprintf(stderr, "error: the width of image must be greater than 0\n");
free(input_image_buffer);
return 1;
}
if (params.height <= 0) {
if (height <= 0) {
fprintf(stderr, "error: the height of image must be greater than 0\n");
free(input_image_buffer);
return 1;
}

// Resize input image ...
if (params.height % 64 != 0 || params.width % 64 != 0) {
int resized_height = params.height + (64 - params.height % 64);
int resized_width = params.width + (64 - params.width % 64);
if (params.height != height || params.width != width) {
printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);
int resized_height = params.height;
int resized_width = params.width;

uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
if (resized_image_buffer == NULL) {
fprintf(stderr, "error: allocate memory for resize input image\n");
free(input_image_buffer);
return 1;
}
stbir_resize(input_image_buffer, params.width, params.height, 0,
stbir_resize(input_image_buffer, width, height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
Expand All @@ -704,8 +708,6 @@ int main(int argc, const char* argv[]) {
// Save resized result
free(input_image_buffer);
input_image_buffer = resized_image_buffer;
params.height = resized_height;
params.width = resized_width;
}
}

Expand All @@ -732,31 +734,32 @@ int main(int argc, const char* argv[]) {
return 1;
}

sd_image_t* control_image = NULL;
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (control_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
control_image_buffer};
if (params.canny_preprocess) { // apply preprocessor
control_image->data = preprocess_canny(control_image->data,
control_image->width,
control_image->height,
0.08f,
0.08f,
0.8f,
1.0f,
false);
}
}

sd_image_t* results;
if (params.mode == TXT2IMG) {
sd_image_t* control_image = NULL;
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
input_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (input_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
input_image_buffer};
if (params.canny_preprocess) { // apply preprocessor
control_image->data = preprocess_canny(control_image->data,
control_image->width,
control_image->height,
0.08f,
0.08f,
0.8f,
1.0f,
false);
}
}
results = txt2img(sd_ctx,
params.prompt.c_str(),
params.negative_prompt.c_str(),
Expand Down Expand Up @@ -828,7 +831,12 @@ int main(int argc, const char* argv[]) {
params.sample_steps,
params.strength,
params.seed,
params.batch_count);
params.batch_count,
control_image,
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
}
}

Expand Down Expand Up @@ -881,6 +889,8 @@ int main(int argc, const char* argv[]) {
}
free(results);
free_sd_ctx(sd_ctx);
free(control_image_buffer);
free(input_image_buffer);

return 0;
}
7 changes: 3 additions & 4 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,10 +752,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding(
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
}


__STATIC_INLINE__ size_t ggml_tensor_num(ggml_context * ctx) {
__STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
size_t num = 0;
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
for (ggml_tensor* t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
num++;
}
return num;
Expand Down Expand Up @@ -851,7 +850,7 @@ struct GGMLModule {
}

public:
virtual std::string get_desc() = 0;
virtual std::string get_desc() = 0;

GGMLModule(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32)
: backend(backend), wtype(wtype) {
Expand Down
Loading

0 comments on commit 036ba9e

Please sign in to comment.