Skip to content

Commit

Permalink
Add support to return multiple states (#66)
Browse files Browse the repository at this point in the history
* Add support to return multiple states

Signed-off-by: Jim Zhang <[email protected]>

* Add checks on the nOuts

Signed-off-by: Jim Zhang <[email protected]>

* Change hashWithState() to HashWithStateEx()

Signed-off-by: Jim Zhang <[email protected]>

* Add tests for HashWithStateEx()

Signed-off-by: Jim Zhang <[email protected]>

---------

Signed-off-by: Jim Zhang <[email protected]>
  • Loading branch information
jimthematrix authored Sep 10, 2024
1 parent fb1d252 commit d59dca8
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 3 deletions.
26 changes: 23 additions & 3 deletions poseidon/poseidon.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,24 @@ func mix(state []*ff.Element, t int, m [][]*ff.Element) []*ff.Element {

// HashWithState computes the Poseidon hash for the given inputs and initState
func HashWithState(inpBI []*big.Int, initState *big.Int) (*big.Int, error) {
res, err := HashWithStateEx(inpBI, initState, 1)
if err != nil {
return nil, err
}
return res[0], nil
}

func HashWithStateEx(inpBI []*big.Int, initState *big.Int, nOuts int) ([]*big.Int, error) {
t := len(inpBI) + 1
if len(inpBI) == 0 || len(inpBI) > len(NROUNDSP) {
return nil, fmt.Errorf("invalid inputs length %d, max %d", len(inpBI), len(NROUNDSP))
}
if !utils.CheckBigIntArrayInField(inpBI) {
return nil, errors.New("inputs values not inside Finite Field")
}
if nOuts < 1 || nOuts > t {
return nil, fmt.Errorf("invalid nOuts %d, min 1, max %d", nOuts, t)
}
inp := utils.BigIntArrayToElementArray(inpBI)

nRoundsF := NROUNDSF
Expand Down Expand Up @@ -125,9 +136,12 @@ func HashWithState(inpBI []*big.Int, initState *big.Int) (*big.Int, error) {
exp5state(state)
state = mix(state, t, M)

rE := state[0]
r := big.NewInt(0)
rE.ToBigIntRegular(r)
r := make([]*big.Int, nOuts)
for i := 0; i < nOuts; i++ {
rE := state[i]
r[i] = big.NewInt(0)
rE.ToBigIntRegular(r[i])
}
return r, nil
}

Expand All @@ -136,6 +150,12 @@ func Hash(inpBI []*big.Int) (*big.Int, error) {
return HashWithState(inpBI, big.NewInt(0))
}

// HashEx computes the Poseidon hash for the given inputs and returns
// the first nOuts outputs that include intermediate states
func HashEx(inpBI []*big.Int, nOuts int) ([]*big.Int, error) {
return HashWithStateEx(inpBI, big.NewInt(0), nOuts)
}

// HashBytes returns a sponge hash of a msg byte slice split into blocks of 31 bytes
func HashBytes(msg []byte) (*big.Int, error) {
return HashBytesX(msg, spongeInputs)
Expand Down
225 changes: 225 additions & 0 deletions poseidon/poseidon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,136 @@ func TestPoseidonHash(t *testing.T) {
h.String())
}

func TestPoseidonHashEx(t *testing.T) {
b0 := big.NewInt(0)
b1 := big.NewInt(1)
b2 := big.NewInt(2)
b3 := big.NewInt(3)
b4 := big.NewInt(4)
b5 := big.NewInt(5)
b6 := big.NewInt(6)
b7 := big.NewInt(7)
b8 := big.NewInt(8)
b9 := big.NewInt(9)
b10 := big.NewInt(10)
b11 := big.NewInt(11)
b12 := big.NewInt(12)
b13 := big.NewInt(13)
b14 := big.NewInt(14)
b15 := big.NewInt(15)
b16 := big.NewInt(16)

h, err := HashEx([]*big.Int{b1}, 1)
assert.Nil(t, err)
assert.Equal(t, 1, len(h))
assert.Equal(t,
"18586133768512220936620570745912940619677854269274689475585506675881198879027",
h[0].String())

h, err = HashEx([]*big.Int{b1, b2}, 2)
assert.Nil(t, err)
assert.Equal(t, 2, len(h))
assert.Equal(t,
"7853200120776062878684798364095072458815029376092732009249414926327459813530",
h[0].String())
assert.Equal(t,
"7142104613055408817911962100316808866448378443474503659992478482890339429929",
h[1].String())

h, err = HashEx([]*big.Int{b1, b2, b0, b0, b0}, 3)
assert.Nil(t, err)
assert.Equal(t,
"1018317224307729531995786483840663576608797660851238720571059489595066344487",
h[0].String())
assert.Equal(t, 3, len(h))

h, err = HashEx([]*big.Int{b1, b2, b0, b0, b0, b0}, 4)
assert.Nil(t, err)
assert.Equal(t,
"15336558801450556532856248569924170992202208561737609669134139141992924267169",
h[0].String())
assert.Equal(t, 4, len(h))

h, err = HashEx([]*big.Int{b3, b4, b0, b0, b0}, 5)
assert.Nil(t, err)
assert.Equal(t,
"5811595552068139067952687508729883632420015185677766880877743348592482390548",
h[0].String())
assert.Equal(t, 5, len(h))

h, err = HashEx([]*big.Int{b3, b4, b0, b0, b0, b0}, 6)
assert.Nil(t, err)
assert.Equal(t,
"12263118664590987767234828103155242843640892839966517009184493198782366909018",
h[0].String())
assert.Equal(t, 6, len(h))

h, err = HashEx([]*big.Int{b1, b2, b3, b4, b5, b6}, 7)
assert.Nil(t, err)
assert.Equal(t,
"20400040500897583745843009878988256314335038853985262692600694741116813247201",
h[0].String())
assert.Equal(t, 7, len(h))

h, err = HashEx([]*big.Int{b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14}, 8)
assert.Nil(t, err)
assert.Equal(t,
"8354478399926161176778659061636406690034081872658507739535256090879947077494",
h[0].String())
assert.Equal(t, 8, len(h))

h, err = HashEx([]*big.Int{b1, b2, b3, b4, b5, b6, b7, b8, b9, b0, b0, b0, b0, b0}, 9)
assert.Nil(t, err)
assert.Equal(t,
"5540388656744764564518487011617040650780060800286365721923524861648744699539",
h[0].String())
assert.Equal(t, 9, len(h))

h, err = HashEx([]*big.Int{b1, b2, b3, b4, b5, b6, b7, b8, b9, b0, b0, b0, b0, b0, b0, b0}, 10)
assert.Nil(t, err)
assert.Equal(t,
"11882816200654282475720830292386643970958445617880627439994635298904836126497",
h[0].String())
assert.Equal(t, 10, len(h))

h, err = HashEx([]*big.Int{b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16}, 11)
assert.Nil(t, err)
assert.Equal(t, 11, len(h))
assert.Equal(t,
"9989051620750914585850546081941653841776809718687451684622678807385399211877",
h[0].String())
assert.Equal(t,
"8319791455060392555425392842391403897548969645190976863995973180967774875286",
h[1].String())
assert.Equal(t,
"21636406227810893698117978732800647815305553312233448361627674958309476058692",
h[2].String())
assert.Equal(t,
"5858261170370825589990804751061473291946977191299454947182890419569833191564",
h[3].String())
assert.Equal(t,
"9379453522659079974536893534601645512603628658741037060370899250203068088821",
h[4].String())
assert.Equal(t,
"473570682425071423656832074606161521036781375454126861176650950315985887926",
h[5].String())
assert.Equal(t,
"6579803930273263668667567320853266118141819373699554146671374489258288008348",
h[6].String())
assert.Equal(t,
"19782381913414087710766737863494215505205430771941455097533197858199467016164",
h[7].String())
assert.Equal(t,
"16057750626779488870446366989248320873718232843994532204040561017822304578116",
h[8].String())
assert.Equal(t,
"18984357576272539606133217260692170661113104846539835604742079547853774113837",
h[9].String())
assert.Equal(t,
"6999414602732066348339779277600222355871064730107676749892229157577448591106",
h[10].String())
}

func TestErrorInputs(t *testing.T) {
b0 := big.NewInt(0)
b1 := big.NewInt(1)
Expand All @@ -112,6 +242,12 @@ func TestErrorInputs(t *testing.T) {
_, err = Hash([]*big.Int{b1, b2, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0, b0})
require.NotNil(t, err)
assert.Equal(t, "invalid inputs length 18, max 16", err.Error())

_, err = HashEx([]*big.Int{b1, b2}, 0)
assert.EqualError(t, err, "invalid nOuts 0, min 1, max 3")

_, err = HashEx([]*big.Int{b1, b2}, 4)
assert.EqualError(t, err, "invalid nOuts 4, min 1, max 3")
}

func TestInputsNotInField(t *testing.T) {
Expand Down Expand Up @@ -169,6 +305,95 @@ func TestHashWithState(t *testing.T) {
h.String())
}

func TestHashWithStateEx(t *testing.T) {
initState0 := big.NewInt(0)
initState1 := big.NewInt(7)

b1 := big.NewInt(1)
b2 := big.NewInt(2)
b3 := big.NewInt(3)
b4 := big.NewInt(4)
b5 := big.NewInt(5)
b6 := big.NewInt(6)
b7 := big.NewInt(7)
b8 := big.NewInt(8)
b9 := big.NewInt(9)
b10 := big.NewInt(10)
b11 := big.NewInt(11)
b12 := big.NewInt(12)
b13 := big.NewInt(13)
b14 := big.NewInt(14)
b15 := big.NewInt(15)
b16 := big.NewInt(16)
b17 := big.NewInt(17)

h, err := HashWithStateEx([]*big.Int{b1, b2, b3, b4, b5, b6}, initState0, 6)
assert.Nil(t, err)
assert.Equal(t, 6, len(h))
assert.Equal(t,
"20400040500897583745843009878988256314335038853985262692600694741116813247201",
h[0].String())

h, err = HashWithStateEx([]*big.Int{b1, b2, b3, b4}, initState1, 4)
assert.Nil(t, err)
assert.Equal(t, 4, len(h))
assert.Equal(t,
"1569211601569591254857354699102545060324851338714426496554851741114291465006",
h[0].String())

h, err = HashWithStateEx([]*big.Int{b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16}, b17, 16)
assert.Nil(t, err)
assert.Equal(t, 16, len(h))
assert.Equal(t,
"7865037705064445207187340054656830232157001572238023180016026650118519857086",
h[0].String())
assert.Equal(t,
"9292383997006336854008325030029058442489692927472584277596649832441082093099",
h[1].String())
assert.Equal(t,
"21700625464938935909463291795162623951575229166945244593449711331894544619498",
h[2].String())
assert.Equal(t,
"1749964961100464837642084889776091157070407086051097880220367435814831060919",
h[3].String())
assert.Equal(t,
"14926884742736943105557530036865339747160219875259470496706517357951967126770",
h[4].String())
assert.Equal(t,
"2039691552066237153485547245250552033884196017621501609319319339955236135906",
h[5].String())
assert.Equal(t,
"15632370980418377873678240526508190824831030254352022226082241110936555130543",
h[6].String())
assert.Equal(t,
"12415717486933552680955550946925876656737401305417786097937904386023163034597",
h[7].String())
assert.Equal(t,
"19518791782429957526810500613963817986723905805167983704284231822835104039583",
h[8].String())
assert.Equal(t,
"3946357499058599914103088366834769377007694643795968939540941315474973940815",
h[9].String())
assert.Equal(t,
"5618081863604788554613937982328324792980580854673130938690864738082655170455",
h[10].String())
assert.Equal(t,
"9119013501536010391475078939286676645280972023937320238963975266387024327421",
h[11].String())
assert.Equal(t,
"8377736769906336164136520530350338558030826788688113957410934156526990238336",
h[12].String())
assert.Equal(t,
"15295058061474937220002017533551270394267030149562824985607747654793981405060",
h[13].String())
assert.Equal(t,
"3767094797637425204201844274463024412131937665868967358407323347727519975724",
h[14].String())
assert.Equal(t,
"11046361685833871233801453306150294246339755171874771935347992312124050338976",
h[15].String())
}

func TestInitStateNotInField(t *testing.T) {
var err error

Expand Down

0 comments on commit d59dca8

Please sign in to comment.