Skip to content

Commit

Permalink
implements constraints for signed inequality instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
tess-eract committed May 4, 2024
1 parent 7a6d813 commit ee93d2d
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 22 deletions.
11 changes: 10 additions & 1 deletion alu_u32/src/lt/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,26 @@ pub struct Lt32Cols<T> {
pub byte_flag: [T; 4],

/// Bit decomposition of 256 + input_1 - input_2
pub bits: [T; 10],
pub bits: [T; 9],

pub output: T,

pub multiplicity: T,

pub is_lt: T,
pub is_lte: T,
pub is_slt: T,
pub is_sle: T,

// inverse of input_1[i] - input_2[i] where i is the first byte that differs
pub diff_inv: T,

// bit decomposition of top bytes for input_1 and input_2
pub top_bits_1: [T; 8],
pub top_bits_2: [T; 8],

// boolean flag for whether the sign of the two inputs is different
pub different_signs: T,
}

pub const NUM_LT_COLS: usize = size_of::<Lt32Cols<u8>>();
Expand Down
24 changes: 17 additions & 7 deletions alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ where
vec![
(LT_COL_MAP.is_lt, SC::Val::from_canonical_u32(LT32)),
(LT_COL_MAP.is_lte, SC::Val::from_canonical_u32(LTE32)),
(LT_COL_MAP.is_slt, SC::Val::from_canonical_u32(SLT32)),
(LT_COL_MAP.is_sle, SC::Val::from_canonical_u32(SLE32)),
],
SC::Val::zero(),
);
Expand Down Expand Up @@ -101,13 +103,11 @@ impl Lt32Chip {
self.set_cols(cols, a, b, c);
}
Operation::Slt32(a, b, c) => {
// TODO: this is just a placeholder
cols.is_lt = F::one();
cols.is_slt = F::one();
self.set_cols(cols, a, b, c);
}
Operation::Sle32(a, b, c) => {
// TODO: this is just a placeholder
cols.is_lte = F::one();
cols.is_sle = F::one();
self.set_cols(cols, a, b, c);
}
}
Expand All @@ -133,13 +133,25 @@ impl Lt32Chip {
.find_map(|(n, (x, y))| if x == y { None } else { Some(n) })
{
let z = 256u16 + b[n] as u16 - c[n] as u16;
for i in 0..10 {
for i in 0..9 {
cols.bits[i] = F::from_canonical_u16(z >> i & 1);
}
cols.byte_flag[n] = F::one();
// b[n] != c[n] always here, so the difference is never zero.
cols.diff_inv = (cols.input_1[n] - cols.input_2[n]).inverse();
}
// compute (little-endian) bit decomposition of the top bytes
for i in 0..8 {
cols.top_bits_1[i] = F::from_canonical_u8(b[0] >> i & 1);
cols.top_bits_2[i] = F::from_canonical_u8(c[0] >> i & 1);
}
// check if sign bits agree and set different_signs accordingly
cols.different_signs = if cols.top_bits_1[7] != cols.top_bits_2[7] {
F::one()
} else {
F::zero()
};

cols.multiplicity = F::one();
}

Expand Down Expand Up @@ -218,7 +230,6 @@ where
let opcode = <Self as Instruction<M, F>>::OPCODE;
let comp = |a, b| a < b;
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);

state
.lt_u32_mut()
.operations
Expand Down Expand Up @@ -281,7 +292,6 @@ where
a_i <= b_i
};
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);

state
.lt_u32_mut()
.operations
Expand Down
87 changes: 75 additions & 12 deletions alu_u32/src/lt/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ where
let main = builder.main();
let local: &Lt32Cols<AB::Var> = main.row_slice(0).borrow();

let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512].map(AB::Expr::from_canonical_u32);
let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256].map(AB::Expr::from_canonical_u32);

let bit_comp: AB::Expr = local
.bits
Expand Down Expand Up @@ -76,26 +76,89 @@ where
builder.assert_bool(local.byte_flag[i]);
}

// Check the bit decomposition of the top bytes:
let top_comp_1: AB::Expr = local
.top_bits_1
.into_iter()
.zip(base_2.iter().cloned())
.map(|(bit, base)| bit * base)
.sum();
let top_comp_2: AB::Expr = local
.top_bits_2
.into_iter()
.zip(base_2.iter().cloned())
.map(|(bit, base)| bit * base)
.sum();
builder.assert_eq(top_comp_1, local.input_1[0]);
builder.assert_eq(top_comp_2, local.input_2[0]);

// Check that `different_signs` is set correctly by comparing sign bits.
builder
.when(local.byte_flag[0])
.when_ne(local.top_bits_1[7], local.top_bits_2[7])
.assert_eq(local.different_signs, AB::Expr::one());
builder
.when(local.different_signs)
.assert_eq(local.byte_flag[0], AB::Expr::one());
// local.top_bits_1[7] and local.top_bits_2[7] are boolean; their sum is 1 iff they are unequal.
builder
.when(local.different_signs)
.assert_eq(local.top_bits_1[7] + local.top_bits_2[7], AB::Expr::one());

builder.assert_bool(local.is_lt);
builder.assert_bool(local.is_lte);
builder.assert_bool(local.is_lt + local.is_lte);
builder.assert_bool(local.is_slt);
builder.assert_bool(local.is_sle);
builder.assert_bool(local.is_lt + local.is_lte + local.is_slt + local.is_sle);

let is_signed = local.is_slt + local.is_sle;
let is_unsigned = AB::Expr::one() - is_signed;
let same_sign = AB::Expr::one() - local.different_signs;
let are_equal = AB::Expr::one() - flag_sum.clone();

// Output constraints
// local.bits[8] is 1 iff input_1 > input_2: output should be 0
builder.when(local.bits[8]).assert_zero(local.output);
// output should be 1 if is_lte & input_1 == input_2
// Case 0: input_1 > input_2 as unsigned ints; equivalently, local.bits[8] == 1
// when both inputs have the same sign, signed and unsigned inequality agree.
builder
.when(local.is_lte)
.when_ne(flag_sum.clone(), AB::Expr::one())
.when(local.bits[8])
.when(is_unsigned.clone() + same_sign.clone())
.assert_zero(local.output);
// when the inputs have different signs, signed inequality is the opposite of unsigned inequality.
builder
.when(local.bits[8])
.when(local.different_signs)
.assert_one(local.output);
// output should be 0 if is_lt & input_1 == input_2

// Case 1: input_1 < input_2 as unsigned ints; equivalently, local.bits[8] == is_equal == 0.
builder
.when(local.is_lt)
.when_ne(flag_sum, AB::Expr::one())
// when are_equal == 1, we have already enforced that local.bits[8] == 0
.when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one())
.when(is_unsigned.clone() + same_sign.clone())
.assert_one(local.output);
builder
.when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one())
.when(local.different_signs)
.assert_zero(local.output);

// Check bit decomposition
for bit in local.bits.into_iter() {
// Case 2: input_1 == input_2; equivalently, are_equal == 1
// output should be 1 if is_lte or is_sle
builder
.when(are_equal.clone())
.when(local.is_lte + local.is_sle)
.assert_one(local.output);
// output should be 0 if is_lt or is_slt
builder
.when(are_equal.clone())
.when(local.is_lt + local.is_slt)
.assert_zero(local.output);

// Check "bit" values are all boolean
for bit in local
.bits
.into_iter()
.chain(local.top_bits_1.into_iter())
.chain(local.top_bits_2.into_iter())
{
builder.assert_bool(bit);
}
}
Expand Down
Loading

0 comments on commit ee93d2d

Please sign in to comment.