Skip to content

Commit

Permalink
Implement bitonic sort network for join compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Jul 17, 2024
1 parent 5b9c5be commit aa01552
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ pub enum StmtEnum<T> {
/// Binds an identifier to each value of an array expr, evaluating the body.
ForEachLoop(String, Expr<T>, Vec<Stmt<T>>),
/// Binds an identifier to each joined row of two tables, evaluating the body.
JoinLoop(String, (Expr<T>, Expr<T>), Vec<Stmt<T>>),
JoinLoop(String, T, (Expr<T>, Expr<T>), Vec<Stmt<T>>),
/// An expression (all expressions are statements, but not all statements expressions).
Expr(Expr<T>),
}
Expand Down
5 changes: 3 additions & 2 deletions src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ impl UntypedStmt {
meta,
))]);
}
let join_ty = tuple_a[0].clone();
let elem_ty = Type::Tuple(vec![elem_ty_a, elem_ty_b]);
let mut body_typed = Vec::with_capacity(body.len());
env.push();
Expand All @@ -771,7 +772,7 @@ impl UntypedStmt {
}
env.pop();
Ok(Stmt::new(
StmtEnum::JoinLoop(var.clone(), (a, b), body_typed),
StmtEnum::JoinLoop(var.clone(), join_ty, (a, b), body_typed),
meta,
))
}
Expand All @@ -791,7 +792,7 @@ impl UntypedStmt {
))
}
},
ast::StmtEnum::JoinLoop(_, _, _) => {
ast::StmtEnum::JoinLoop(_, _, _, _) => {
unreachable!("Untyped expressions should never be join loops")
}
}
Expand Down
33 changes: 33 additions & 0 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,39 @@ impl CircuitBuilder {
}
(acc_lt, acc_gt)
}

pub fn push_condswap(
&mut self,
s: GateIndex,
x: GateIndex,
y: GateIndex,
) -> (GateIndex, GateIndex) {
if x == y {
return (x, y);
}
let x_xor_y = self.push_xor(x, y);
let swap = self.push_and(x_xor_y, s);
let x_swapped = self.push_xor(x, swap);
let y_swapped = self.push_xor(y, swap);
(x_swapped, y_swapped)
}

pub fn push_sorter(
&mut self,
bits: usize,
x: &[GateIndex],
y: &[GateIndex],
) -> (Vec<GateIndex>, Vec<GateIndex>) {
let (_, gt) = self.push_comparator_circuit(bits, x, false, y, false);
let mut min = vec![];
let mut max = vec![];
for (x, y) in x.iter().zip(y.iter()) {
let (a, b) = self.push_condswap(gt, *x, *y);
min.push(a);
max.push(b);
}
(min, max)
}
}

fn unsigned_as_usize_bits(n: u64) -> [usize; USIZE_BITS] {
Expand Down
64 changes: 63 additions & 1 deletion src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,69 @@ impl TypedStmt {
env.pop();
vec![]
}
StmtEnum::JoinLoop(_, _, _) => {
StmtEnum::JoinLoop(var, join_ty, (a, b), body) => {
let (elem_bits_a, num_elems_a) = match &a.ty {
Type::Array(elem_ty, size) => (
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
*size,
),
Type::ArrayConst(elem_ty, size) => (
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
*circuit.const_sizes().get(size).unwrap(),
),
_ => panic!("Found a non-array value in an array access expr"),
};
let (elem_bits_b, num_elems_b) = match &b.ty {
Type::Array(elem_ty, size) => (
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
*size,
),
Type::ArrayConst(elem_ty, size) => (
elem_ty.size_in_bits_for_defs(prg, circuit.const_sizes()),
*circuit.const_sizes().get(size).unwrap(),
),
_ => panic!("Found a non-array value in an array access expr"),
};
let max_elem_bits = max(elem_bits_a, elem_bits_b);
let num_elems = num_elems_a + num_elems_b;
let join_ty_size = join_ty.size_in_bits_for_defs(prg, circuit.const_sizes());
let a = a.compile(prg, env, circuit);
let b = b.compile(prg, env, circuit);
let mut bitonic = vec![];
for i in 0..num_elems_a {
let mut v = a[i * elem_bits_a..(i + 1) * elem_bits_a].to_vec();
for _ in 0..(max_elem_bits - elem_bits_a) {
v.push(0);
}
bitonic.push(v);
}
for i in (0..num_elems_b).rev() {
let mut v = b[i * elem_bits_b..(i + 1) * elem_bits_b].to_vec();
for _ in 0..(max_elem_bits - elem_bits_b) {
v.push(0);
}
bitonic.push(v);
}
let mut offset = num_elems / 2;
while offset > 0 {
let mut result = vec![];
for _ in 0..num_elems {
result.push(vec![]);
}
let rounds = num_elems / 2 / offset;
for r in 0..rounds {
for i in 0..offset {
let i = i + r * offset * 2;
let x = &bitonic[i];
let y = &bitonic[i + offset];
let (min, max) = circuit.push_sorter(join_ty_size, x, y);
result[i] = min;
result[i + offset] = max;
}
}
offset /= 2;
bitonic = result;
}
todo!("compile join loop")
}
}
Expand Down
12 changes: 11 additions & 1 deletion tests/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,7 @@ pub fn main(array: [u16; MY_CONST]) -> u16 {
#[test]
fn compile_join_fn() -> Result<(), Error> {
let prg = "
pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 3]) -> u16 {
pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 4]) -> u16 {
let mut result = 0u16;
for row in join(rows1, rows2) {
let ((_, field1), (_, field2, field3)) = row;
Expand Down Expand Up @@ -2091,6 +2091,11 @@ pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 3]) -> u16
Literal::NumUnsigned(117, UnsignedNumType::U8),
Literal::NumUnsigned(120, UnsignedNumType::U8),
]);
let id_xxx = Literal::Array(vec![
Literal::NumUnsigned(120, UnsignedNumType::U8),
Literal::NumUnsigned(120, UnsignedNumType::U8),
Literal::NumUnsigned(120, UnsignedNumType::U8),
]);
eval.set_literal(Literal::Array(vec![
Literal::Tuple(vec![
id_aaa.clone(),
Expand Down Expand Up @@ -2126,6 +2131,11 @@ pub fn main(rows1: [([u8; 3], u16); 4], rows2: [([u8; 3], u16, u16); 3]) -> u16
Literal::NumUnsigned(8, UnsignedNumType::U16),
Literal::NumUnsigned(9, UnsignedNumType::U16),
]),
Literal::Tuple(vec![
id_xxx.clone(),
Literal::NumUnsigned(10, UnsignedNumType::U16),
Literal::NumUnsigned(11, UnsignedNumType::U16),
]),
]))
.unwrap();
let output = eval.run().map_err(|e| pretty_print(e, prg))?;
Expand Down

0 comments on commit aa01552

Please sign in to comment.