Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
qbits support f4 weight repack (#1653)
Browse files Browse the repository at this point in the history
* qbits support f4 weight repack

* fix
  • Loading branch information
zhewang1-intc authored Jul 10, 2024
1 parent 20765ab commit 3fd99c8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#pragma once
#include <ATen/core/TensorBody.h>
#include <torch/torch.h>
#include "bestla/bestla_storage.h"
#include "../include/dispatcher_utils.hpp"
#include <string.h>
#include <assert.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <chrono>
#include <string>
#include "bestla/bestla_device.h"
#include "bestla/bestla_storage.h"
#include "bestla/bestla_utils.h"
#include "bestla/bestla_parallel.h"
namespace dispatcher_utils {
Expand All @@ -26,6 +27,12 @@ inline bool check_avx_vnni() { return bestla::device::CpuDevice::getInstance()->
inline bool check_avx512f() { return bestla::device::CpuDevice::getInstance()->AVX512F(); }
inline bool check_avx2() { return bestla::device::CpuDevice::getInstance()->AVX2(); }

template <class GemmCore>
constexpr bool is_int8_cmpt_gemmcore() {
return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>;
}

class qbits_threading {
public:
static bestla::parallel::IThreading* get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@
#include "../include/bestla_packq_impl.hpp"

namespace woq {
template <class GemmCore, BTLA_ISA ISA>

template <class proB>
void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
using proB = bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>;
static proB ker;
auto qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym);
using WType = typename proB::StorageWeight;
WType qpackw(0);
if constexpr (std::is_same_v<WType, bestla::storage::gemm::StorageWeightKBlockNInteger>) {
qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym);
} else {
qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type),
scale2bestladt_map.at(p->scale_type));
}
if (p->enable_act_shuffle) ker.enableShuffle(&qpackw);
ctx->packw_size = qpackw.mSize;
if (task == WOQ_GET_PACKW_SIZE) return;
Expand All @@ -33,6 +40,20 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, dispatcher_utils::qbits_threading::get());
}

template <class GemmCore, BTLA_ISA ISA>
void parse_prob(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip") {
return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>>(p, ctx, task);
}
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1") {
TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization.");
return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNFloat<GemmCore, ISA>>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: unsupported bestla packq config, compute_type: " + p->compute_type +
" weight_type: " + p->weight_type);
}

std::string get_dtype_str(BTLA_DTYPE dtype) {
switch (dtype) {
case BTLA_DTYPE::F32:
Expand Down Expand Up @@ -183,40 +204,38 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
}

void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
// TODO(zhe): elegant impl.
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip",
"Qbits: only support Integer WOQ in PACKQ");

if (p->compute_type == "int8") {
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip",
"Qbits: only support Integer weight-type with int8 compute-type");
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
}
if (dispatcher_utils::check_avx512_vnni() &&
p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>, BTLA_ISA::AVX2>(p, ctx, task);
return parse_prob<bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize,
", ISA support avx2:", dispatcher_utils::check_avx2());
}
if (p->compute_type == "fp32") {
if (dispatcher_utils::check_avx512f()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
return parse_prob<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
}
if (dispatcher_utils::check_avx2()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx2<24, 4>, BTLA_ISA::AVX2>(p, ctx, task);
return parse_prob<bestla::gemm::SCoreRowNAvx2<24, 4>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32");
}
if (p->compute_type == "bf16") {
if (dispatcher_utils::check_amx()) {
return execute_qpack<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
return parse_prob<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support AMX-BF16 when compute_type==bf16");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ concept quant_PrologueA = requires {
requires !std::is_same_v<T, bestla::utils::bf16>;
};

template <class GemmCore>
constexpr bool is_int8_cmpt_gemmcore() {
return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>;
}

template <class Launcher>
void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start();
Expand Down Expand Up @@ -133,7 +127,7 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) {
using StorageWeight = typename Launcher::PrologueB::StorageWeight;
size_t asym_size = 0, shuf_size = 0;
int8_t* tmpbuf = nullptr;
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
using Parallel = bestla::parallel::gemm::SchedulerKBlockS<GemmCore>;
bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k, p->blocksize);
StorageWeight* packedw = dynamic_cast<StorageWeight*>(ctx->deseries_wei);
Expand Down Expand Up @@ -236,7 +230,7 @@ void execute_task(woq_config_param* p, woq_runtime_ctx* ctx) {
template <WOQ_TASK TASK, class GemmCore, template <class _T, BTLA_ISA> class PrologueB,
template <class _T, BTLA_ISA> class PrologueA, template <BTLA_ISA> class Epilogue>
void parse_launcher(woq_config_param* p, woq_runtime_ctx* ctx) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
using Launcher = bestla::wrapper::gemm::LauncherIntKBlock<GemmCore::ISA, GemmCore, PrologueA, PrologueB, Epilogue>;
return execute_task<TASK, Launcher>(p, ctx);
} else {
Expand All @@ -260,7 +254,7 @@ template <WOQ_TASK TASK, class GemmCore, template <class _T, BTLA_ISA> class Pro
void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
using namespace bestla::prologue_a::gemm;
if (p->src_dt == dispatcher_utils::QBITS_FP32) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
return parse_store<TASK, GemmCore, PrologueB, ShuffleActivationKBlockQuantizeF32, dispatcher_utils::QBITS_FP32>(
p, ctx);
} else {
Expand All @@ -269,7 +263,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
}
}
if (p->src_dt == dispatcher_utils::QBITS_BF16) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>()) {
return parse_store<TASK, GemmCore, PrologueB, ShuffleActivationKBlockQuantizeBf16, dispatcher_utils::QBITS_BF16>(
p, ctx);
} else {
Expand All @@ -289,7 +283,7 @@ void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" ||
p->weight_type == "fp8_e4m3" || p->weight_type == "fp8_e5m2") {
TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization.");
if constexpr (!is_int8_cmpt_gemmcore<GemmCore>())
if constexpr (!dispatcher_utils::is_int8_cmpt_gemmcore<GemmCore>())
return parse_activation<TASK, GemmCore, WeightKBlockNFloat>(p, ctx);
}
TORCH_CHECK(false,
Expand Down

0 comments on commit 3fd99c8

Please sign in to comment.