diff --git a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp index 1f33cfc663b..784c512220f 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp +++ b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp @@ -14,7 +14,6 @@ #pragma once #include #include -#include "bestla/bestla_storage.h" #include "../include/dispatcher_utils.hpp" #include #include diff --git a/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp b/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp index 8a0c99b3b3a..05a8c718b26 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp +++ b/intel_extension_for_transformers/qbits/dispatcher/include/dispatcher_utils.hpp @@ -16,6 +16,7 @@ #include #include #include "bestla/bestla_device.h" +#include "bestla/bestla_storage.h" #include "bestla/bestla_utils.h" #include "bestla/bestla_parallel.h" namespace dispatcher_utils { @@ -26,6 +27,12 @@ inline bool check_avx_vnni() { return bestla::device::CpuDevice::getInstance()-> inline bool check_avx512f() { return bestla::device::CpuDevice::getInstance()->AVX512F(); } inline bool check_avx2() { return bestla::device::CpuDevice::getInstance()->AVX2(); } +template +constexpr bool is_int8_cmpt_gemmcore() { + return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || + GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v>; +} + class qbits_threading { public: static bestla::parallel::IThreading* get() { diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp index 399deaba7e0..cf6889a9f15 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp @@ -16,12 +16,19 @@ #include "../include/bestla_packq_impl.hpp" namespace woq { -template + +template void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { - using proB = bestla::prologue_b::gemm::WeightKBlockNInteger; static proB ker; - auto qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type), - scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym); + using WType = typename proB::StorageWeight; + WType qpackw(0); + if constexpr (std::is_same_v) { + qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type), + scale2bestladt_map.at(p->scale_type), BTLA_DTYPE::BF16, p->asym); + } else { + qpackw = ker.createStorage(ctx->n, ctx->k, p->blocksize, wei2bestladt_map.at(p->weight_type), + scale2bestladt_map.at(p->scale_type)); + } if (p->enable_act_shuffle) ker.enableShuffle(&qpackw); ctx->packw_size = qpackw.mSize; if (task == WOQ_GET_PACKW_SIZE) return; @@ -33,6 +40,20 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx p->asym ? ctx->zp->data_ptr() : nullptr, &qpackw, dispatcher_utils::qbits_threading::get()); } +template +void parse_prob(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { + if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || + p->weight_type == "int2_clip") { + return execute_qpack>(p, ctx, task); + } + if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1") { + TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization."); + return execute_qpack>(p, ctx, task); + } + TORCH_CHECK(false, "Qbits: unsupported bestla packq config, compute_type: " + p->compute_type + + " weight_type: " + p->weight_type); +} + std::string get_dtype_str(BTLA_DTYPE dtype) { switch (dtype) { case BTLA_DTYPE::F32: @@ -183,40 +204,38 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) { } void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { - // TODO(zhe): elegant impl. - TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || - p->weight_type == "int2_clip", - "Qbits: only support Integer WOQ in PACKQ"); - if (p->compute_type == "int8") { + TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || + p->weight_type == "int2_clip", + "Qbits: only support Integer weight-type with int8 compute-type"); if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AMX_INT8>(p, ctx, task); + return parse_prob, BTLA_ISA::AMX_INT8>(p, ctx, task); } if (dispatcher_utils::check_avx512_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX512_VNNI>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX512_VNNI>(p, ctx, task); } if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX_VNNI>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX_VNNI>(p, ctx, task); } if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) { - return execute_qpack, BTLA_ISA::AVX2>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX2>(p, ctx, task); } TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize, ", ISA support avx2:", dispatcher_utils::check_avx2()); } if (p->compute_type == "fp32") { if (dispatcher_utils::check_avx512f()) { - return execute_qpack, BTLA_ISA::AVX512F>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX512F>(p, ctx, task); } if (dispatcher_utils::check_avx2()) { - return execute_qpack, BTLA_ISA::AVX2>(p, ctx, task); + return parse_prob, BTLA_ISA::AVX2>(p, ctx, task); } TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32"); } if (p->compute_type == "bf16") { if (dispatcher_utils::check_amx()) { - return execute_qpack, BTLA_ISA::AMX_BF16>(p, ctx, task); + return parse_prob, BTLA_ISA::AMX_BF16>(p, ctx, task); } TORCH_CHECK(false, "Qbits: device ISA must support AMX-BF16 when compute_type==bf16"); } diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp index f9864ddece0..c04e652a4aa 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp @@ -43,12 +43,6 @@ concept quant_PrologueA = requires { requires !std::is_same_v; }; -template -constexpr bool is_int8_cmpt_gemmcore() { - return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI || - GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v>; -} - template void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) { if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start(); @@ -133,7 +127,7 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) { using StorageWeight = typename Launcher::PrologueB::StorageWeight; size_t asym_size = 0, shuf_size = 0; int8_t* tmpbuf = nullptr; - if constexpr (is_int8_cmpt_gemmcore()) { + if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore()) { using Parallel = bestla::parallel::gemm::SchedulerKBlockS; bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k, p->blocksize); StorageWeight* packedw = dynamic_cast(ctx->deseries_wei); @@ -236,7 +230,7 @@ void execute_task(woq_config_param* p, woq_runtime_ctx* ctx) { template class PrologueB, template class PrologueA, template class Epilogue> void parse_launcher(woq_config_param* p, woq_runtime_ctx* ctx) { - if constexpr (is_int8_cmpt_gemmcore()) { + if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore()) { using Launcher = bestla::wrapper::gemm::LauncherIntKBlock; return execute_task(p, ctx); } else { @@ -260,7 +254,7 @@ template class Pro void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) { using namespace bestla::prologue_a::gemm; if (p->src_dt == dispatcher_utils::QBITS_FP32) { - if constexpr (is_int8_cmpt_gemmcore()) { + if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore()) { return parse_store( p, ctx); } else { @@ -269,7 +263,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) { } } if (p->src_dt == dispatcher_utils::QBITS_BF16) { - if constexpr (is_int8_cmpt_gemmcore()) { + if constexpr (dispatcher_utils::is_int8_cmpt_gemmcore()) { return parse_store( p, ctx); } else { @@ -289,7 +283,7 @@ void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) { if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" || p->weight_type == "fp8_e4m3" || p->weight_type == "fp8_e5m2") { TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization."); - if constexpr (!is_int8_cmpt_gemmcore()) + if constexpr (!dispatcher_utils::is_int8_cmpt_gemmcore()) return parse_activation(p, ctx); } TORCH_CHECK(false,