Skip to content

Commit

Permalink
Merge pull request #121 from valida-xyz/thealmarty-lte
Browse files Browse the repository at this point in the history
Add `LTE` instruction.
  • Loading branch information
morganthomas authored Mar 18, 2024
2 parents b145f40 + e1c336b commit 7e3f1dd
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 34 deletions.
5 changes: 4 additions & 1 deletion alu_u32/src/lt/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ pub struct Lt32Cols<T> {
pub input_2: Word<T>,

/// Boolean flags indicating which byte pair differs
pub byte_flag: [T; 3],
pub byte_flag: [T; 4],

/// Bit decomposition of 256 + input_1 - input_2
pub bits: [T; 10],

pub output: T,

pub multiplicity: T,

pub is_lt: T,
pub is_lte: T,
}

pub const NUM_LT_COLS: usize = size_of::<Lt32Cols<u8>>();
Expand Down
116 changes: 93 additions & 23 deletions alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use valida_cpu::MachineWithCpuChip;
use valida_machine::{
instructions, Chip, Instruction, Interaction, Operands, Word, MEMORY_CELL_BYTES,
};
use valida_opcodes::LT32;
use valida_opcodes::{LT32, LTE32};

use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field, PrimeField};
Expand All @@ -24,7 +24,8 @@ pub mod stark;

#[derive(Clone)]
pub enum Operation {
Lt32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
Lt32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
Lte32(Word<u8>, Word<u8>, Word<u8>), // (dst, src1, src2)
}

#[derive(Default)]
Expand Down Expand Up @@ -53,7 +54,13 @@ where
}

fn global_receives(&self, machine: &M) -> Vec<Interaction<SC::Val>> {
let opcode = VirtualPairCol::constant(SC::Val::from_canonical_u32(LT32));
let opcode = VirtualPairCol::new_main(
vec![
(LT_COL_MAP.is_lt, SC::Val::from_canonical_u32(LT32)),
(LT_COL_MAP.is_lte, SC::Val::from_canonical_u32(LTE32)),
],
SC::Val::zero(),
);
let input_1 = LT_COL_MAP.input_1.0.map(VirtualPairCol::single_main);
let input_2 = LT_COL_MAP.input_2.0.map(VirtualPairCol::single_main);
let output = (0..MEMORY_CELL_BYTES - 1)
Expand Down Expand Up @@ -83,37 +90,51 @@ impl Lt32Chip {
let cols: &mut Lt32Cols<F> = unsafe { transmute(&mut row) };

match op {
Operation::Lt32(dst, src1, src2) => {
if let Some(n) = src1
.into_iter()
.zip(src2.into_iter())
.enumerate()
.find_map(|(n, (x, y))| if x == y { Some(n) } else { None })
{
let z = 256u16 + src1[n] as u16 - src2[n] as u16;
for i in 0..10 {
cols.bits[i] = F::from_canonical_u16(z >> i & 1);
}
if n < 3 {
cols.byte_flag[n] = F::one();
}
}
cols.input_1 = src1.transform(F::from_canonical_u8);
cols.input_2 = src2.transform(F::from_canonical_u8);
cols.output = F::from_canonical_u8(dst[3]);
cols.multiplicity = F::one();
Operation::Lt32(a, b, c) => {
cols.is_lt = F::one();
self.set_cols(cols, a, b, c);
}
Operation::Lte32(a, b, c) => {
cols.is_lte = F::one();
self.set_cols(cols, a, b, c);
}
}
row
}

fn set_cols<F>(&self, cols: &mut Lt32Cols<F>, a: &Word<u8>, b: &Word<u8>, c: &Word<u8>)
where
F: PrimeField,
{
// Set the input columns
cols.input_1 = b.transform(F::from_canonical_u8);
cols.input_2 = c.transform(F::from_canonical_u8);
cols.output = F::from_canonical_u8(a[3]);

if let Some(n) = b
.into_iter()
.zip(c.into_iter())
.enumerate()
.find_map(|(n, (x, y))| if x == y { Some(n) } else { None })
{
let z = 256u16 + b[n] as u16 - c[n] as u16;
for i in 0..10 {
cols.bits[i] = F::from_canonical_u16(z >> i & 1);
}
if n < 4 {
cols.byte_flag[n] = F::one();
}
}
cols.multiplicity = F::one();
}
}

pub trait MachineWithLt32Chip<F: Field>: MachineWithCpuChip<F> {
fn lt_u32(&self) -> &Lt32Chip;
fn lt_u32_mut(&mut self) -> &mut Lt32Chip;
}

instructions!(Lt32Instruction);
instructions!(Lt32Instruction, Lte32Instruction);

impl<M, F> Instruction<M, F> for Lt32Instruction
where
Expand Down Expand Up @@ -163,3 +184,52 @@ where
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}

impl<M, F> Instruction<M, F> for Lte32Instruction
where
M: MachineWithLt32Chip<F>,
F: Field,
{
const OPCODE: u32 = LTE32;

fn execute(state: &mut M, ops: Operands<i32>) {
let opcode = <Self as Instruction<M, F>>::OPCODE;
let clk = state.cpu().clock;
let pc = state.cpu().pc;
let mut imm: Option<Word<u8>> = None;
let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32;
let write_addr = (state.cpu().fp as i32 + ops.a()) as u32;
let src1 = if ops.d() == 1 {
let b = (ops.b() as u32).into();
imm = Some(b);
b
} else {
state
.mem_mut()
.read(clk, read_addr_1, true, pc, opcode, 0, "")
};
let src2 = if ops.is_imm() == 1 {
let c = (ops.c() as u32).into();
imm = Some(c);
c
} else {
let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32;
state
.mem_mut()
.read(clk, read_addr_2, true, pc, opcode, 1, "")
};

let dst = if src1 <= src2 {
Word::from(1)
} else {
Word::from(0)
};
state.mem_mut().write(clk, write_addr, dst, true);

state
.lt_u32_mut()
.operations
.push(Operation::Lte32(dst, src1, src2));
state.cpu_mut().push_bus_op(imm, opcode, ops);
}
}
12 changes: 11 additions & 1 deletion alu_u32/src/lt/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,28 @@ where
builder.assert_bool(flag_sum.clone());
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(flag_sum, AB::Expr::one())
.when_ne(flag_sum.clone(), AB::Expr::one())
.assert_eq(
AB::Expr::from_canonical_u32(256) + local.input_1[3] - local.input_2[3],
bit_comp.clone(),
);

builder.assert_bool(local.is_lt);
builder.assert_bool(local.is_lte);
builder.assert_bool(local.is_lt + local.is_lte);

// Output constraints
builder.when(local.bits[8]).assert_zero(local.output);
builder
.when_ne(local.multiplicity, AB::Expr::zero())
.when_ne(local.bits[8], AB::Expr::one())
.assert_one(local.output);
// output should be 1 if is_lte & input_1 == input_2
let all_flag_sum = flag_sum + local.byte_flag[3];
builder
.when(local.is_lte)
.when_ne(all_flag_sum, AB::Expr::one())
.assert_one(local.output);

// Check bit decomposition
for bit in local.bits.into_iter() {
Expand Down
2 changes: 1 addition & 1 deletion assembler/grammar/assembly.pest
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mnemonic = {
"lw" | "sw" | "jalv" | "jal" | "beqi" | "beq" | "bnei" | "bne" | "imm32" | "stop" |
"advread" | "advwrite" |
"addi" | "add" | "subi" | "sub" | "muli" | "mul" | "mulhsi"| "mulhui"| "mulhs"| "mulhu" | "divi" | "div" | "sdiv" | "sdivi" |
"lti" | "lt" | "shli" | "shl" | "shri" | "shr" | "srai" | "sra" |
"ilte" | "ltei" | "lte" | "ilt" | "lti" | "lt" | "shli" | "shl" | "shri" | "shr" | "srai" | "sra" |
"andi" | "and" | "ori" | "or" | "xori" | "xor" | "nei" | "ne" | "eqi" | "eq" |
"feadd" | "fesub" | "femul" |
"write"
Expand Down
17 changes: 11 additions & 6 deletions assembler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ pub fn assemble(input: &str) -> Result<Vec<u8>, String> {
"mulhu" | "mulhui" => MULHU32,
"div" | "divi" => DIV32,
"sdiv" | "sdivi" => SDIV32,
"lt" | "lti" => LT32,
"ilt" | "lt" | "lti" => LT32,
"ilte" | "lte" | "ltei" => LTE32,
"shl" | "shli" => SHL32,
"shr" | "shri" => SHR32,
"sra" | "srai" => SRA32,
Expand Down Expand Up @@ -117,18 +118,22 @@ pub fn assemble(input: &str) -> Result<Vec<u8>, String> {
// (0, 0, 0, 0, 0)
operands.extend(vec![0; 5]);
}
"add" | "sub" | "mul" | "mulhs" | "mulhu" | "div" | "lt" | "shl" | "shr"
| "sra" | "beq" | "bne" | "and" | "or" | "xor" | "ne" | "eq" | "jal"
| "jalv" => {
"add" | "sub" | "mul" | "mulhs" | "mulhu" | "div" | "lt" | "lte" | "shl"
| "shr" | "sra" | "beq" | "bne" | "and" | "or" | "xor" | "ne" | "eq"
| "jal" | "jalv" => {
// (a, b, c, 0, 0)
operands.extend(vec![0; 2]);
}
"addi" | "subi" | "muli" | "mulhsi" | "mulhui" | "divi" | "sdivi" | "lti"
| "shli" | "shri" | "srai" | "beqi" | "bnei" | "andi" | "ori" | "xori"
| "nei" | "eqi" => {
| "ltei" | "shli" | "shri" | "srai" | "beqi" | "bnei" | "andi" | "ori"
| "xori" | "nei" | "eqi" => {
// (a, b, c, 0, 1)
operands.extend(vec![0, 1]);
}
"ilt" | "ilte" => {
// (a, b, c, 1, 0)
operands.extend(vec![0, 1]);
}
"advread" => {
// (a, 0, 0, 0, 0)
operands.extend(vec![0; 4]);
Expand Down
5 changes: 4 additions & 1 deletion basic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use valida_alu_u32::{
},
com::{Com32Chip, Eq32Instruction, MachineWithCom32Chip, Ne32Instruction},
div::{Div32Chip, Div32Instruction, MachineWithDiv32Chip, SDiv32Instruction},
lt::{Lt32Chip, Lt32Instruction, MachineWithLt32Chip},
lt::{Lt32Chip, Lt32Instruction, Lte32Instruction, MachineWithLt32Chip},
mul::{
MachineWithMul32Chip, Mul32Chip, Mul32Instruction, Mulhs32Instruction, Mulhu32Instruction,
},
Expand Down Expand Up @@ -116,6 +116,9 @@ pub struct BasicMachine<F: PrimeField32 + TwoAdicField> {
#[instruction(lt_u32)]
lt32: Lt32Instruction,

#[instruction(lt_u32)]
lte32: Lte32Instruction,

#[instruction(bitwise_u32)]
and32: And32Instruction,

Expand Down
2 changes: 1 addition & 1 deletion opcodes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub const NE32: u32 = 111;
pub const MULHU32: u32 = 112;
pub const SRA32: u32 = 113;
pub const MULHS32: u32 = 114;
pub const LTE32: u32 = 115; //TODO
pub const LTE32: u32 = 115;
pub const EQ32: u32 = 116;

/// NATIVE FIELD
Expand Down

0 comments on commit 7e3f1dd

Please sign in to comment.