Skip to content

Commit

Permalink
feat: do not convert some tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet committed Aug 25, 2024
1 parent 28a6147 commit 79c9fe9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
68 changes: 44 additions & 24 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1397,10 +1397,11 @@ ggml_type ModelLoader::get_sd_wtype() {
continue;
}

if (tensor_storage.name.find(".weight") != std::string::npos &&
(tensor_storage.name.find("time_embed") != std::string::npos ||
tensor_storage.name.find("context_embedder") != std::string::npos ||
tensor_storage.name.find("time_in") != std::string::npos)) {
if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}

if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
}
}
Expand All @@ -1420,7 +1421,11 @@ ggml_type ModelLoader::get_conditioner_wtype() {
continue;
}

if (tensor_storage.name.find(".weight") != std::string::npos) {
if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}

if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
}
}
Expand All @@ -1437,10 +1442,11 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
continue;
}

if (tensor_storage.name.find(".weight") != std::string::npos &&
(tensor_storage.name.find("time_embed") != std::string::npos ||
tensor_storage.name.find("context_embedder") != std::string::npos ||
tensor_storage.name.find("time_in") != std::string::npos)) {
if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}

if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
}
}
Expand All @@ -1458,7 +1464,11 @@ ggml_type ModelLoader::get_vae_wtype() {
continue;
}

if (tensor_storage.name.find(".weight")) {
if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}

if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
}
}
Expand Down Expand Up @@ -1723,6 +1733,26 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
return true;
}

bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
const std::string& name = tensor_storage.name;
if (type != GGML_TYPE_COUNT) {
if (ggml_is_quantized(type) && tensor_storage.ne[0] % ggml_blck_size(type) != 0) {
// Pass, do not convert
} else if (ends_with(name, ".bias")) {
// Pass, do not convert
} else if (contains(name, "img_in.") || contains(name, "time_in.in_layer.") || contains(name, "vector_in.in_layer.") || contains(name, "guidance_in.in_layer.") || contains(name, "final_layer.linear.")) {
// Pass, do not convert. For FLUX
} else if (contains(name, "x_embedder.") || contains(name, "t_embedder.") || contains(name, "y_embedder.") || contains(name, "context_embedder.")) {
// Pass, do not convert. For MMDiT
} else if (contains(name, "time_embed.") || contains(name, "label_emb.")) {
// Pass, do not convert. For Unet
} else {
return true;
}
}
return false;
}

bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) {
auto backend = ggml_backend_cpu_init();
size_t mem_size = 1 * 1024 * 1024; // for padding
Expand All @@ -1737,12 +1767,8 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
const std::string& name = tensor_storage.name;

ggml_type tensor_type = tensor_storage.type;
if (type != GGML_TYPE_COUNT) {
if (ggml_is_quantized(type) && tensor_storage.ne[0] % ggml_blck_size(type) != 0) {
tensor_type = GGML_TYPE_F16;
} else {
tensor_type = type;
}
if (tensor_should_be_converted(tensor_storage, type)) {
tensor_type = type;
}

ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
Expand Down Expand Up @@ -1792,15 +1818,9 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
}

for (auto& tensor_storage : processed_tensor_storages) {
ggml_type tensor_type = tensor_storage.type;
if (type != GGML_TYPE_COUNT) {
if (ggml_is_quantized(type) && tensor_storage.ne[0] % 32 != 0) {
tensor_type = GGML_TYPE_F16;
} else {
tensor_type = type;
}
if (tensor_should_be_converted(tensor_storage, type)) {
tensor_storage.type = type;
}
tensor_storage.type = tensor_type;
mem_size += tensor_storage.nbytes() + alignment;
}

Expand Down
1 change: 1 addition & 0 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class ModelLoader {
ggml_backend_t backend,
std::set<std::string> ignore_tensors = {});
bool save_to_gguf_file(const std::string& file_path, ggml_type type);
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default;

Expand Down

0 comments on commit 79c9fe9

Please sign in to comment.