Skip to content

Commit

Permalink
xe: ocl: fix gemm_with_po int accumulation type
Browse files Browse the repository at this point in the history
  • Loading branch information
rjoursler committed Jan 10, 2025
1 parent d7861b9 commit 26fc01d
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,17 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx(
kernel_ctx, memory_desc_info_t::create(dst_md(0)), "DST", false);

int ndims = src_info.ndims;
bool is_int8 = src_md(1)->data_type == data_type::s8;
kernel_ctx.set_data_type(c_type);
//here SRC is output tensor of gemm call
def_data_type(kernel_ctx, is_int8 ? data_type::f32 : desc_.acc_type, "ACC");

auto is_int_type = [](data_type_t t) {
return utils::one_of(t, data_type::s8, data_type::u8, data_type::s32);
};
data_type_t acc_type = desc_.acc_type;
if (desc_.acc_type == data_type::s32
&& !(is_int_type(bias_info.data_type)
&& is_int_type(dst_md(0)->data_type)))
acc_type = data_type::f32;
def_data_type(kernel_ctx, acc_type, "ACC");

kernel_ctx.define_int("NDIMS", ndims);
CHECK(def_attr_info(
Expand Down

0 comments on commit 26fc01d

Please sign in to comment.