Skip to content

Commit

Permalink
refactor: remove rebuildShares
Browse files Browse the repository at this point in the history
  • Loading branch information
rootulp committed Jun 28, 2023
1 parent 2557620 commit 7664a4d
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 95 deletions.
16 changes: 13 additions & 3 deletions codecs.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package rsmt2d

import "fmt"
import (
"errors"
"fmt"
)

const (
// Leopard is a codec that was originally implemented in the C++ library
Expand All @@ -15,8 +18,11 @@ type Codec interface {
// Encode encodes original data, automatically extracting share size.
// There must be no missing shares. Only returns parity shares.
Encode(data [][]byte) ([][]byte, error)
// Decode decodes sparse original + parity data, automatically extracting share size.
// Missing shares must be nil. Returns original + parity data.
// Decode attempts to reconstruct the missing shards in data. The data
// parameter should contain all original + parity shards where missing
// shards should be `nil`. If reconstruction is successful, the original +
// parity shards are returned. Returns ErrTooFewShards if not enough non-nil
// shards exist in data to reconstruct the missing shards.
Decode(data [][]byte) ([][]byte, error)
// MaxChunks returns the max. number of chunks each code supports in a 2D square.
MaxChunks() int
Expand All @@ -33,3 +39,7 @@ func registerCodec(ct string, codec Codec) {
}
codecs[ct] = codec
}

// ErrTooFewShards is returned by Decode if too few shards exist in the data to
// reconstruct the `nil` shards.
var ErrTooFewShards = errors.New("too few shards given to reconstruct all the shards in data")
46 changes: 18 additions & 28 deletions extendeddatacrossword.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,18 @@ func (eds *ExtendedDataSquare) solveCrosswordRow(
shares[c] = vectorData[c]
}

// Attempt rebuild
rebuiltShares, isDecoded, err := eds.rebuildShares(shares)
// Attempt to rebuild the shards in this row.
rebuiltShares, err := eds.codec.Decode(shares)
if err != nil {
if err == ErrTooFewShards {
// Decode was unsuccessful for this iteration but don't propagate the
// error because that would halt the progress of solveCrossword.
return false, false, nil
}
// Otherwise, Decode was unsuccessful for some other reason and we
// should propagate the error.
return false, false, err
}
if !isDecoded {
return false, false, nil
}

// Check that rebuilt shares matches appropriate root
err = eds.verifyAgainstRowRoots(rowRoots, uint(r), rebuiltShares, noShareInsertion, nil)
Expand Down Expand Up @@ -200,14 +204,18 @@ func (eds *ExtendedDataSquare) solveCrosswordCol(

}

// Attempt rebuild
rebuiltShares, isDecoded, err := eds.rebuildShares(shares)
// Attempt to rebuild the shards in this column.
rebuiltShares, err := eds.codec.Decode(shares)
if err != nil {
if err == ErrTooFewShards {
// Decode was unsuccessful for this iteration but don't propagate the
// error because that would halt the progress of solveCrossword.
return false, false, nil
}
// Otherwise, Decode was unsuccessful for some other reason and we
// should propagate the error.
return false, false, err
}
if !isDecoded {
return false, false, nil
}

// Check that rebuilt shares matches appropriate root
err = eds.verifyAgainstColRoots(colRoots, uint(c), rebuiltShares, noShareInsertion, nil)
Expand Down Expand Up @@ -241,24 +249,6 @@ func (eds *ExtendedDataSquare) solveCrosswordCol(
return true, true, nil
}

// rebuildShares attempts to rebuild a row or column of shares.
// Returns
// 1. An entire row or column of shares so original + parity shares.
// 2. Whether the original shares could be decoded from the shares parameter.
// 3. [Optional] an error.
func (eds *ExtendedDataSquare) rebuildShares(
shares [][]byte,
) ([][]byte, bool, error) {
rebuiltShares, err := eds.codec.Decode(shares)
if err != nil {
// Decode was unsuccessful but don't propagate the error because that
// would halt the progress of solveCrosswordRow or solveCrosswordCol.
return nil, false, nil
}

return rebuiltShares, true, nil
}

func (eds *ExtendedDataSquare) verifyAgainstRowRoots(
rowRoots [][]byte,
r uint,
Expand Down
110 changes: 47 additions & 63 deletions extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// PseudoFraudProof is an example fraud proof.
Expand All @@ -19,70 +20,53 @@ type PseudoFraudProof struct {
}

func TestRepairExtendedDataSquare(t *testing.T) {
bufferSize := 64
tests := []struct {
name string
// Size of each share, in bytes
shareSize int
codec Codec
}{
{"leopard", bufferSize, NewLeoRSCodec()},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
name, codec, shareSize := test.name, test.codec, test.shareSize
original := createTestEds(codec, shareSize)

rowRoots := original.RowRoots()
colRoots := original.ColRoots()

// Verify that an EDS can be repaired after the maximum amount of erasures
t.Run("MaximumErasures", func(t *testing.T) {
flattened := original.Flattened()
flattened[0], flattened[2], flattened[3] = nil, nil, nil
flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil
flattened[8], flattened[9], flattened[10] = nil, nil, nil
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}

err = eds.Repair(rowRoots, colRoots)
if err != nil {
t.Errorf("unexpected err while repairing data square: %v, codec: :%s", err, name)
} else {
assert.Equal(t, original.GetCell(0, 0), bytes.Repeat([]byte{1}, shareSize))
assert.Equal(t, original.GetCell(0, 1), bytes.Repeat([]byte{2}, shareSize))
assert.Equal(t, original.GetCell(1, 0), bytes.Repeat([]byte{3}, shareSize))
assert.Equal(t, original.GetCell(1, 1), bytes.Repeat([]byte{4}, shareSize))
}
})

// Verify that an EDS returns an error when there are too many erasures
t.Run("Unrepairable", func(t *testing.T) {
flattened := original.Flattened()
flattened[0], flattened[2], flattened[3] = nil, nil, nil
flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil
flattened[8], flattened[9], flattened[10] = nil, nil, nil
flattened[12], flattened[13], flattened[14] = nil, nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
if err != nil {
t.Errorf("ImportExtendedDataSquare failed: %v", err)
}
codec := NewLeoRSCodec()
shareSize := 64

err = eds.Repair(rowRoots, colRoots)
if err != ErrUnrepairableDataSquare {
t.Errorf("did not return an error on trying to repair an unrepairable square")
}
})
})
}
// Verify that an EDS can be repaired after the maximum amount of erasures
t.Run("MaximumErasures", func(t *testing.T) {
original := createTestEds(codec, shareSize)
rowRoots := original.RowRoots()
colRoots := original.ColRoots()

flattened := original.Flattened()
flattened[0], flattened[2], flattened[3] = nil, nil, nil
flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil
flattened[8], flattened[9], flattened[10] = nil, nil, nil
flattened[12], flattened[13] = nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
require.NoError(t, err)

err = eds.Repair(rowRoots, colRoots)
require.NoError(t, err)

assert.Equal(t, original.GetCell(0, 0), bytes.Repeat([]byte{1}, shareSize))
assert.Equal(t, original.GetCell(0, 1), bytes.Repeat([]byte{2}, shareSize))
assert.Equal(t, original.GetCell(1, 0), bytes.Repeat([]byte{3}, shareSize))
assert.Equal(t, original.GetCell(1, 1), bytes.Repeat([]byte{4}, shareSize))
})

// Verify that an EDS returns an error when there are too many erasures
t.Run("Unrepairable", func(t *testing.T) {
original := createTestEds(codec, shareSize)
rowRoots := original.RowRoots()
colRoots := original.ColRoots()

flattened := original.Flattened()
flattened[0], flattened[2], flattened[3] = nil, nil, nil
flattened[4], flattened[5], flattened[6], flattened[7] = nil, nil, nil, nil
flattened[8], flattened[9], flattened[10] = nil, nil, nil
flattened[12], flattened[13], flattened[14] = nil, nil, nil

// Re-import the data square.
eds, err := ImportExtendedDataSquare(flattened, codec, NewDefaultTree)
require.NoError(t, err)

err = eds.Repair(rowRoots, colRoots)
assert.ErrorAs(t, err, &ErrUnrepairableDataSquare)
})
}

func TestValidFraudProof(t *testing.T) {
Expand Down
13 changes: 12 additions & 1 deletion leopard.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,25 @@ func (l *leoRSCodec) Encode(data [][]byte) ([][]byte, error) {
return shards[dataLen:], nil
}

// Decode attempts to reconstruct the missing shards in data. The data
// parameter should contain all original + parity shards where missing
// shards should be `nil`. If reconstruction is successful, the original +
// parity shards are returned. Returns ErrTooFewShards if not enough non-nil
// shards exist in data to reconstruct the missing shards.
func (l *leoRSCodec) Decode(data [][]byte) ([][]byte, error) {
half := len(data) / 2
enc, err := l.loadOrInitEncoder(half)
if err != nil {
return nil, err
}
err = enc.Reconstruct(data)
return data, err
if err == reedsolomon.ErrTooFewShards || err == reedsolomon.ErrShardNoData {
return nil, ErrTooFewShards
}
if err != nil {
return nil, err
}
return data, nil
}

func (l *leoRSCodec) loadOrInitEncoder(dataLen int) (reedsolomon.Encoder, error) {
Expand Down

0 comments on commit 7664a4d

Please sign in to comment.