From 48636e9ccec59e1452dcc28363c24464dcf26894 Mon Sep 17 00:00:00 2001 From: Oleksandr Brezhniev Date: Wed, 20 Sep 2023 22:42:10 +0100 Subject: [PATCH] Return errors instead of panics where possible --- babyjub/babyjub.go | 10 ++-- babyjub/babyjub_test.go | 78 ++++++++++++++-------------- babyjub/babyjub_wrapper.go | 9 ++-- babyjub/babyjub_wrapper_test.go | 20 ++++--- babyjub/eddsa.go | 6 +-- babyjub/eddsa_test.go | 92 ++++++++++++++++----------------- goldenposeidon/poseidon_test.go | 13 ++--- mimc7/mimc7.go | 6 +-- mimc7/mimc7_test.go | 18 ++++--- poseidon/poseidon_test.go | 8 +-- poseidon/poseidon_wrapper.go | 4 +- utils/utils.go | 8 +-- 12 files changed, 140 insertions(+), 132 deletions(-) diff --git a/babyjub/babyjub.go b/babyjub/babyjub.go index 802c770..38daa09 100644 --- a/babyjub/babyjub.go +++ b/babyjub/babyjub.go @@ -34,19 +34,19 @@ var B8 *Point // init initializes global numbers and the subgroup base. func init() { - A = utils.NewIntFromString("168700") - D = utils.NewIntFromString("168696") + A, _ = utils.NewIntFromString("168700") + D, _ = utils.NewIntFromString("168696") Aff = ff.NewElement().SetBigInt(A) Dff = ff.NewElement().SetBigInt(D) - Order = utils.NewIntFromString( + Order, _ = utils.NewIntFromString( "21888242871839275222246405745257275088614511777268538073601725287587578984328") SubOrder = new(big.Int).Rsh(Order, 3) //nolint:gomnd B8 = NewPoint() - B8.X = utils.NewIntFromString( + B8.X, _ = utils.NewIntFromString( "5299619240641551281634865583518297030282874472190772894086521144482721001553") - B8.Y = utils.NewIntFromString( + B8.Y, _ = utils.NewIntFromString( "16950150798460657717958625567821834550301663161624707787222815936182638968203") } diff --git a/babyjub/babyjub_test.go b/babyjub/babyjub_test.go index 462fc96..312ae43 100644 --- a/babyjub/babyjub_test.go +++ b/babyjub/babyjub_test.go @@ -22,15 +22,15 @@ func TestAdd1(t *testing.T) { } func TestAdd2(t *testing.T) { - aX := utils.NewIntFromString( + aX, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - aY := utils.NewIntFromString( + aY, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") a := &Point{X: aX, Y: aY} - bX := utils.NewIntFromString( + bX, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - bY := utils.NewIntFromString( + bY, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") b := &Point{X: bX, Y: bY} @@ -57,15 +57,15 @@ func TestAdd2(t *testing.T) { } func TestAdd3(t *testing.T) { - aX := utils.NewIntFromString( + aX, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - aY := utils.NewIntFromString( + aY, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") a := &Point{X: aX, Y: aY} - bX := utils.NewIntFromString( + bX, _ := utils.NewIntFromString( "16540640123574156134436876038791482806971768689494387082833631921987005038935") - bY := utils.NewIntFromString( + bY, _ := utils.NewIntFromString( "20819045374670962167435360035096875258406992893633759881276124905556507972311") b := &Point{X: bX, Y: bY} @@ -80,15 +80,15 @@ func TestAdd3(t *testing.T) { } func TestAdd4(t *testing.T) { - aX := utils.NewIntFromString( + aX, _ := utils.NewIntFromString( "0") - aY := utils.NewIntFromString( + aY, _ := utils.NewIntFromString( "1") a := &Point{X: aX, Y: aY} - bX := utils.NewIntFromString( + bX, _ := utils.NewIntFromString( "16540640123574156134436876038791482806971768689494387082833631921987005038935") - bY := utils.NewIntFromString( + bY, _ := utils.NewIntFromString( "20819045374670962167435360035096875258406992893633759881276124905556507972311") b := &Point{X: bX, Y: bY} @@ -113,12 +113,12 @@ func TestInCurve2(t *testing.T) { } func TestMul0(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} - s := utils.NewIntFromString("3") + s, _ := utils.NewIntFromString("3") r2 := NewPoint().Projective().Add(p.Projective(), p.Projective()).Affine() r2 = NewPoint().Projective().Add(r2.Projective(), p.Projective()).Affine() @@ -135,12 +135,12 @@ func TestMul0(t *testing.T) { } func TestMul1(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} - s := utils.NewIntFromString( + s, _ := utils.NewIntFromString( "14035240266687799601661095864649209771790948434046947201833777492504781204499") r := NewPoint().Mul(s, p) assert.Equal(t, @@ -152,12 +152,12 @@ func TestMul1(t *testing.T) { } func TestMul2(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} - s := utils.NewIntFromString( + s, _ := utils.NewIntFromString( "20819045374670962167435360035096875258406992893633759881276124905556507972311") r := NewPoint().Mul(s, p) assert.Equal(t, @@ -169,45 +169,45 @@ func TestMul2(t *testing.T) { } func TestInCurve3(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InCurve()) } func TestInCurve4(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InCurve()) } func TestInSubGroup1(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InSubGroup()) } func TestInSubGroup2(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InSubGroup()) } func TestPointFromSignAndy(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} @@ -219,9 +219,9 @@ func TestPointFromSignAndy(t *testing.T) { } func TestPackAndUnpackSignY(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} pComp := p.Compress() @@ -239,9 +239,9 @@ func TestPackAndUnpackSignY(t *testing.T) { } func TestCompressDecompress1(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} @@ -257,9 +257,9 @@ func TestCompressDecompress1(t *testing.T) { } func TestCompressDecompress2(t *testing.T) { - x := utils.NewIntFromString( + x, _ := utils.NewIntFromString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := utils.NewIntFromString( + y, _ := utils.NewIntFromString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} @@ -299,9 +299,9 @@ func BenchmarkBabyjub(b *testing.B) { var points [n]*Point var pointsProj [n]*PointProjective - baseX := utils.NewIntFromString( + baseX, _ := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - baseY := utils.NewIntFromString( + baseY, _ := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") base := &Point{X: baseX, Y: baseY} for i := 0; i < n; i++ { diff --git a/babyjub/babyjub_wrapper.go b/babyjub/babyjub_wrapper.go index 66e40ee..d4893cd 100644 --- a/babyjub/babyjub_wrapper.go +++ b/babyjub/babyjub_wrapper.go @@ -38,9 +38,12 @@ func NewBjjWrappedKey(privKey *PrivateKey) *BjjWrappedPrivateKey { } // RandomBjjWrappedKey creates a new BjjWrappedPrivateKey with a random private key. -func RandomBjjWrappedKey() *BjjWrappedPrivateKey { - privKey := NewRandPrivKey() - return NewBjjWrappedKey(&privKey) +func RandomBjjWrappedKey() (*BjjWrappedPrivateKey, error) { + privKey, err := NewRandPrivKey() + if err != nil { + return nil, err + } + return NewBjjWrappedKey(&privKey), nil } // Public returns the public key of the private key. diff --git a/babyjub/babyjub_wrapper_test.go b/babyjub/babyjub_wrapper_test.go index fd605ac..8b33039 100644 --- a/babyjub/babyjub_wrapper_test.go +++ b/babyjub/babyjub_wrapper_test.go @@ -27,7 +27,7 @@ func TestBjjWrappedPrivateKeyInterfaceImpl(t *testing.T) { } func TestBjjWrappedPrivateKey(t *testing.T) { - pk := RandomBjjWrappedKey() + pk, _ := RandomBjjWrappedKey() hasher, err := poseidon.New(16) require.NoError(t, err) @@ -43,13 +43,14 @@ func TestBjjWrappedPrivateKey(t *testing.T) { require.NoError(t, err) digestBI := big.NewInt(0).SetBytes(digest) - pub.pubKey.VerifyPoseidon(digestBI, decomrpessSig) + err = pub.pubKey.VerifyPoseidon(digestBI, decomrpessSig) + require.NoError(t, err) } func TestBjjWrappedPrivateKeyEqual(t *testing.T) { - x1 := RandomBjjWrappedKey() + x1, _ := RandomBjjWrappedKey() require.True(t, x1.Equal(x1)) - x2 := RandomBjjWrappedKey() + x2, _ := RandomBjjWrappedKey() require.False(t, x1.Equal(x2)) } @@ -58,8 +59,11 @@ func TestBjjWrappedPublicKeyInterfaceImpl(t *testing.T) { } func TestBjjWrappedPublicKeyEqual(t *testing.T) { - x1 := RandomBjjWrappedKey().Public().(*BjjWrappedPublicKey) - require.True(t, x1.Equal(x1)) - x2 := RandomBjjWrappedKey().Public() - require.False(t, x1.Equal(x2)) + x1, _ := RandomBjjWrappedKey() + x1pub := x1.Public().(*BjjWrappedPublicKey) + require.True(t, x1pub.Equal(x1pub)) + require.True(t, x1pub.Equal(x1.Public())) + x2, _ := RandomBjjWrappedKey() + x2pub := x2.Public() + require.False(t, x1pub.Equal(x2pub)) } diff --git a/babyjub/eddsa.go b/babyjub/eddsa.go index af77929..836edf8 100644 --- a/babyjub/eddsa.go +++ b/babyjub/eddsa.go @@ -28,13 +28,13 @@ type PrivateKey [32]byte // NewRandPrivKey generates a new random private key (using cryptographically // secure randomness). -func NewRandPrivKey() PrivateKey { +func NewRandPrivKey() (PrivateKey, error) { var k PrivateKey _, err := rand.Read(k[:]) if err != nil { - panic(err) + return PrivateKey{}, err } - return k + return k, nil } // Scalar converts a private key into the scalar value s following the EdDSA diff --git a/babyjub/eddsa_test.go b/babyjub/eddsa_test.go index e048273..16b94e7 100644 --- a/babyjub/eddsa_test.go +++ b/babyjub/eddsa_test.go @@ -28,11 +28,9 @@ func TestSignVerifyMimc7(t *testing.T) { var k PrivateKey _, err := hex.Decode(k[:], []byte("0001020304050607080900010203040506070809000102030405060708090001")) - require.Nil(t, err) + require.NoError(t, err) msgBuf, err := hex.DecodeString("00010203040506070809") - if err != nil { - panic(err) - } + require.NoError(t, err) msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf) pk := k.Public() @@ -44,7 +42,7 @@ func TestSignVerifyMimc7(t *testing.T) { pk.Y.String()) sig, err := k.SignMimc7(msg) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "11384336176656855268977457483345535180380036354188103142384839473266348197733", sig.R8.X.String()) @@ -56,11 +54,11 @@ func TestSignVerifyMimc7(t *testing.T) { sig.S.String()) err = pk.VerifyMimc7(msg, sig) - assert.NoError(t, err) + require.NoError(t, err) sigBuf := sig.Compress() sig2, err := new(Signature).Decompress(sigBuf) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, ""+ "dfedb4315d3f2eb4de2d3c510d7a987dcab67089c8ace06308827bf5bcbe02a2"+ @@ -68,18 +66,16 @@ func TestSignVerifyMimc7(t *testing.T) { hex.EncodeToString(sigBuf[:])) err = pk.VerifyMimc7(msg, sig2) - assert.NoError(t, err) + require.NoError(t, err) } func TestSignVerifyPoseidon(t *testing.T) { var k PrivateKey _, err := hex.Decode(k[:], []byte("0001020304050607080900010203040506070809000102030405060708090001")) - require.Nil(t, err) + require.NoError(t, err) msgBuf, err := hex.DecodeString("00010203040506070809") - if err != nil { - panic(err) - } + require.NoError(t, err) msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf) pk := k.Public() @@ -91,7 +87,7 @@ func TestSignVerifyPoseidon(t *testing.T) { pk.Y.String()) sig, err := k.SignPoseidon(msg) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "11384336176656855268977457483345535180380036354188103142384839473266348197733", sig.R8.X.String()) @@ -103,11 +99,11 @@ func TestSignVerifyPoseidon(t *testing.T) { sig.S.String()) err = pk.VerifyPoseidon(msg, sig) - assert.NoError(t, err) + require.NoError(t, err) sigBuf := sig.Compress() sig2, err := new(Signature).Decompress(sigBuf) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, ""+ "dfedb4315d3f2eb4de2d3c510d7a987dcab67089c8ace06308827bf5bcbe02a2"+ @@ -115,88 +111,86 @@ func TestSignVerifyPoseidon(t *testing.T) { hex.EncodeToString(sigBuf[:])) err = pk.VerifyPoseidon(msg, sig2) - assert.NoError(t, err) + require.NoError(t, err) } func TestCompressDecompress(t *testing.T) { var k PrivateKey _, err := hex.Decode(k[:], []byte("0001020304050607080900010203040506070809000102030405060708090001")) - require.Nil(t, err) + require.NoError(t, err) pk := k.Public() for i := 0; i < 64; i++ { msgBuf, err := hex.DecodeString(fmt.Sprintf("000102030405060708%02d", i)) - if err != nil { - panic(err) - } + require.NoError(t, err) msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf) sig, err := k.SignMimc7(msg) - assert.NoError(t, err) + require.NoError(t, err) sigBuf := sig.Compress() sig2, err := new(Signature).Decompress(sigBuf) - assert.NoError(t, err) + require.NoError(t, err) err = pk.VerifyMimc7(msg, sig2) - assert.NoError(t, err) + require.NoError(t, err) } } func TestSignatureCompScannerValuer(t *testing.T) { - privK := NewRandPrivKey() + privK, _ := NewRandPrivKey() var err error sig, err := privK.SignPoseidon(big.NewInt(674238462)) - assert.NoError(t, err) + require.NoError(t, err) var value driver.Valuer //nolint:gosimple // this is done to ensure interface compatibility value = sig.Compress() sig, err = privK.SignPoseidon(big.NewInt(1)) - assert.NoError(t, err) + require.NoError(t, err) scan := sig.Compress() fromDB, err := value.Value() - assert.NoError(t, err) + require.NoError(t, err) assert.Nil(t, scan.Scan(fromDB)) assert.Equal(t, value, scan) } func TestSignatureScannerValuer(t *testing.T) { - privK := NewRandPrivKey() + privK, _ := NewRandPrivKey() var value driver.Valuer var scan sql.Scanner var err error value, err = privK.SignPoseidon(big.NewInt(674238462)) - assert.NoError(t, err) + require.NoError(t, err) scan, err = privK.SignPoseidon(big.NewInt(1)) - assert.NoError(t, err) + require.NoError(t, err) fromDB, err := value.Value() - assert.NoError(t, err) + require.NoError(t, err) assert.Nil(t, scan.Scan(fromDB)) assert.Equal(t, value, scan) } func TestPublicKeyScannerValuer(t *testing.T) { - privKValue := NewRandPrivKey() + privKValue, _ := NewRandPrivKey() pubKValue := privKValue.Public() - privKScan := NewRandPrivKey() + privKScan, _ := NewRandPrivKey() pubKScan := privKScan.Public() var value driver.Valuer var scan sql.Scanner value = pubKValue scan = pubKScan fromDB, err := value.Value() - assert.Nil(t, err) + require.NoError(t, err) assert.Nil(t, scan.Scan(fromDB)) assert.Equal(t, value, scan) } func TestPublicKeyCompScannerValuer(t *testing.T) { - privKValue := NewRandPrivKey() + privKValue, _ := NewRandPrivKey() pubKCompValue := privKValue.Public().Compress() - privKScan := NewRandPrivKey() + privKScan, _ := NewRandPrivKey() pubKCompScan := privKScan.Public().Compress() var value driver.Valuer var scan sql.Scanner value = &pubKCompValue scan = &pubKCompScan fromDB, err := value.Value() - assert.Nil(t, err) + require.NoError(t, err) assert.Nil(t, scan.Scan(fromDB)) assert.Equal(t, value, scan) } @@ -205,15 +199,13 @@ func BenchmarkBabyjubEddsa(b *testing.B) { var k PrivateKey _, err := hex.Decode(k[:], []byte("0001020304050607080900010203040506070809000102030405060708090001")) - require.Nil(b, err) + require.NoError(b, err) pk := k.Public() const n = 256 msgBuf, err := hex.DecodeString("00010203040506070809") - if err != nil { - panic(err) - } + require.NoError(b, err) msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf) var msgs [n]*big.Int for i := 0; i < n; i++ { @@ -223,33 +215,39 @@ func BenchmarkBabyjubEddsa(b *testing.B) { b.Run("SignMimc7", func(b *testing.B) { for i := 0; i < b.N; i++ { - k.SignMimc7(msgs[i%n]) + _, err = k.SignMimc7(msgs[i%n]) + require.NoError(b, err) } }) for i := 0; i < n; i++ { - sigs[i%n], _ = k.SignMimc7(msgs[i%n]) + sigs[i%n], err = k.SignMimc7(msgs[i%n]) + require.NoError(b, err) } b.Run("VerifyMimc7", func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = pk.VerifyMimc7(msgs[i%n], sigs[i%n]) + err = pk.VerifyMimc7(msgs[i%n], sigs[i%n]) + require.NoError(b, err) } }) b.Run("SignPoseidon", func(b *testing.B) { for i := 0; i < b.N; i++ { - k.SignPoseidon(msgs[i%n]) + _, err = k.SignPoseidon(msgs[i%n]) + require.NoError(b, err) } }) for i := 0; i < n; i++ { - sigs[i%n], _ = k.SignPoseidon(msgs[i%n]) + sigs[i%n], err = k.SignPoseidon(msgs[i%n]) + require.NoError(b, err) } b.Run("VerifyPoseidon", func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = pk.VerifyPoseidon(msgs[i%n], sigs[i%n]) + err = pk.VerifyPoseidon(msgs[i%n], sigs[i%n]) + require.NoError(b, err) } }) } diff --git a/goldenposeidon/poseidon_test.go b/goldenposeidon/poseidon_test.go index 1d83c83..3e574ec 100644 --- a/goldenposeidon/poseidon_test.go +++ b/goldenposeidon/poseidon_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const prime uint64 = 18446744069414584321 @@ -16,7 +17,7 @@ func TestPoseidonHashCompare(t *testing.T) { h, err := Hash([NROUNDSF]uint64{b0, b0, b0, b0, b0, b0, b0, b0}, [CAPLEN]uint64{b0, b0, b0, b0}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, [CAPLEN]uint64{ 4330397376401421145, @@ -28,7 +29,7 @@ func TestPoseidonHashCompare(t *testing.T) { h, err = Hash([NROUNDSF]uint64{b1, b1, b1, b1, b1, b1, b1, b1}, [CAPLEN]uint64{b1, b1, b1, b1}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, [CAPLEN]uint64{ 16428316519797902711, @@ -40,7 +41,7 @@ func TestPoseidonHashCompare(t *testing.T) { h, err = Hash([NROUNDSF]uint64{b1, b1, b1, b1, b1, b1, b1, b1}, [CAPLEN]uint64{b1, b1, b1, b1}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, [CAPLEN]uint64{ 16428316519797902711, @@ -54,7 +55,7 @@ func TestPoseidonHashCompare(t *testing.T) { [NROUNDSF]uint64{bm1, bm1, bm1, bm1, bm1, bm1, bm1, bm1}, [CAPLEN]uint64{bm1, bm1, bm1, bm1}, ) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, [CAPLEN]uint64{ 13691089994624172887, @@ -66,7 +67,7 @@ func TestPoseidonHashCompare(t *testing.T) { h, err = Hash([NROUNDSF]uint64{bM, bM, bM, bM, bM, bM, bM, bM}, [CAPLEN]uint64{b0, b0, b0, b0}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, [CAPLEN]uint64{ 4330397376401421145, @@ -86,7 +87,7 @@ func TestPoseidonHashCompare(t *testing.T) { uint64(6254867324987), uint64(2087), }, [CAPLEN]uint64{b0, b0, b0, b0}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, [CAPLEN]uint64{ 1892171027578617759, diff --git a/mimc7/mimc7.go b/mimc7/mimc7.go index 78e2fc0..a0d1a6d 100644 --- a/mimc7/mimc7.go +++ b/mimc7/mimc7.go @@ -141,7 +141,7 @@ func Hash(arr []*big.Int, key *big.Int) (*big.Int, error) { // HashBytes hashes a msg byte slice by blocks of 31 bytes encoded as // little-endian -func HashBytes(b []byte) *big.Int { +func HashBytes(b []byte) (*big.Int, error) { n := 31 bElems := make([]*big.Int, 0, len(b)/n+1) for i := 0; i < len(b)/n; i++ { @@ -156,7 +156,7 @@ func HashBytes(b []byte) *big.Int { } h, err := Hash(bElems, nil) if err != nil { - panic(err) + return nil, err } - return h + return h, nil } diff --git a/mimc7/mimc7_test.go b/mimc7/mimc7_test.go index 3be55a9..3a2b3a5 100644 --- a/mimc7/mimc7_test.go +++ b/mimc7/mimc7_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMIMC7Generic(t *testing.T) { @@ -21,7 +22,7 @@ func TestMIMC7Generic(t *testing.T) { "10594780656576967754230020536574539122676596303354946869887184401991294982664", mhg.String()) hg, err := HashGeneric(big.NewInt(0), bigArray, 91) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "6464402164086696096195815557694604139393321133243036833927490113253119343397", hg.String()) @@ -33,11 +34,11 @@ func TestMIMC7(t *testing.T) { b78 := big.NewInt(int64(78)) b41 := big.NewInt(int64(41)) - // h1, hash of 1 elements + // h1, hash of 1 element bigArray1 := []*big.Int{b12} h1, err := Hash(bigArray1, nil) - assert.Nil(t, err) + require.NoError(t, err) // same hash value than the iden3js and circomlib tests: assert.Equal(t, "0x"+hex.EncodeToString(h1.Bytes()), "0x237c92644dbddb86d8a259e0e923aaab65a93f1ec5758b8799988894ac0958fd") @@ -46,7 +47,7 @@ func TestMIMC7(t *testing.T) { bigArray2a := []*big.Int{b78, b41} h2a, err := Hash(bigArray2a, nil) - assert.Nil(t, err) + require.NoError(t, err) // same hash value than the iden3js and circomlib tests: assert.Equal(t, "0x"+hex.EncodeToString(h2a.Bytes()), "0x067f3202335ea256ae6e6aadcd2d5f7f4b06a00b2d1e0de903980d5ab552dc70") @@ -55,12 +56,12 @@ func TestMIMC7(t *testing.T) { bigArray2b := []*big.Int{b12, b45} mh2b := MIMC7Hash(b12, b45) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "0x"+hex.EncodeToString(mh2b.Bytes()), "0x2ba7ebad3c6b6f5a20bdecba2333c63173ca1a5f2f49d958081d9fa7179c44e4") h2b, err := Hash(bigArray2b, nil) - assert.Nil(t, err) + require.NoError(t, err) // same hash value than the iden3js and circomlib tests: assert.Equal(t, "0x"+hex.EncodeToString(h2b.Bytes()), "0x15ff7fe9793346a17c3150804bcb36d161c8662b110c50f55ccb7113948d8879") @@ -69,13 +70,14 @@ func TestMIMC7(t *testing.T) { bigArray4 := []*big.Int{b12, b45, b78, b41} h4, err := Hash(bigArray4, nil) - assert.Nil(t, err) + require.NoError(t, err) // same hash value than the iden3js and circomlib tests: assert.Equal(t, "0x"+hex.EncodeToString(h4.Bytes()), "0x284bc1f34f335933a23a433b6ff3ee179d682cd5e5e2fcdd2d964afa85104beb") msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.") //nolint:lll - hmsg := HashBytes(msg) + hmsg, err := HashBytes(msg) + require.NoError(t, err) assert.Equal(t, "16855787120419064316734350414336285711017110414939748784029922801367685456065", hmsg.String()) diff --git a/poseidon/poseidon_test.go b/poseidon/poseidon_test.go index 64daa4e..68ff5ff 100644 --- a/poseidon/poseidon_test.go +++ b/poseidon/poseidon_test.go @@ -118,12 +118,12 @@ func TestInputsNotInField(t *testing.T) { var err error // Very big number, should just return error and not go into endless loop - b1 := utils.NewIntFromString("12242166908188651009877250812424843524687801523336557272219921456462821518061999999999999999999999999999999999999999999999999999999999") + b1, _ := utils.NewIntFromString("12242166908188651009877250812424843524687801523336557272219921456462821518061999999999999999999999999999999999999999999999999999999999") _, err = Hash([]*big.Int{b1}) require.Error(t, err, "inputs values not inside Finite Field") // Finite Field const Q, should return error - b2 := utils.NewIntFromString("21888242871839275222246405745257275088548364400416034343698204186575808495617") + b2, _ := utils.NewIntFromString("21888242871839275222246405745257275088548364400416034343698204186575808495617") _, err = Hash([]*big.Int{b2}) require.Error(t, err, "inputs values not inside Finite Field") } @@ -246,8 +246,8 @@ func TestSpongeHashX(t *testing.T) { func BenchmarkPoseidonHash6Inputs(b *testing.B) { b0 := big.NewInt(0) - b1 := utils.NewIntFromString("12242166908188651009877250812424843524687801523336557272219921456462821518061") - b2 := utils.NewIntFromString("12242166908188651009877250812424843524687801523336557272219921456462821518061") + b1, _ := utils.NewIntFromString("12242166908188651009877250812424843524687801523336557272219921456462821518061") + b2, _ := utils.NewIntFromString("12242166908188651009877250812424843524687801523336557272219921456462821518061") bigArray6 := []*big.Int{b1, b2, b0, b0, b0, b0} diff --git a/poseidon/poseidon_wrapper.go b/poseidon/poseidon_wrapper.go index 3198c48..c60534f 100644 --- a/poseidon/poseidon_wrapper.go +++ b/poseidon/poseidon_wrapper.go @@ -37,11 +37,11 @@ func (h *hasher) Write(p []byte) (n int, err error) { // Sum returns the Poseidon digest of the data. func (h *hasher) Sum(b []byte) []byte { - hahs, err := HashBytesX(h.buf.Bytes(), h.frameSize) + res, err := HashBytesX(h.buf.Bytes(), h.frameSize) if err != nil { panic(err) } - return append(b, hahs.Bytes()...) + return append(b, res.Bytes()...) } // Reset resets the Hash to its initial state. diff --git a/utils/utils.go b/utils/utils.go index fe0f887..eb6b1ab 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -12,13 +12,13 @@ import ( ) // NewIntFromString creates a new big.Int from a decimal integer encoded as a -// string. It will panic if the string is not a decimal integer. -func NewIntFromString(s string) *big.Int { +// string. It will return error if the string is not a decimal integer. +func NewIntFromString(s string) (*big.Int, error) { v, ok := new(big.Int).SetString(s, 10) //nolint:gomnd if !ok { - panic(fmt.Sprintf("Bad base 10 string %s", s)) + return nil, fmt.Errorf("bad base 10 string %s", s) } - return v + return v, nil } // SwapEndianness swaps the endianness of the value encoded in xs. If xs is