Skip to content

Commit

Permalink
src: cpu: x64: augru: fwd, bwd: fix sse4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
igorsafo committed Sep 15, 2022
1 parent 65e270b commit d8ffd44
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/cpu/x64/rnn/jit_uni_gru_cell_postgemm_1_bwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ struct jit_uni_gru_cell_postgemm_part1_bwd : public jit_uni_rnn_postgemm {
// 1. compute dAttention -= dG0 * G
uni_vfnmadd231ps(diff_attn_acc, dG0, G0, tmp2);
// 2. Compute dG0 *= 1 - Attention
uni_vsubps(tmp1, one_vmm, attn);
uni_vsubps(tmp1, one_vmm, attn, tmp2);
uni_vmulps(dG0, dG0, tmp1);
}

Expand Down
7 changes: 4 additions & 3 deletions src/cpu/x64/rnn/jit_uni_gru_cell_postgemm_2_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct jit_uni_gru_cell_postgemm_part2_fwd : public jit_uni_rnn_postgemm {
}
const Vmm tmp1_vmm = Vmm(9);
const Vmm tmp2_vmm = Vmm(10);
const Vmm tmp3_vmm = Vmm(11);

void generate() override {
using namespace Xbyak;
Expand Down Expand Up @@ -219,8 +220,8 @@ struct jit_uni_gru_cell_postgemm_part2_fwd : public jit_uni_rnn_postgemm {
scratch_dt_size);
uni_vbroadcastss(tmp2_vmm, tmp2s_vmm);
// G01 = (1 - a) * G0
compute_vsubps(
tmp2_vmm, tmp1_vmm, tmp2_vmm, current_vlen);
compute_vsubps(tmp2_vmm, tmp1_vmm, tmp2_vmm, tmp3_vmm,
current_vlen);
compute_vmulps(G0(loop_ur_idx), G0(loop_ur_idx),
tmp2_vmm, current_vlen);
to_float(tmp2_vmm,
Expand All @@ -232,7 +233,7 @@ struct jit_uni_gru_cell_postgemm_part2_fwd : public jit_uni_rnn_postgemm {
current_vlen);
// tmp1 = G2 * tmp1
compute_vmulps(tmp1_vmm, G2(loop_ur_idx), tmp1_vmm,
current_vlen);
tmp3_vmm, current_vlen);
// states_t_l = G01 * states_tm1_l + tmp1
compute_vfmadd213ps(G0(loop_ur_idx), tmp2_vmm, tmp1_vmm,
current_vlen);
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/rnn/jit_uni_gru_lbr_cell_postgemm_bwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ struct jit_uni_gru_lbr_cell_postgemm_bwd : public jit_uni_rnn_postgemm {
// 1. compute dAttention = -dG0 * G
uni_vfnmadd231ps(diff_attn_acc, dG0, G0, tmp2);
// 2. Compute dG0 *= 1 - Attention
uni_vsubps(tmp1, one_vmm, attn);
uni_vsubps(tmp1, one_vmm, attn, tmp2);
uni_vmulps(dG0, dG0, tmp1);
}
// compute dG2
Expand Down
8 changes: 5 additions & 3 deletions src/cpu/x64/rnn/jit_uni_gru_lbr_cell_postgemm_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct jit_uni_gru_lbr_cell_postgemm_fwd : public jit_uni_rnn_postgemm {
const Reg64 table_reg(rbx); // table is used for data scale and shifts

// We skip vmm0 as it can be used by the injector for masks on sse4.1
const Vmm G0(1), G1(2), G2(3), tmp1_vmm(5), tmp2_vmm(6);
const Vmm G0(1), G1(2), G2(3), tmp1_vmm(5), tmp2_vmm(6), tmp3_vmm(7);

// constant table map
const Address one_addr = ptr[table_reg];
Expand Down Expand Up @@ -189,12 +189,14 @@ struct jit_uni_gru_lbr_cell_postgemm_fwd : public jit_uni_rnn_postgemm {
scratch_dt_size);
uni_vbroadcastss(tmp2_vmm, tmp2s_vmm);
// G01 = (1 - a) * G0
compute_vsubps(tmp2_vmm, tmp1_vmm, tmp2_vmm, current_vlen);
compute_vsubps(tmp2_vmm, tmp1_vmm, tmp2_vmm, tmp3_vmm,
current_vlen);
compute_vmulps(G0, G0, tmp2_vmm, current_vlen);
// tmp1 = 1 - G01
compute_vsubps(tmp1_vmm, tmp1_vmm, G0, current_vlen);
// tmp1 = G2 * tmp1
compute_vmulps(tmp1_vmm, G2, tmp1_vmm, current_vlen);
compute_vmulps(
tmp1_vmm, G2, tmp1_vmm, tmp3_vmm, current_vlen);
// states_t_l = G01 * states_tm1_l + tmp2
to_float(tmp2_vmm, ptr[addr_states_tm1_l_reg], src_data_t,
current_vlen);
Expand Down
22 changes: 22 additions & 0 deletions src/cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,17 @@ struct jit_uni_rnn_postgemm : public jit_generator {
uni_vsubps(v1, v2, v3);
}

template <typename Vmm>
void compute_vsubps(const Vmm &v1, const Vmm &v2, const Vmm &v3,
const Vmm &buf, int vlen_bytes) {
if (vlen_bytes == 4)
// special case for scalar-based tail processing
uni_vsubss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
Xbyak::Xmm(v3.getIdx()), Xbyak::Xmm(buf.getIdx()));
else
uni_vsubps(v1, v2, v3, buf);
}

template <typename Vmm>
void compute_vmulps(
const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) {
Expand All @@ -761,6 +772,17 @@ struct jit_uni_rnn_postgemm : public jit_generator {
uni_vmulps(v1, v2, v3);
}

template <typename Vmm>
void compute_vmulps(const Vmm &v1, const Vmm &v2, const Vmm &v3,
const Vmm &buf, int vlen_bytes) {
if (vlen_bytes == 4)
// special case for scalar-based tail processing
uni_vmulss(Xbyak::Xmm(v1.getIdx()), Xbyak::Xmm(v2.getIdx()),
Xbyak::Xmm(v3.getIdx()), Xbyak::Xmm(buf.getIdx()));
else
uni_vmulps(v1, v2, v3, buf);
}

template <typename Vmm>
void compute_vfmadd231ps(
const Vmm &v1, const Vmm &v2, const Vmm &v3, int vlen_bytes) {
Expand Down

0 comments on commit d8ffd44

Please sign in to comment.