Skip to content

Commit

Permalink
xe: ocl: fix gemm_with_po int accumulation type
Browse files Browse the repository at this point in the history
  • Loading branch information
rjoursler committed Jan 10, 2025
1 parent 765eced commit 40d9909
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,31 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx(
kernel_ctx, memory_desc_info_t::create(dst_md(0)), "DST", false);

int ndims = src_info.ndims;
bool is_int8 = src_md(1)->data_type == data_type::s8;
kernel_ctx.set_data_type(c_type);
//here SRC is output tensor of gemm call
def_data_type(kernel_ctx, is_int8 ? data_type::f32 : desc_.acc_type, "ACC");

kernel_ctx.define_int("NDIMS", ndims);
CHECK(def_attr_info(
kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md()));
const auto &attr_scales = attr()->scales_;
const bool with_src_scales
= !attr_scales.get(DNNL_ARG_SRC).has_default_values();
const bool with_wei_scales
= !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values();
const bool with_dst_scales
= !attr_scales.get(DNNL_ARG_DST).has_default_values();
auto is_int_type = [](data_type_t t) {
return utils::one_of(t, data_type::s8, data_type::u8, data_type::s32);
};
data_type_t acc_type = desc_.acc_type;
if (desc_.acc_type == data_type::s32) {
if (with_src_scales || with_wei_scales
|| !is_int_type(bias_info.data_type)
|| !is_int_type(dst_md(0)->data_type)) {
acc_type = data_type::f32;
}
}
def_data_type(kernel_ctx, acc_type, "ACC");

kernel_ctx.define_int("NDIMS", ndims);
CHECK(def_attr_info(
kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md()));
kernel_ctx.define_int("A_SCALES", with_src_scales);
kernel_ctx.define_int("B_SCALES", with_wei_scales);
kernel_ctx.define_int("C_SCALES", with_dst_scales);
Expand Down

0 comments on commit 40d9909

Please sign in to comment.