diff --git a/src/linalg.rs b/src/linalg.rs index 18549b59..6f0fdcab 100644 --- a/src/linalg.rs +++ b/src/linalg.rs @@ -106,8 +106,8 @@ fn blocks(start: usize, end: usize, step: usize) -> BlockIter { /// The tile size depends upon the kernel and is specified by the `MR` and `NR` /// associated constants. The MR and NR values are chosen such that an `MR * NR` /// tile can fit in registers. NR is generally determined by the width of the -/// registers used (eg. for SSE, 128 bits = 4x32 floats, so NR is 4) and MR by -/// the number available. +/// registers used (eg. for SSE, 128 bits = 4x32 floats, so NR will be a multiple +/// of 4) and MR by the number available. /// /// The kernel corresponds to Loop 6 of the algorithm in Page 4 of /// https://dl.acm.org/doi/pdf/10.1145/2925987. @@ -141,8 +141,11 @@ struct FMAKernel {} #[cfg(target_arch = "x86_64")] impl Kernel for FMAKernel { - const MR: usize = 8; - const NR: usize = 8; // AVX registers are 256 bits wide = 8 x f32 + const MR: usize = 6; + + // Chosen to fit 2 AVX registers and take advantage of the two FMA + // execution ports. + const NR: usize = 16; fn supported() -> bool { is_x86_feature_detected!("fma") @@ -177,9 +180,10 @@ impl FMAKernel { const NR: usize = FMAKernel::NR; use core::arch::x86_64::{ - _mm256_add_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps, + __m256, _mm256_add_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps, _mm256_setzero_ps, _mm256_storeu_ps, }; + use std::mem::size_of; // Check that buffer accesses below are going to be valid. assert!(a.len() >= depth * MR); @@ -189,40 +193,55 @@ impl FMAKernel { let a_ptr = a.as_ptr(); let b_ptr = b.as_ptr(); - // Accumulate into a fixed-sized array to allow the compiler to generate - // more efficient code for the loop over `depth`. - let mut tmp = [_mm256_setzero_ps(); MR]; + // NR=16, so two AVX registers are required for the inner array. + let mut tmp = [[_mm256_setzero_ps(); 2]; MR]; + let reg_size = size_of::<__m256>() / size_of::(); + for k in 0..depth { let a_off = k * MR; let b_off = k * NR; - let b_row = _mm256_loadu_ps(b_ptr.add(b_off)); + let b_row_0 = _mm256_loadu_ps(b_ptr.add(b_off)); + let b_row_1 = _mm256_loadu_ps(b_ptr.add(b_off + reg_size)); + for i in 0..MR { let a_val = *a_ptr.add(a_off + i); let a_broadcast = _mm256_set1_ps(a_val); - tmp[i] = _mm256_fmadd_ps(a_broadcast, b_row, tmp[i]); + + tmp[i][0] = _mm256_fmadd_ps(a_broadcast, b_row_0, tmp[i][0]); + tmp[i][1] = _mm256_fmadd_ps(a_broadcast, b_row_1, tmp[i][1]); } } if beta == 0. && alpha == 1. { for i in 0..MR { let out_ptr = out.as_mut_ptr().add(out_row_stride * i); - _mm256_storeu_ps(out_ptr, tmp[i]); + _mm256_storeu_ps(out_ptr, tmp[i][0]); + _mm256_storeu_ps(out_ptr.add(reg_size), tmp[i][1]); } } else if beta == 1. && alpha == 1. { for i in 0..MR { let out_ptr = out.as_mut_ptr().add(out_row_stride * i); - let out_val = _mm256_add_ps(_mm256_loadu_ps(out_ptr), tmp[i]); + + let out_val = _mm256_add_ps(_mm256_loadu_ps(out_ptr), tmp[i][0]); _mm256_storeu_ps(out_ptr, out_val); + + let out_val = _mm256_add_ps(_mm256_loadu_ps(out_ptr.add(reg_size)), tmp[i][1]); + _mm256_storeu_ps(out_ptr.add(reg_size), out_val); } } else { let alpha_broadcast = _mm256_set1_ps(alpha); let beta_broadcast = _mm256_set1_ps(beta); for i in 0..MR { let out_ptr = out.as_mut_ptr().add(out_row_stride * i); + let out_val = _mm256_mul_ps(_mm256_loadu_ps(out_ptr), beta_broadcast); - let out_val = _mm256_fmadd_ps(tmp[i], alpha_broadcast, out_val); + let out_val = _mm256_fmadd_ps(tmp[i][0], alpha_broadcast, out_val); _mm256_storeu_ps(out_ptr, out_val); + + let out_val = _mm256_mul_ps(_mm256_loadu_ps(out_ptr.add(reg_size)), beta_broadcast); + let out_val = _mm256_fmadd_ps(tmp[i][1], alpha_broadcast, out_val); + _mm256_storeu_ps(out_ptr.add(reg_size), out_val); } } }