Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: i32 gemm experiment #28

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions benches/benchmarks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extern crate matrixmultiply;
pub use matrixmultiply::sgemm;
pub use matrixmultiply::dgemm;
pub use matrixmultiply::igemm;

#[macro_use]
extern crate bencher;
Expand All @@ -10,7 +11,14 @@ extern crate bencher;
// by flop / s = 2 M N K / time


benchmark_main!(mat_mul_f32, mat_mul_f64, layout_f32_032, layout_f64_032);
benchmark_main!(
mat_mul_f32,
mat_mul_f64,
mat_mul_i32,
layout_f32_032,
layout_f64_032,
layout_i32_032
);

macro_rules! mat_mul {
($modname:ident, $gemm:ident, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => {
Expand All @@ -20,17 +28,17 @@ macro_rules! mat_mul {
$(
pub fn $name(bench: &mut Bencher)
{
let a = vec![0.; $m * $n];
let b = vec![0.; $n * $k];
let mut c = vec![0.; $m * $k];
let a = vec![0 as _; $m * $n];
let b = vec![0 as _; $n * $k];
let mut c = vec![0 as _; $m * $k];
bench.iter(|| {
unsafe {
$gemm(
$m, $n, $k,
1.,
1 as _,
a.as_ptr(), $n, 1,
b.as_ptr(), $k, 1,
0.,
0 as _,
c.as_mut_ptr(), $k, 1,
)
}
Expand Down Expand Up @@ -106,20 +114,20 @@ macro_rules! gemm_layout {

fn base(bench: &mut Bencher, al: Layout, bl: Layout, cl: Layout)
{
let a = vec![0.; $m * $m];
let b = vec![0.; $m * $m];
let mut c = vec![0.; $m * $m];
let a = vec![0 as _; $m * $m];
let b = vec![0 as _; $m * $m];
let mut c = vec![0 as _; $m * $m];
let (rsa, csa) = al.strides($m, 1);
let (rsb, csb) = bl.strides($m, 1);
let (rsc, csc) = cl.strides($m, 1);
bench.iter(|| {
unsafe {
$gemm(
$m, $m, $m,
1.,
1 as _,
a.as_ptr(), rsa, csa,
b.as_ptr(), rsb, csb,
0.,
0 as _,
c.as_mut_ptr(), rsc, csc,
)
}
Expand Down Expand Up @@ -157,6 +165,10 @@ gemm_layout!{layout_f64_032, dgemm,
(m032, 32)
}

gemm_layout!{layout_i32_032, igemm,
(m032, 32)
}


use std::ops::{Add, Mul};

Expand Down Expand Up @@ -219,3 +231,22 @@ ref_mat_mul!{ref_mat_mul_f32, f32,
(m032, 32, 32, 32)
(m064, 64, 64, 64)
}

mat_mul!{mat_mul_i32, igemm,
(m004, 4, 4, 4)
(m006, 6, 6, 6)
(m008, 8, 8, 8)
(m012, 12, 12, 12)
(m016, 16, 16, 16)
(m032, 32, 32, 32)
(m064, 64, 64, 64)
(m127, 127, 127, 127)
/*
(m256, 256, 256, 256)
(m512, 512, 512, 512)
(mix16x4, 32, 4, 32)
(mix32x2, 32, 2, 32)
(mix97, 97, 97, 125)
(mix128x10000x128, 128, 10000, 128)
*/
}
18 changes: 18 additions & 0 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use kernel::GemmKernel;
use kernel::Element;
use sgemm_kernel;
use dgemm_kernel;
use igemm_kernel;
use rawpointer::PointerExt;

/// General matrix multiplication (f32)
Expand Down Expand Up @@ -87,6 +88,23 @@ pub unsafe fn dgemm(
c, rsc, csc)
}

pub unsafe fn igemm(
m: usize, k: usize, n: usize,
alpha: i32,
a: *const i32, rsa: isize, csa: isize,
b: *const i32, rsb: isize, csb: isize,
beta: i32,
c: *mut i32, rsc: isize, csc: isize)
{
gemm_loop::<igemm_kernel::Gemm>(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}

/// Ensure that GemmKernel parameters are supported
/// (alignment, microkernel size).
///
Expand Down
219 changes: 219 additions & 0 deletions src/igemm_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// Copyright 2016 - 2018 Ulrik Sverdrup "bluss"
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use kernel::GemmKernel;
use kernel::Element;
use archparam;


#[cfg(target_arch="x86")]
use std::arch::x86::*;
#[cfg(target_arch="x86_64")]
use std::arch::x86_64::*;

pub enum Gemm { }

pub type T = i32;

const MR: usize = 8;
const NR: usize = 4;

macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; }
macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }

impl GemmKernel for Gemm {
type Elem = T;

#[inline(always)]
fn align_to() -> usize { 16 }

#[inline(always)]
fn mr() -> usize { MR }
#[inline(always)]
fn nr() -> usize { NR }

#[inline(always)]
fn always_masked() -> bool { true }

#[inline(always)]
fn nc() -> usize { archparam::S_NC }
#[inline(always)]
fn kc() -> usize { archparam::S_KC }
#[inline(always)]
fn mc() -> usize { archparam::S_MC }

#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize) {
kernel(k, alpha, a, b, beta, c, rsc, csc)
}
}

/// matrix multiplication kernel
///
/// This does the matrix multiplication:
///
/// C ← α A B + β C
///
/// + k: length of data in a, b
/// + a, b are packed
/// + c has general strides
/// + rsc: row stride of c
/// + csc: col stride of c
/// + if beta is 0, then c does not need to be initialized
#[inline(never)]
pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
// dispatch to specific compiled versions
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
{
if is_x86_feature_detected_!("avx") {
return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc);
} else if is_x86_feature_detected_!("sse2") {
return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc);
}
}
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc);
}

#[inline]
#[target_feature(enable="avx")]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
}

#[inline]
#[target_feature(enable="sse2")]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
}


#[inline(always)]
unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
let mut ab: [[T; NR]; MR] = [[0; NR]; MR];
let mut a = a;
let mut b = b;
debug_assert_eq!(beta, 0);

// Compute A B into ab[i][j]
unroll_by!(4 => k, {
loop_m!(i, loop_n!(j, {
ab[i][j] = ab[i][j].wrapping_add(at(a, i).wrapping_mul(at(b, j)));
}));

a = a.offset(MR as isize);
b = b.offset(NR as isize);
});

macro_rules! c {
($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
}

// set C = α A B + β C
loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j])));
}

#[inline(always)]
unsafe fn at(ptr: *const T, i: usize) -> T {
*ptr.offset(i as isize)
}

#[cfg(test)]
mod tests {
use super::*;
use aligned_alloc::Alloc;

fn aligned_alloc<T>(elt: T, n: usize) -> Alloc<T> where T: Copy
{
unsafe {
Alloc::new(n, Gemm::align_to()).init_with(elt)
}
}

use super::T;
type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize);

fn test_a_kernel(_name: &str, kernel_fn: KernelFn) {
const K: usize = 4;
let mut a = aligned_alloc(1, MR * K);
let mut b = aligned_alloc(0, NR * K);
for (i, x) in a.iter_mut().enumerate() {
*x = i as _;
}

for i in 0..K {
b[i + i * NR] = 1;
}
let mut c = [0; MR * NR];
unsafe {
kernel_fn(K, 1, &a[0], &b[0], 0, &mut c[0], 1, MR as isize);
// col major C
}
assert_eq!(&a[..], &c[..a.len()]);
}

#[test]
fn test_native_kernel() {
test_a_kernel("kernel", kernel);
}

#[test]
fn test_kernel_fallback_impl() {
test_a_kernel("kernel", kernel_fallback_impl);
}

#[test]
fn test_loop_m_n() {
let mut m = [[0; NR]; MR];
loop_m!(i, loop_n!(j, m[i][j] += 1));
for arr in &m[..] {
for elt in &arr[..] {
assert_eq!(*elt, 1);
}
}
}

mod test_arch_kernels {
use super::test_a_kernel;
macro_rules! test_arch_kernels_x86 {
($($feature_name:tt, $function_name:ident),*) => {
$(
#[test]
fn $function_name() {
if is_x86_feature_detected_!($feature_name) {
test_a_kernel(stringify!($function_name), super::super::$function_name);
} else {
println!("Skipping, host does not have feature: {:?}", $feature_name);
}
}
)*
}
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
test_arch_kernels_x86! {
"avx", kernel_target_avx,
"sse2", kernel_target_sse2
}
}
}
12 changes: 12 additions & 0 deletions src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,15 @@ impl Element for f64 {
*self += alpha * a;
}
}

impl Element for i32 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
fn is_zero(&self) -> bool { *self == 0 }
fn scale_by(&mut self, x: Self) {
*self = self.wrapping_mul(x);
}
fn scaled_add(&mut self, alpha: Self, a: Self) {
*self = self.wrapping_add(alpha.wrapping_mul(a));
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ mod kernel;
mod gemm;
mod sgemm_kernel;
mod dgemm_kernel;
mod igemm_kernel;
mod util;
mod aligned_alloc;

pub use gemm::sgemm;
pub use gemm::dgemm;
pub use gemm::igemm;
Loading