From c5ff175703d1ff5ba669f2d3eefc0ee6f538ccd8 Mon Sep 17 00:00:00 2001 From: Roy Oursler Date: Fri, 10 Jan 2025 14:32:05 -0800 Subject: [PATCH] xe: ocl: fix gemm_with_po int accumulation type --- src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp index 36a52fa40c0..c164c91707b 100644 --- a/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp +++ b/src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp @@ -155,14 +155,8 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx( kernel_ctx, memory_desc_info_t::create(dst_md(0)), "DST", false); int ndims = src_info.ndims; - bool is_int8 = src_md(1)->data_type == data_type::s8; kernel_ctx.set_data_type(c_type); - //here SRC is output tensor of gemm call - def_data_type(kernel_ctx, is_int8 ? data_type::f32 : desc_.acc_type, "ACC"); - kernel_ctx.define_int("NDIMS", ndims); - CHECK(def_attr_info( - kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md())); const auto &attr_scales = attr()->scales_; const bool with_src_scales = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); @@ -170,6 +164,22 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx( = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); const bool with_dst_scales = !attr_scales.get(DNNL_ARG_DST).has_default_values(); + auto is_int_type = [](data_type_t t) { + return utils::one_of(t, data_type::s8, data_type::u8, data_type::s32); + }; + data_type_t acc_type = desc_.acc_type; + if (desc_.acc_type == data_type::s32) { + if (with_src_scales || with_wei_scales + || !is_int_type(bias_info.data_type) + || !is_int_type(dst_md(0)->data_type)) { + acc_type = data_type::f32; + } + } + def_data_type(kernel_ctx, acc_type, "ACC"); + + kernel_ctx.define_int("NDIMS", ndims); + CHECK(def_attr_info( + kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md())); kernel_ctx.define_int("A_SCALES", with_src_scales); kernel_ctx.define_int("B_SCALES", with_wei_scales); kernel_ctx.define_int("C_SCALES", with_dst_scales);