diff --git a/src/bls12381/Fp.sol b/src/bls12381/Fp.sol index 8f1d72f..9d5e71c 100644 --- a/src/bls12381/Fp.sol +++ b/src/bls12381/Fp.sol @@ -163,81 +163,4 @@ library BLS12FP { } return Bls12Fp(output[0], output[1]); } - - /// @dev base^base % modulus - /// @param base Bls12Fp. - /// @param exp Bls12Fp. - /// @param modulus Bls12Fp. - /// @return Result of mod_exp. - function mod_exp( - Bls12Fp memory base, - Bls12Fp memory exp, - Bls12Fp memory modulus - ) - internal - view - returns (Bls12Fp memory) - { - uint256[9] memory input; - input[0] = 0x40; - input[1] = 0x40; - input[2] = 0x40; - input[3] = base.a; - input[4] = base.b; - input[5] = exp.a; - input[6] = exp.b; - input[7] = modulus.a; - input[8] = modulus.b; - uint256[2] memory output; - - assembly ("memory-safe") { - if iszero(staticcall(gas(), MOD_EXP, input, 288, output, 64)) { - let p := mload(0x40) - returndatacopy(p, 0, returndatasize()) - revert(p, returndatasize()) - } - } - - return Bls12Fp(output[0], output[1]); - } - - /// @dev base^base % modulus - /// @param base Bls12Fp. - /// @param exp uint256. - /// @param modulus Bls12Fp. - /// @return Result of mod_exp. - function mod_exp(Bls12Fp memory base, uint256 exp, Bls12Fp memory modulus) internal view returns (Bls12Fp memory) { - uint256[8] memory input; - input[0] = 0x40; - input[1] = 0x40; - input[2] = 0x40; - input[3] = base.a; - input[4] = base.b; - input[5] = exp; - input[6] = modulus.a; - input[7] = modulus.b; - uint256[2] memory output; - - assembly ("memory-safe") { - if iszero(staticcall(gas(), MOD_EXP, input, 256, output, 64)) { - let p := mload(0x40) - returndatacopy(p, 0, returndatasize()) - revert(p, returndatasize()) - } - } - - return Bls12Fp(output[0], output[1]); - } - - // using quadratic residue - function find_y(Bls12Fp memory x) internal view returns (Bls12Fp memory) { - Bls12Fp memory y_square = add(mod_exp(x, 3, q()), b()); - Bls12Fp memory y = mod_exp(y_square, qr(), q()); - return y; - } - - // pow(y, 2, q) == (x**3 + b.n) % q: - function is_on_curve(Bls12Fp memory x, Bls12Fp memory y) internal view returns (bool) { - return eq(mod_exp(y, 2, q()), add(mod_exp(x, 3, q()), b())); - } } diff --git a/src/bls12381/Fp2.sol b/src/bls12381/Fp2.sol index 77d17e1..847fb8e 100644 --- a/src/bls12381/Fp2.sol +++ b/src/bls12381/Fp2.sol @@ -21,6 +21,14 @@ library BLS12FP2 { bytes private constant Z_PAD = hex"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"; + function zero() internal pure returns (Bls12Fp2 memory) { + return Bls12Fp2(BLS12FP.zero(), BLS12FP.zero()); + } + + function b2() internal pure returns (Bls12Fp2 memory) { + return Bls12Fp2(Bls12Fp(0, 4), Bls12Fp(0, 4)); + } + /// @dev Returns `true` if `x` is equal to `y`. /// @param x Bls12Fp2. /// @param y Bls12Fp2. @@ -33,6 +41,10 @@ library BLS12FP2 { return x.c0.is_zero() && x.c1.is_zero(); } + function is_valid(Bls12Fp2 memory self) internal pure returns (bool) { + return self.c0.is_valid() && self.c1.is_valid(); + } + /// @dev Hash an arbitrary `msg` to `2` elements from field `Fp2`. /// @param message A byte string containing the message to hash. /// @return `2` of field elements. diff --git a/src/bls12381/G1.sol b/src/bls12381/G1.sol index fd27c2b..39b7c2d 100644 --- a/src/bls12381/G1.sol +++ b/src/bls12381/G1.sol @@ -94,8 +94,8 @@ library BLS12G1Affine { } // Take a 96 byte array and convert to a G1 point (x, y) - function deserialize(bytes memory g1) internal view returns (Bls12G1 memory) { - require(g1.length == 48, "!g1"); + function deserialize(bytes memory g1) internal pure returns (Bls12G1 memory) { + require(g1.length == 96, "!g1"); bytes1 byt = g1[0]; bool c_flag = (byt >> 7) & 0x01 == 0x01; bool b_flag = (byt >> 6) & 0x01 == 0x01; @@ -103,24 +103,19 @@ library BLS12G1Affine { if (a_flag && (!c_flag || b_flag)) { revert("!flag"); } - require(c_flag, "uncompressed"); + require(!c_flag, "compressed"); // Zero flags g1[0] = byt & 0x1f; Bls12Fp memory x = Bls12Fp(g1.slice_to_uint(0, 16), g1.slice_to_uint(16, 48)); + Bls12Fp memory y = Bls12Fp(g1.slice_to_uint(48, 64), g1.slice_to_uint(64, 96)); if (b_flag) { - require(x.is_zero(), "!zero"); + require(x.is_zero() && y.is_zero(), "!zero"); return zero(); } - Bls12Fp memory y = x.find_y(); - // Require elements less than field modulus - require(x.is_valid() && y.is_valid(), "!fp"); - - if (y.add(y).gt(BLS12FP.q()) != a_flag) { - y = BLS12FP.q().sub(y); - } + require(x.is_valid() && y.is_valid(), "!pnt"); // Convert to G1 Bls12G1 memory p = Bls12G1(x, y); diff --git a/src/bls12381/G2.sol b/src/bls12381/G2.sol index d171a2d..df75c2b 100644 --- a/src/bls12381/G2.sol +++ b/src/bls12381/G2.sol @@ -28,6 +28,10 @@ library BLS12G2Affine { bytes1 private constant INFINITY_FLAG = bytes1(0x40); bytes1 private constant Y_FLAG = bytes1(0x20); + function zero() internal pure returns (Bls12G2 memory) { + return Bls12G2(BLS12FP2.zero(), BLS12FP2.zero()); + } + /// @dev Returns `true` if `x` is equal to `y`. /// @param a Bls12G2. /// @param b Bls12G2. @@ -126,11 +130,13 @@ library BLS12G2Affine { function deserialize(bytes memory g2) internal pure returns (Bls12G2 memory) { require(g2.length == 192, "!g2"); bytes1 byt = g2[0]; - require(byt & COMPRESION_FLAG == 0, "compressed"); - require(byt & INFINITY_FLAG == 0, "infinity"); - require(byt & Y_FLAG == 0, "y_flag"); - - g2[0] = byt & 0x1f; + bool c_flag = (byt >> 7) & 0x01 == 0x01; + bool b_flag = (byt >> 6) & 0x01 == 0x01; + bool a_flag = (byt >> 5) & 0x01 == 0x01; + if (a_flag && (!c_flag || b_flag)) { + revert("!flag"); + } + require(!c_flag, "compressed"); // Convert from array to FP2 Bls12Fp memory x_imaginary = Bls12Fp(g2.slice_to_uint(0, 16), g2.slice_to_uint(16, 48)); @@ -138,6 +144,11 @@ library BLS12G2Affine { Bls12Fp memory y_imaginary = Bls12Fp(g2.slice_to_uint(96, 112), g2.slice_to_uint(112, 144)); Bls12Fp memory y_real = Bls12Fp(g2.slice_to_uint(144, 160), g2.slice_to_uint(160, 192)); + if (b_flag) { + require(x_imaginary.is_zero() && x_real.is_zero() && y_imaginary.is_zero() && y_real.is_zero(), "!zero"); + return zero(); + } + // Require elements less than field modulus require(x_imaginary.is_valid() && x_real.is_valid() && y_imaginary.is_valid() && y_real.is_valid(), "!pnt");