Skip to content

Commit

Permalink
Change FMA kernel size from 8x8 to 6x16
Browse files Browse the repository at this point in the history
These values were taken from the Blis kernel configuration for Haswell [1],
Setting NR to use 2 AVX registers takes advantage of the 2 FMA execution ports
[2].

The clippy lint about the compiler optimizing away constant assertions
was disabled because that is exactly the behavior that we want. See also
rust-lang/rust-clippy#8159.

[1] https://github.com/flame/blis/blob/f956b79922da412791e4c8b8b846b3aafc0a5ee0/kernels/haswell/bli_kernels_haswell.h#L55

[2] bluss/matrixmultiply#34 (comment)
  • Loading branch information
robertknight committed Jan 4, 2023
1 parent 211e1f2 commit 5f93404
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ check:

.PHONY: lint
lint:
cargo clippy -- -Aclippy::needless_range_loop -Aclippy::too_many_arguments -Aclippy::derivable_impls -Aclippy::manual_memcpy
cargo clippy -- -Aclippy::needless_range_loop -Aclippy::too_many_arguments -Aclippy::derivable_impls -Aclippy::manual_memcpy -Aclippy::assertions_on_constants

.PHONY: wasm
wasm:
Expand Down
66 changes: 42 additions & 24 deletions src/linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,9 @@ fn blocks(start: usize, end: usize, step: usize) -> BlockIter {
///
/// The tile size depends upon the kernel and is specified by the `MR` and `NR`
/// associated constants. The MR and NR values are chosen such that an `MR * NR`
/// tile can fit in registers. NR is generally determined by the width of the
/// registers used (eg. for SSE, 128 bits = 4x32 floats, so NR is 4) and MR by
/// the number available.
/// tile can fit in registers.
///
/// The kernel corresponds to Loop 6 of the algorithm in Page 4 of
/// The kernel corresponds to Loop 6 (the "microkernel") in Page 4 of
/// https://dl.acm.org/doi/pdf/10.1145/2925987.
trait Kernel {
/// Height of output tiles computed by the kernel.
Expand Down Expand Up @@ -141,8 +139,11 @@ struct FMAKernel {}

#[cfg(target_arch = "x86_64")]
impl Kernel for FMAKernel {
const MR: usize = 8;
const NR: usize = 8; // AVX registers are 256 bits wide = 8 x f32
const MR: usize = 6;

// Chosen to fit 2 AVX registers and take advantage of the two FMA
// execution ports.
const NR: usize = 16;

fn supported() -> bool {
is_x86_feature_detected!("fma")
Expand Down Expand Up @@ -173,13 +174,18 @@ impl FMAKernel {
alpha: f32,
beta: f32,
) {
const MR: usize = FMAKernel::MR;
const NR: usize = FMAKernel::NR;

use core::arch::x86_64::{
_mm256_add_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps,
__m256, _mm256_add_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps,
_mm256_setzero_ps, _mm256_storeu_ps,
};
use std::mem::size_of;

const MR: usize = FMAKernel::MR;
const NR: usize = FMAKernel::NR;

const REG_SIZE: usize = size_of::<__m256>() / size_of::<f32>();
const NR_REGS: usize = NR / REG_SIZE;
assert!(NR % REG_SIZE == 0);

// Check that buffer accesses below are going to be valid.
assert!(a.len() >= depth * MR);
Expand All @@ -189,40 +195,52 @@ impl FMAKernel {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();

// Accumulate into a fixed-sized array to allow the compiler to generate
// more efficient code for the loop over `depth`.
let mut tmp = [_mm256_setzero_ps(); MR];
let mut tmp = [[_mm256_setzero_ps(); NR_REGS]; MR];
let mut b_rows = [_mm256_setzero_ps(); NR_REGS];

for k in 0..depth {
let a_off = k * MR;
let b_off = k * NR;

let b_row = _mm256_loadu_ps(b_ptr.add(b_off));
for i in 0..NR_REGS {
b_rows[i] = _mm256_loadu_ps(b_ptr.add(b_off + i * REG_SIZE));
}

for i in 0..MR {
let a_val = *a_ptr.add(a_off + i);
let a_broadcast = _mm256_set1_ps(a_val);
tmp[i] = _mm256_fmadd_ps(a_broadcast, b_row, tmp[i]);

for j in 0..NR_REGS {
tmp[i][j] = _mm256_fmadd_ps(a_broadcast, b_rows[j], tmp[i][j]);
}
}
}

if beta == 0. && alpha == 1. {
for i in 0..MR {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i);
_mm256_storeu_ps(out_ptr, tmp[i]);
for j in 0..NR_REGS {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i + j * REG_SIZE);
_mm256_storeu_ps(out_ptr, tmp[i][j]);
}
}
} else if beta == 1. && alpha == 1. {
for i in 0..MR {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i);
let out_val = _mm256_add_ps(_mm256_loadu_ps(out_ptr), tmp[i]);
_mm256_storeu_ps(out_ptr, out_val);
for j in 0..NR_REGS {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i + j * REG_SIZE);
let out_val = _mm256_add_ps(_mm256_loadu_ps(out_ptr), tmp[i][j]);
_mm256_storeu_ps(out_ptr, out_val);
}
}
} else {
let alpha_broadcast = _mm256_set1_ps(alpha);
let beta_broadcast = _mm256_set1_ps(beta);
for i in 0..MR {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i);
let out_val = _mm256_mul_ps(_mm256_loadu_ps(out_ptr), beta_broadcast);
let out_val = _mm256_fmadd_ps(tmp[i], alpha_broadcast, out_val);
_mm256_storeu_ps(out_ptr, out_val);
for j in 0..NR_REGS {
let out_ptr = out.as_mut_ptr().add(out_row_stride * i + j * REG_SIZE);
let out_val = _mm256_mul_ps(_mm256_loadu_ps(out_ptr), beta_broadcast);
let out_val = _mm256_fmadd_ps(tmp[i][j], alpha_broadcast, out_val);
_mm256_storeu_ps(out_ptr, out_val);
}
}
}
}
Expand Down

0 comments on commit 5f93404

Please sign in to comment.