From f5997a19515cf1eb17e64dc97f558a319ca19363 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 25 Aug 2024 14:07:22 +0800 Subject: [PATCH] fix: do not force using f32 for some flux layers This sometimes leads to worse result --- flux.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flux.hpp b/flux.hpp index c837a1d8..73bc345a 100644 --- a/flux.hpp +++ b/flux.hpp @@ -13,7 +13,7 @@ namespace Flux { struct MLPEmbedder : public UnaryBlock { public: MLPEmbedder(int64_t in_dim, int64_t hidden_dim) { - blocks["in_layer"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true, true)); + blocks["in_layer"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true)); blocks["out_layer"] = std::shared_ptr(new Linear(hidden_dim, hidden_dim, true)); } @@ -449,7 +449,7 @@ namespace Flux { int64_t patch_size, int64_t out_channels) { blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); } @@ -634,13 +634,13 @@ namespace Flux { int64_t out_channels = params.in_channels; int64_t pe_dim = params.hidden_size / params.num_heads; - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true, true)); + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); if (params.guidance_embed) { blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); } - blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true, true)); + blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true)); for (int i = 0; i < params.depth; i++) { blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size,