Skip to content

Commit

Permalink
Merge pull request #4 from luiz-lvj/twisted_edwards
Browse files Browse the repository at this point in the history
Twisted edwards and Weirstrass
  • Loading branch information
luiz-lvj authored Oct 3, 2024
2 parents cb0b681 + c8ee1c9 commit 51cfa1c
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 15 deletions.
109 changes: 109 additions & 0 deletions tools/garaga_rs/src/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use num_bigint::BigUint;
use std::cmp::PartialEq;
use std::collections::HashMap;

use crate::io::{biguint_from_hex, element_from_biguint};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CurveID {
BN254 = 0,
Expand Down Expand Up @@ -231,3 +233,110 @@ impl CurveParamsProvider<BLS12381PrimeField> for BLS12381PrimeField {
}
}
}

pub trait ToWeierstrassCurve {
fn to_weirstrass(
x_twisted: FieldElement<X25519PrimeField>,
y_twisted: FieldElement<X25519PrimeField>,
) -> (
FieldElement<X25519PrimeField>,
FieldElement<X25519PrimeField>,
);
}

pub trait ToTwistedEdwardsCurve {
fn to_twistededwards(
x_weierstrass: FieldElement<X25519PrimeField>,
y_weierstrass: FieldElement<X25519PrimeField>,
) -> (
FieldElement<X25519PrimeField>,
FieldElement<X25519PrimeField>,
);
}

impl ToWeierstrassCurve for X25519PrimeField {
fn to_weirstrass(
x_twisted: FieldElement<X25519PrimeField>,
y_twisted: FieldElement<X25519PrimeField>,
) -> (
FieldElement<X25519PrimeField>,
FieldElement<X25519PrimeField>,
) {
let a = element_from_biguint::<X25519PrimeField>(&biguint_from_hex(
"0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEC",
)); // Replace with actual a_twisted
let d = element_from_biguint::<X25519PrimeField>(&biguint_from_hex(
"0x52036CEE2B6FFE738CC740797779E89800700A4D4141D8AB75EB4DCA135978A3",
)); // Replace with actual d_twisted

let x = (FieldElement::<X25519PrimeField>::from(5) * a.clone()
+ a.clone() * y_twisted.clone()
- FieldElement::<X25519PrimeField>::from(5) * d.clone() * y_twisted.clone()
- d.clone())
* (FieldElement::<X25519PrimeField>::from(12)
- FieldElement::<X25519PrimeField>::from(12) * y_twisted.clone())
.inv()
.unwrap();
let y = (a.clone() + a * y_twisted.clone() - d.clone() * y_twisted.clone() - d)
* (FieldElement::<X25519PrimeField>::from(4) * x_twisted.clone()
- FieldElement::<X25519PrimeField>::from(4) * x_twisted.clone() * y_twisted)
.inv()
.unwrap();

(x, y)
}
}

impl ToTwistedEdwardsCurve for X25519PrimeField {
fn to_twistededwards(
x_weierstrass: FieldElement<X25519PrimeField>,
y_weierstrass: FieldElement<X25519PrimeField>,
) -> (
FieldElement<X25519PrimeField>,
FieldElement<X25519PrimeField>,
) {
let a = element_from_biguint::<X25519PrimeField>(&biguint_from_hex(
"0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEC",
)); // Replace with actual a_twisted
let d = element_from_biguint::<X25519PrimeField>(&biguint_from_hex(
"0x52036CEE2B6FFE738CC740797779E89800700A4D4141D8AB75EB4DCA135978A3",
)); // Replace with actual d_twisted

let y = (FieldElement::<X25519PrimeField>::from(5) * a.clone()
- FieldElement::<X25519PrimeField>::from(12) * x_weierstrass.clone()
- d.clone())
* (-FieldElement::<X25519PrimeField>::from(12) * x_weierstrass.clone() - a.clone()
+ FieldElement::<X25519PrimeField>::from(5) * d.clone())
.inv()
.unwrap();
let x = (a.clone() + a.clone() * y.clone() - d.clone() * y.clone() - d)
* (FieldElement::<X25519PrimeField>::from(4) * y_weierstrass.clone()
- FieldElement::<X25519PrimeField>::from(4) * y_weierstrass.clone() * y.clone())
.inv()
.unwrap();

(x, y)
}
}

#[cfg(test)]
mod tests {

use super::{CurveParamsProvider, ToTwistedEdwardsCurve, ToWeierstrassCurve, X25519PrimeField};

#[test]
fn test_to_weierstrass_and_back() {
let curve = X25519PrimeField::get_curve_params();

let x_weirstrass = curve.g_x;
let y_weirstrass = curve.g_y;

let (x_twisted, y_twisted) =
X25519PrimeField::to_twistededwards(x_weirstrass.clone(), y_weirstrass.clone());
let (x_weirstrass_back, y_weirstrass_back) =
X25519PrimeField::to_weirstrass(x_twisted, y_twisted);

assert_eq!(x_weirstrass, x_weirstrass_back);
assert_eq!(y_weirstrass, y_weirstrass_back);
}
}
2 changes: 1 addition & 1 deletion tools/garaga_rs/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pub fn byte_slice_split<const N: usize, const SIZE: usize>(bytes: &[u8]) -> [u12
limbs
}

fn biguint_from_hex(hex: &str) -> BigUint {
pub fn biguint_from_hex(hex: &str) -> BigUint {
let mut s = hex;
if let Some(stripped) = s.strip_prefix("0x") {
s = stripped;
Expand Down
93 changes: 93 additions & 0 deletions tools/garaga_rs/src/wasm_bindings.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use crate::definitions::{
CurveParamsProvider, ToTwistedEdwardsCurve, ToWeierstrassCurve, X25519PrimeField,
};
use crate::io::{element_from_biguint, element_to_biguint};
use num_bigint::BigUint;
use std::str::FromStr;
use wasm_bindgen::prelude::*;
Expand Down Expand Up @@ -49,6 +53,49 @@ fn biguint_to_jsvalue(v: BigUint) -> JsValue {
JsValue::bigint_from_str(&v.to_string())
}

#[wasm_bindgen]
pub fn to_weirstrass(x_twisted: JsValue, y_twisted: JsValue) -> Result<Vec<JsValue>, JsValue> {
let x_twisted_biguint = jsvalue_to_biguint(x_twisted).unwrap();
let x_twisted = element_from_biguint::<X25519PrimeField>(&x_twisted_biguint);

let y_twisted_biguint = jsvalue_to_biguint(y_twisted).unwrap();
let y_twisted = element_from_biguint::<X25519PrimeField>(&y_twisted_biguint);

let result = crate::definitions::X25519PrimeField::to_weirstrass(x_twisted, y_twisted);

let x_weirstrass = element_to_biguint::<X25519PrimeField>(&result.0);
let y_weirstrass = element_to_biguint::<X25519PrimeField>(&result.1);

let result = vec![
biguint_to_jsvalue(x_weirstrass),
biguint_to_jsvalue(y_weirstrass),
];

Ok(result)
}

#[wasm_bindgen]
pub fn to_twistededwards(
x_weirstrass: JsValue,
y_weirstrass: JsValue,
) -> Result<Vec<JsValue>, JsValue> {
let x_weirstrass_biguint = jsvalue_to_biguint(x_weirstrass).unwrap();
let x_weirstrass = element_from_biguint::<X25519PrimeField>(&x_weirstrass_biguint);

let y_weirstrass_biguint = jsvalue_to_biguint(y_weirstrass).unwrap();
let y_weirstrass = element_from_biguint::<X25519PrimeField>(&y_weirstrass_biguint);

let result =
crate::definitions::X25519PrimeField::to_twistededwards(x_weirstrass, y_weirstrass);

let x_twisted = element_to_biguint::<X25519PrimeField>(&result.0);
let y_twisted = element_to_biguint::<X25519PrimeField>(&result.1);

let result = vec![biguint_to_jsvalue(x_twisted), biguint_to_jsvalue(y_twisted)];

Ok(result)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -65,4 +112,50 @@ mod tests {
BigUint::from(v)
);
}

#[wasm_bindgen_test]
pub fn test_to_weierstrass_and_back() {
let curve = X25519PrimeField::get_curve_params();

let x_weirstrass = curve.g_x;
let y_weirstrass = curve.g_y;

let (x_twisted, y_twisted) =
X25519PrimeField::to_twistededwards(x_weirstrass.clone(), y_weirstrass.clone());
let (x_weirstrass_back, y_weirstrass_back) =
X25519PrimeField::to_weirstrass(x_twisted, y_twisted);

assert_eq!(x_weirstrass, x_weirstrass_back);
assert_eq!(y_weirstrass, y_weirstrass_back);
}

#[wasm_bindgen_test]
pub fn test_to_twistededwards_and_back_from_js() {
let curve = X25519PrimeField::get_curve_params();

let x_weirstrass = curve.g_x;
let y_weirstrass = curve.g_y;

let x_weirstrass_js =
biguint_to_jsvalue(element_to_biguint::<X25519PrimeField>(&x_weirstrass));
let y_weirstrass_js =
biguint_to_jsvalue(element_to_biguint::<X25519PrimeField>(&y_weirstrass));
let result_js = to_twistededwards(x_weirstrass_js, y_weirstrass_js).unwrap();
assert_eq!(result_js.len(), 2);

let x_twisted_js = result_js.get(0).unwrap();
let y_twisted_js = result_js.get(1).unwrap();

let x_twisted_biguint = jsvalue_to_biguint(x_twisted_js.clone()).unwrap();
let y_twisted_biguint = jsvalue_to_biguint(y_twisted_js.clone()).unwrap();

let x_twisted = element_from_biguint::<X25519PrimeField>(&x_twisted_biguint);
let y_twisted = element_from_biguint::<X25519PrimeField>(&y_twisted_biguint);

let (x_weirstrass_back, y_weirstrass_back) =
X25519PrimeField::to_weirstrass(x_twisted, y_twisted);

assert_eq!(x_weirstrass, x_weirstrass_back);
assert_eq!(y_weirstrass, y_weirstrass_back);
}
}
8 changes: 4 additions & 4 deletions tools/npm/garaga_ts/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tools/npm/garaga_ts/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"devDependencies": {
"@rollup/plugin-typescript": "^11.1.6",
"@types/jest": "^29.5.13",
"@types/node": "^22.5.4",
"@types/node": "^22.7.4",
"jest": "^29.7.0",
"rollup": "^4.21.2",
"rollup-plugin-dts": "^6.1.1",
Expand Down
22 changes: 21 additions & 1 deletion tools/npm/garaga_ts/src/node/api.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This files provides a ts-like interface for garaga_rs

import { msm_calldata_builder } from '../wasm/pkg/garaga_rs';
import { msm_calldata_builder, to_twistededwards, to_weirstrass } from '../wasm/pkg/garaga_rs';

export enum CurveId {
BN254 = 0,
Expand All @@ -27,3 +27,23 @@ export function msmCalldataBuilder(points: G1Point[], scalars: bigint[], curveId
const risc0Mode = options.risc0Mode ?? false;
return msm_calldata_builder(values, scalars, curveId, includeDigitsDecomposition, includePointsAndScalars, serializeAsPureFelt252Array, risc0Mode);
}

export function toWeirstrass( x_twisted: bigint, y_twisted: bigint): [bigint, bigint] {
const result = to_weirstrass(x_twisted, y_twisted);

if(result.length !== 2) {
throw new Error('Invalid result length');
}

return [result[0], result[1]];
}

export function toTwistedEdwards( x_weierstrass: bigint, y_weierstrass: bigint): [bigint, bigint] {
const result = to_twistededwards(x_weierstrass, y_weierstrass);

if(result.length !== 2) {
throw new Error('Invalid result length');
}

return [result[0], result[1]];
}
14 changes: 14 additions & 0 deletions tools/npm/garaga_ts/src/wasm/pkg/garaga_rs.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,26 @@
* @returns {any[]}
*/
export function msm_calldata_builder(values: any[], scalars: any[], curve_id: number, include_digits_decomposition: boolean, include_points_and_scalars: boolean, serialize_as_pure_felt252_array: boolean, risc0_mode: boolean): any[];
/**
* @param {any} x_twisted
* @param {any} y_twisted
* @returns {any[]}
*/
export function to_weirstrass(x_twisted: any, y_twisted: any): any[];
/**
* @param {any} x_weirstrass
* @param {any} y_weirstrass
* @returns {any[]}
*/
export function to_twistededwards(x_weirstrass: any, y_weirstrass: any): any[];

export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module;

export interface InitOutput {
readonly memory: WebAssembly.Memory;
readonly msm_calldata_builder: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number) => void;
readonly to_weirstrass: (a: number, b: number, c: number) => void;
readonly to_twistededwards: (a: number, b: number, c: number) => void;
readonly __wbindgen_malloc: (a: number, b: number) => number;
readonly __wbindgen_realloc: (a: number, b: number, c: number, d: number) => number;
readonly __wbindgen_add_to_stack_pointer: (a: number) => number;
Expand Down
Loading

0 comments on commit 51cfa1c

Please sign in to comment.