From 93d6fe5ab62136e42a372b166c6aa171cb564a7d Mon Sep 17 00:00:00 2001 From: thealmarty <“thealmartyblog@gmail.com”> Date: Fri, 15 Mar 2024 12:07:44 -0700 Subject: [PATCH 1/2] Add LTE instruction. --- alu_u32/src/lt/columns.rs | 3 + alu_u32/src/lt/mod.rs | 116 +++++++++++++++++++++++++------- alu_u32/src/lt/stark.rs | 15 ++++- assembler/grammar/assembly.pest | 2 +- assembler/src/lib.rs | 17 +++-- basic/src/lib.rs | 5 +- opcodes/src/lib.rs | 2 +- 7 files changed, 127 insertions(+), 33 deletions(-) diff --git a/alu_u32/src/lt/columns.rs b/alu_u32/src/lt/columns.rs index 0bbd3c9e..633baf25 100644 --- a/alu_u32/src/lt/columns.rs +++ b/alu_u32/src/lt/columns.rs @@ -18,6 +18,9 @@ pub struct Lt32Cols { pub output: T, pub multiplicity: T, + + pub is_lt: T, + pub is_lte: T, } pub const NUM_LT_COLS: usize = size_of::>(); diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index 032f3fc2..c6ea04b6 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -10,7 +10,7 @@ use valida_cpu::MachineWithCpuChip; use valida_machine::{ instructions, Chip, Instruction, Interaction, Operands, Word, MEMORY_CELL_BYTES, }; -use valida_opcodes::LT32; +use valida_opcodes::{LT32, LTE32}; use p3_air::VirtualPairCol; use p3_field::{AbstractField, Field, PrimeField}; @@ -24,7 +24,8 @@ pub mod stark; #[derive(Clone)] pub enum Operation { - Lt32(Word, Word, Word), // (dst, src1, src2) + Lt32(Word, Word, Word), // (dst, src1, src2) + Lte32(Word, Word, Word), // (dst, src1, src2) } #[derive(Default)] @@ -53,7 +54,13 @@ where } fn global_receives(&self, machine: &M) -> Vec> { - let opcode = VirtualPairCol::constant(SC::Val::from_canonical_u32(LT32)); + let opcode = VirtualPairCol::new_main( + vec![ + (LT_COL_MAP.is_lt, SC::Val::from_canonical_u32(LT32)), + (LT_COL_MAP.is_lte, SC::Val::from_canonical_u32(LTE32)), + ], + SC::Val::zero(), + ); let input_1 = LT_COL_MAP.input_1.0.map(VirtualPairCol::single_main); let input_2 = LT_COL_MAP.input_2.0.map(VirtualPairCol::single_main); let output = (0..MEMORY_CELL_BYTES - 1) @@ -83,29 +90,43 @@ impl Lt32Chip { let cols: &mut Lt32Cols = unsafe { transmute(&mut row) }; match op { - Operation::Lt32(dst, src1, src2) => { - if let Some(n) = src1 - .into_iter() - .zip(src2.into_iter()) - .enumerate() - .find_map(|(n, (x, y))| if x == y { Some(n) } else { None }) - { - let z = 256u16 + src1[n] as u16 - src2[n] as u16; - for i in 0..10 { - cols.bits[i] = F::from_canonical_u16(z >> i & 1); - } - if n < 3 { - cols.byte_flag[n] = F::one(); - } - } - cols.input_1 = src1.transform(F::from_canonical_u8); - cols.input_2 = src2.transform(F::from_canonical_u8); - cols.output = F::from_canonical_u8(dst[3]); - cols.multiplicity = F::one(); + Operation::Lt32(a, b, c) => { + cols.is_lt = F::one(); + self.set_cols(cols, a, b, c); + } + Operation::Lte32(a, b, c) => { + cols.is_lte = F::one(); + self.set_cols(cols, a, b, c); } } row } + + fn set_cols(&self, cols: &mut Lt32Cols, a: &Word, b: &Word, c: &Word) + where + F: PrimeField, + { + // Set the input columns + cols.input_1 = b.transform(F::from_canonical_u8); + cols.input_2 = c.transform(F::from_canonical_u8); + cols.output = F::from_canonical_u8(a[3]); + + if let Some(n) = b + .into_iter() + .zip(c.into_iter()) + .enumerate() + .find_map(|(n, (x, y))| if x == y { Some(n) } else { None }) + { + let z = 256u16 + b[n] as u16 - c[n] as u16; + for i in 0..10 { + cols.bits[i] = F::from_canonical_u16(z >> i & 1); + } + if n < 3 { + cols.byte_flag[n] = F::one(); + } + } + cols.multiplicity = F::one(); + } } pub trait MachineWithLt32Chip: MachineWithCpuChip { @@ -113,7 +134,7 @@ pub trait MachineWithLt32Chip: MachineWithCpuChip { fn lt_u32_mut(&mut self) -> &mut Lt32Chip; } -instructions!(Lt32Instruction); +instructions!(Lt32Instruction, Lte32Instruction); impl Instruction for Lt32Instruction where @@ -163,3 +184,52 @@ where state.cpu_mut().push_bus_op(imm, opcode, ops); } } + +impl Instruction for Lte32Instruction +where + M: MachineWithLt32Chip, + F: Field, +{ + const OPCODE: u32 = LTE32; + + fn execute(state: &mut M, ops: Operands) { + let opcode = >::OPCODE; + let clk = state.cpu().clock; + let pc = state.cpu().pc; + let mut imm: Option> = None; + let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; + let write_addr = (state.cpu().fp as i32 + ops.a()) as u32; + let src1 = if ops.d() == 1 { + let b = (ops.b() as u32).into(); + imm = Some(b); + b + } else { + state + .mem_mut() + .read(clk, read_addr_1, true, pc, opcode, 0, "") + }; + let src2 = if ops.is_imm() == 1 { + let c = (ops.c() as u32).into(); + imm = Some(c); + c + } else { + let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32; + state + .mem_mut() + .read(clk, read_addr_2, true, pc, opcode, 1, "") + }; + + let dst = if src1 <= src2 { + Word::from(1) + } else { + Word::from(0) + }; + state.mem_mut().write(clk, write_addr, dst, true); + + state + .lt_u32_mut() + .operations + .push(Operation::Lte32(dst, src1, src2)); + state.cpu_mut().push_bus_op(imm, opcode, ops); + } +} diff --git a/alu_u32/src/lt/stark.rs b/alu_u32/src/lt/stark.rs index 0b5b5f22..44e7cb25 100644 --- a/alu_u32/src/lt/stark.rs +++ b/alu_u32/src/lt/stark.rs @@ -51,18 +51,31 @@ where builder.assert_bool(flag_sum.clone()); builder .when_ne(local.multiplicity, AB::Expr::zero()) - .when_ne(flag_sum, AB::Expr::one()) + .when_ne(flag_sum.clone(), AB::Expr::one()) .assert_eq( AB::Expr::from_canonical_u32(256) + local.input_1[3] - local.input_2[3], bit_comp.clone(), ); + builder.assert_bool(local.is_lt); + builder.assert_bool(local.is_lte); + builder.assert_bool(local.is_lt + local.is_lte); + // Output constraints builder.when(local.bits[8]).assert_zero(local.output); builder .when_ne(local.multiplicity, AB::Expr::zero()) .when_ne(local.bits[8], AB::Expr::one()) .assert_one(local.output); + // output should be 1 if is_lte & input_1 == input_2 + builder + .when(local.is_lte) + .when_ne(flag_sum, AB::Expr::one()) + .when_ne( + AB::Expr::from_canonical_u32(256) + local.input_1[3] - local.input_2[3], + AB::Expr::one(), + ) + .assert_one(local.output); // Check bit decomposition for bit in local.bits.into_iter() { diff --git a/assembler/grammar/assembly.pest b/assembler/grammar/assembly.pest index c82955da..530aa40e 100644 --- a/assembler/grammar/assembly.pest +++ b/assembler/grammar/assembly.pest @@ -6,7 +6,7 @@ mnemonic = { "lw" | "sw" | "jalv" | "jal" | "beqi" | "beq" | "bnei" | "bne" | "imm32" | "stop" | "advread" | "advwrite" | "addi" | "add" | "subi" | "sub" | "muli" | "mul" | "mulhsi"| "mulhui"| "mulhs"| "mulhu" | "divi" | "div" | "sdiv" | "sdivi" | - "lti" | "lt" | "shli" | "shl" | "shri" | "shr" | "srai" | "sra" | + "ilte" | "ltei" | "lte" | "ilt" | "lti" | "lt" | "shli" | "shl" | "shri" | "shr" | "srai" | "sra" | "andi" | "and" | "ori" | "or" | "xori" | "xor" | "nei" | "ne" | "eqi" | "eq" | "feadd" | "fesub" | "femul" | "write" diff --git a/assembler/src/lib.rs b/assembler/src/lib.rs index 890949a4..f24e48f5 100644 --- a/assembler/src/lib.rs +++ b/assembler/src/lib.rs @@ -77,7 +77,8 @@ pub fn assemble(input: &str) -> Result, String> { "mulhu" | "mulhui" => MULHU32, "div" | "divi" => DIV32, "sdiv" | "sdivi" => SDIV32, - "lt" | "lti" => LT32, + "ilt" | "lt" | "lti" => LT32, + "ilte" | "lte" | "ltei" => LTE32, "shl" | "shli" => SHL32, "shr" | "shri" => SHR32, "sra" | "srai" => SRA32, @@ -117,18 +118,22 @@ pub fn assemble(input: &str) -> Result, String> { // (0, 0, 0, 0, 0) operands.extend(vec![0; 5]); } - "add" | "sub" | "mul" | "mulhs" | "mulhu" | "div" | "lt" | "shl" | "shr" - | "sra" | "beq" | "bne" | "and" | "or" | "xor" | "ne" | "eq" | "jal" - | "jalv" => { + "add" | "sub" | "mul" | "mulhs" | "mulhu" | "div" | "lt" | "lte" | "shl" + | "shr" | "sra" | "beq" | "bne" | "and" | "or" | "xor" | "ne" | "eq" + | "jal" | "jalv" => { // (a, b, c, 0, 0) operands.extend(vec![0; 2]); } "addi" | "subi" | "muli" | "mulhsi" | "mulhui" | "divi" | "sdivi" | "lti" - | "shli" | "shri" | "srai" | "beqi" | "bnei" | "andi" | "ori" | "xori" - | "nei" | "eqi" => { + | "ltei" | "shli" | "shri" | "srai" | "beqi" | "bnei" | "andi" | "ori" + | "xori" | "nei" | "eqi" => { // (a, b, c, 0, 1) operands.extend(vec![0, 1]); } + "ilt" | "ilte" => { + // (a, b, c, 1, 0) + operands.extend(vec![0, 1]); + } "advread" => { // (a, 0, 0, 0, 0) operands.extend(vec![0; 4]); diff --git a/basic/src/lib.rs b/basic/src/lib.rs index 4c06ce90..8f0b564a 100644 --- a/basic/src/lib.rs +++ b/basic/src/lib.rs @@ -20,7 +20,7 @@ use valida_alu_u32::{ }, com::{Com32Chip, Eq32Instruction, MachineWithCom32Chip, Ne32Instruction}, div::{Div32Chip, Div32Instruction, MachineWithDiv32Chip, SDiv32Instruction}, - lt::{Lt32Chip, Lt32Instruction, MachineWithLt32Chip}, + lt::{Lt32Chip, Lt32Instruction, Lte32Instruction, MachineWithLt32Chip}, mul::{ MachineWithMul32Chip, Mul32Chip, Mul32Instruction, Mulhs32Instruction, Mulhu32Instruction, }, @@ -116,6 +116,9 @@ pub struct BasicMachine { #[instruction(lt_u32)] lt32: Lt32Instruction, + #[instruction(lt_u32)] + lte32: Lte32Instruction, + #[instruction(bitwise_u32)] and32: And32Instruction, diff --git a/opcodes/src/lib.rs b/opcodes/src/lib.rs index 42106938..3bb989b4 100644 --- a/opcodes/src/lib.rs +++ b/opcodes/src/lib.rs @@ -30,7 +30,7 @@ pub const NE32: u32 = 111; pub const MULHU32: u32 = 112; pub const SRA32: u32 = 113; pub const MULHS32: u32 = 114; -pub const LTE32: u32 = 115; //TODO +pub const LTE32: u32 = 115; pub const EQ32: u32 = 116; /// NATIVE FIELD From e1c336b99c5b0e7ebbb2bd00adc957845c206b5d Mon Sep 17 00:00:00 2001 From: thealmarty <“thealmartyblog@gmail.com”> Date: Mon, 18 Mar 2024 09:49:05 -0700 Subject: [PATCH 2/2] Fix constraint and test. --- alu_u32/src/lt/columns.rs | 2 +- alu_u32/src/lt/mod.rs | 2 +- alu_u32/src/lt/stark.rs | 7 ++----- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/alu_u32/src/lt/columns.rs b/alu_u32/src/lt/columns.rs index 633baf25..55fa10cc 100644 --- a/alu_u32/src/lt/columns.rs +++ b/alu_u32/src/lt/columns.rs @@ -10,7 +10,7 @@ pub struct Lt32Cols { pub input_2: Word, /// Boolean flags indicating which byte pair differs - pub byte_flag: [T; 3], + pub byte_flag: [T; 4], /// Bit decomposition of 256 + input_1 - input_2 pub bits: [T; 10], diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index c6ea04b6..59f2aba2 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -121,7 +121,7 @@ impl Lt32Chip { for i in 0..10 { cols.bits[i] = F::from_canonical_u16(z >> i & 1); } - if n < 3 { + if n < 4 { cols.byte_flag[n] = F::one(); } } diff --git a/alu_u32/src/lt/stark.rs b/alu_u32/src/lt/stark.rs index 44e7cb25..2180ce58 100644 --- a/alu_u32/src/lt/stark.rs +++ b/alu_u32/src/lt/stark.rs @@ -68,13 +68,10 @@ where .when_ne(local.bits[8], AB::Expr::one()) .assert_one(local.output); // output should be 1 if is_lte & input_1 == input_2 + let all_flag_sum = flag_sum + local.byte_flag[3]; builder .when(local.is_lte) - .when_ne(flag_sum, AB::Expr::one()) - .when_ne( - AB::Expr::from_canonical_u32(256) + local.input_1[3] - local.input_2[3], - AB::Expr::one(), - ) + .when_ne(all_flag_sum, AB::Expr::one()) .assert_one(local.output); // Check bit decomposition