Skip to content

Commit

Permalink
fix and improve: VAE tiling
Browse files Browse the repository at this point in the history
- properly handle the upper left corner interpolating both x and y
- refactor out lerp
- use smootherstep to preserve more detail and spend less area blending
  • Loading branch information
Green-Sky committed Aug 26, 2024
1 parent 8847114 commit 22e48d8
Showing 1 changed file with 36 additions and 5 deletions.
41 changes: 36 additions & 5 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
}
}

__STATIC_INLINE__ float ggml_lerp_f32(const float a, const float b, const float x) {
return (1 - x) * a + x * b;
}

// unclamped -> expects x in the range [0-1]
__STATIC_INLINE__ float ggml_smootherstep_f32(const float x) {
GGML_ASSERT(x >= 0.f && x <= 1.f);
return x * x * x * (x * (6.0f * x - 15.0f) + 10.0f);
}

__STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
struct ggml_tensor* output,
int x,
Expand All @@ -364,12 +374,33 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
float new_value = ggml_tensor_get_f32(input, ix, iy, k);
if (overlap > 0) { // blend colors in overlapped area
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);
if (x > 0 && ix < overlap) { // in overlapped horizontal
ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (ix / (1.0f * overlap)), x + ix, y + iy, k);
const bool inside_x_overlap = x > 0 && ix < overlap;
const bool inside_y_overlap = y > 0 && iy < overlap;
if (inside_x_overlap && inside_y_overlap) {
// upper left corner needs to be interpolated in both directions
const float x_f = ix / float(overlap);
const float y_f = iy / float(overlap);
// TODO: try `x+y - 1`
const float f = std::min(x_f, y_f); // min of both
ggml_tensor_set_f32(
output,
ggml_lerp_f32(old_value, new_value, ggml_smootherstep_f32(f)),
x + ix, y + iy, k
);
continue;
}
if (y > 0 && iy < overlap) { // in overlapped vertical
ggml_tensor_set_f32(output, old_value + (new_value - old_value) * (iy / (1.0f * overlap)), x + ix, y + iy, k);
} else if (inside_x_overlap) {
ggml_tensor_set_f32(
output,
ggml_lerp_f32(old_value, new_value, ggml_smootherstep_f32(ix / float(overlap))),
x + ix, y + iy, k
);
continue;
} else if (inside_y_overlap) {
ggml_tensor_set_f32(
output,
ggml_lerp_f32(old_value, new_value, ggml_smootherstep_f32(iy / float(overlap))),
x + ix, y + iy, k
);
continue;
}
}
Expand Down

0 comments on commit 22e48d8

Please sign in to comment.