Skip to content

Commit

Permalink
WaveValue respects warp size
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalin committed Jul 7, 2024
1 parent dacfb18 commit 633a096
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
34 changes: 30 additions & 4 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ impl Value for u32 {
#[derive(Debug, Clone, Copy)]
pub struct WaveValue {
pub value: u32,
pub warp_size: usize,
pub default_lane: Option<usize>,
pub mutations: Option<[bool; 32]>,
}
impl WaveValue {
pub fn new(value: u32) -> Self {
pub fn new(value: u32, warp_size: usize) -> Self {
Self {
value,
warp_size,
default_lane: None,
mutations: None,
}
Expand All @@ -92,7 +94,7 @@ impl WaveValue {
}
pub fn apply_muts(&mut self) {
self.value = 0;
for lane in 0..32 {
for lane in 0..self.warp_size {
if self.mutations.unwrap()[lane] {
self.value |= 1 << lane;
}
Expand All @@ -106,16 +108,40 @@ mod test_state {

#[test]
fn test_wave_value() {
let mut val = WaveValue::new(0b11000000000000011111111111101110);
let mut val = WaveValue::new(0b11000000000000011111111111101110, 32);
val.default_lane = Some(0);
assert!(!val.read());
val.default_lane = Some(31);
assert!(val.read());
}

#[test]
fn test_wave_value_small() {
let mut val = WaveValue::new(0, 1);
val.default_lane = Some(0);
assert!(!val.read());
assert_eq!(val.value, 0);
val.set_lane(true);
val.apply_muts();
assert!(val.read());
assert_eq!(val.value, 1);
}

#[test]
fn test_wave_value_small_alt() {
let mut val = WaveValue::new(0, 2);
val.default_lane = Some(0);
assert!(!val.read());
assert_eq!(val.value, 0);
val.set_lane(true);
val.apply_muts();
assert!(val.read());
assert_eq!(val.value, 1);
}

#[test]
fn test_wave_value_mutations() {
let mut val = WaveValue::new(0b10001);
let mut val = WaveValue::new(0b10001, 32);
val.default_lane = Some(0);
val.set_lane(false);
assert!(val.mutations.unwrap().iter().all(|x| !x));
Expand Down
13 changes: 8 additions & 5 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub struct Thread<'a> {
pub stream: Vec<u32>,
pub simm: Option<u32>,
pub sgpr_co: &'a mut Option<(usize, WaveValue)>,
pub warp_size: usize,
pub scalar: bool,
}

Expand Down Expand Up @@ -992,7 +993,7 @@ impl<'a> Thread<'a> {
let sdst = ((instr >> 8) & 0x7f) as usize;
let f = |i: u32| -> usize { ((instr >> i) & 0x1ff) as usize };
let (s0, s1, s2) = (f(32), f(41), f(50));
let mut carry_in = WaveValue::new(self.val(s2));
let mut carry_in = WaveValue::new(self.val(s2), self.warp_size);
carry_in.default_lane = self.vcc.default_lane;
let omod = (instr >> 59) & 0x3;
let _neg = (instr >> 61) & 0x7;
Expand Down Expand Up @@ -1352,7 +1353,8 @@ impl<'a> Thread<'a> {
796 => s0 * 2f32.powi(s1.to_bits() as i32),
// cnd_mask isn't a float only ALU but supports neg
257 => {
let mut cond = WaveValue::new(s2.to_bits());
let mut cond =
WaveValue::new(s2.to_bits(), self.warp_size);
cond.default_lane = self.vcc.default_lane;
match cond.read() {
true => s1,
Expand Down Expand Up @@ -1795,7 +1797,7 @@ impl<'a> Thread<'a> {
let mut wv = self
.sgpr_co
.map(|(_, wv)| wv)
.unwrap_or_else(|| WaveValue::new(0));
.unwrap_or_else(|| WaveValue::new(0, self.warp_size));
wv.default_lane = self.vcc.default_lane;
wv.set_lane(val);
*self.sgpr_co = Some((idx, wv));
Expand Down Expand Up @@ -3777,8 +3779,8 @@ fn _helper_test_thread() -> Thread<'static> {
let static_sgpr: &'static mut Vec<u32> = Box::leak(Box::new(vec![0; 256]));
let static_vgpr: &'static mut VGPR = Box::leak(Box::new(VGPR::new()));
let static_scc: &'static mut u32 = Box::leak(Box::new(0));
let static_exec: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(u32::MAX)));
let static_vcc: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(0)));
let static_exec: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(u32::MAX, 32)));
let static_vcc: &'static mut WaveValue = Box::leak(Box::new(WaveValue::new(0, 32)));
let static_sds: &'static mut VecDataStore = Box::leak(Box::new(VecDataStore::new()));
let static_co: &'static mut Option<(usize, WaveValue)> = Box::leak(Box::new(None));

Expand All @@ -3794,6 +3796,7 @@ fn _helper_test_thread() -> Thread<'static> {
pc_offset: 0,
stream: vec![],
sgpr_co: static_co,
warp_size: 32,
scalar: false,
};
thread.vec_reg.default_lane = Some(0);
Expand Down
7 changes: 6 additions & 1 deletion src/work_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ impl<'a> WorkGroup<'a> {
};
let (mut vec_reg, mut vcc, mut exec) = match wave_state {
Some(val) => (val.2.clone(), val.3.clone(), val.4.clone()),
_ => (VGPR::new(), WaveValue::new(0), WaveValue::new(u32::MAX)),
_ => (
VGPR::new(),
WaveValue::new(0, threads.len()),
WaveValue::new((1 << threads.len()) - 1, threads.len()),
),
};

let mut seeded_lanes = vec![];
Expand Down Expand Up @@ -167,6 +171,7 @@ impl<'a> WorkGroup<'a> {
stream: self.kernel[pc..self.kernel.len()].to_vec(),
scalar: false,
simm: None,
warp_size: threads.len(),
sgpr_co: &mut sgpr_co,
};
thread.interpret()?;
Expand Down

0 comments on commit 633a096

Please sign in to comment.