Skip to content

Commit

Permalink
WIP - Tweak FMA kernel size from 8x8 to 6x16
Browse files Browse the repository at this point in the history
These values were taken from the Blis kernel configuration for Haswell [1],
Setting NR to use 2 AVX registers takes advantage of the 2 FMA execution ports
[2].

[1] https://github.com/flame/blis/blob/f956b79922da412791e4c8b8b846b3aafc0a5ee0/kernels/haswell/bli_kernels_haswell.h#L55

[2] bluss/matrixmultiply#34 (comment)
  • Loading branch information
robertknight committed Jan 3, 2023
1 parent 0c86683 commit 84909b6
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions src/linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ fn blocks(start: usize, end: usize, step: usize) -> BlockIter {
/// The tile size depends upon the kernel and is specified by the `MR` and `NR`
/// associated constants. The MR and NR values are chosen such that an `MR * NR`
/// tile can fit in registers. NR is generally determined by the width of the
/// registers used (eg. for SSE, 128 bits = 4x32 floats, so NR is 4) and MR by
/// the number available.
/// registers used (eg. for SSE, 128 bits = 4x32 floats, so NR will be a multiple
/// of 4) and MR by the number available.
///
/// The kernel corresponds to Loop 6 of the algorithm in Page 4 of
/// https://dl.acm.org/doi/pdf/10.1145/2925987.
Expand Down Expand Up @@ -141,8 +141,11 @@ struct FMAKernel {}

#[cfg(target_arch = "x86_64")]
impl Kernel for FMAKernel {
const MR: usize = 8;
const NR: usize = 8; // AVX registers are 256 bits wide = 8 x f32
const MR: usize = 6;

// Chosen to fit 2 AVX registers and take advantage of the two FMA
// execution ports.
const NR: usize = 16;

fn supported() -> bool {
is_x86_feature_detected!("fma")
Expand Down Expand Up @@ -177,9 +180,10 @@ impl FMAKernel {
const NR: usize = FMAKernel::NR;

use core::arch::x86_64::{
_mm256_add_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps,
__m256, _mm256_add_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps,
_mm256_setzero_ps, _mm256_storeu_ps,
};
use std::mem::size_of;

// Check that buffer accesses below are going to be valid.
assert!(a.len() >= depth * MR);
Expand All @@ -189,40 +193,55 @@ impl FMAKernel {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();

// Accumulate into a fixed-sized array to allow the compiler to generate
// more efficient code for the loop over `depth`.
let mut tmp = [_mm256_setzero_ps(); MR];
// NR=16, so two AVX registers are required for the inner array.
let mut tmp = [[_mm256_setzero_ps(); 2]; MR];
let reg_size = size_of::<__m256>() / size_of::<f32>();

for k in 0..depth {
let a_off = k * MR;
let b_off = k * NR;

let b_row = _mm256_loadu_ps(b_ptr.add(b_off));
let b_row_0 = _mm256_loadu_ps(b_ptr.add(b_off));
let b_row_1 = _mm256_loadu_ps(b_ptr.add(b_off + reg_size));

for i in 0..MR {
let a_val = *a_ptr.add(a_off + i);
let a_broadcast = _mm256_set1_ps(a_val);
tmp[i] = _mm256_fmadd_ps(a_broadcast, b_row, tmp[i]);

tmp[i][0] = _mm256_fmadd_ps(a_broadcast, b_row_0, tmp[i][0]);
tmp[i][1] = _mm256_fmadd_ps(a_broadcast, b_row_1, tmp[i][1]);
}
}

if beta == 0. && alpha == 1. {
for i in 0..MR {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i);
_mm256_storeu_ps(out_ptr, tmp[i]);
_mm256_storeu_ps(out_ptr, tmp[i][0]);
_mm256_storeu_ps(out_ptr.add(reg_size), tmp[i][1]);
}
} else if beta == 1. && alpha == 1. {
for i in 0..MR {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i);
let out_val = _mm256_add_ps(_mm256_loadu_ps(out_ptr), tmp[i]);

let out_val = _mm256_add_ps(_mm256_loadu_ps(out_ptr), tmp[i][0]);
_mm256_storeu_ps(out_ptr, out_val);

let out_val = _mm256_add_ps(_mm256_loadu_ps(out_ptr.add(reg_size)), tmp[i][1]);
_mm256_storeu_ps(out_ptr.add(reg_size), out_val);
}
} else {
let alpha_broadcast = _mm256_set1_ps(alpha);
let beta_broadcast = _mm256_set1_ps(beta);
for i in 0..MR {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i);

let out_val = _mm256_mul_ps(_mm256_loadu_ps(out_ptr), beta_broadcast);
let out_val = _mm256_fmadd_ps(tmp[i], alpha_broadcast, out_val);
let out_val = _mm256_fmadd_ps(tmp[i][0], alpha_broadcast, out_val);
_mm256_storeu_ps(out_ptr, out_val);

let out_val = _mm256_mul_ps(_mm256_loadu_ps(out_ptr.add(reg_size)), beta_broadcast);
let out_val = _mm256_fmadd_ps(tmp[i][1], alpha_broadcast, out_val);
_mm256_storeu_ps(out_ptr.add(reg_size), out_val);
}
}
}
Expand Down

0 comments on commit 84909b6

Please sign in to comment.