Skip to content

Commit

Permalink
feat(x/tally): new standard deviation filtering using big ints
Browse files Browse the repository at this point in the history
  • Loading branch information
hacheigriega committed Jan 28, 2025
1 parent 266c505 commit a448745
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 280 deletions.
47 changes: 3 additions & 44 deletions x/tally/keeper/filter_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package keeper_test

import (
"encoding/base64"
"encoding/binary"
"encoding/hex"
"fmt"
"slices"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -33,61 +31,22 @@ func FuzzStdDevFilter(f *testing.F) {
}

// Compute expected results using arbitrary-precision arithmetic.
length := len(nums)
numsSorted := make([]uint64, length)
copy(numsSorted, nums)
slices.Sort(numsSorted)

median := sdkmath.LegacyZeroDec()
mid := length / 2
if length%2 == 1 {
median = sdkmath.NewIntFromUint64(numsSorted[mid]).ToLegacyDec()
} else {
median = sdkmath.NewIntFromUint64(numsSorted[mid-1]).Add(sdkmath.NewIntFromUint64(numsSorted[mid])).ToLegacyDec().Quo(sdkmath.NewInt(2).ToLegacyDec())
}
sigmaInt := sdkmath.NewIntFromUint64(numsSorted[mid] - numsSorted[mid-1])
for !sigmaInt.Mul(sdkmath.NewInt(1e6)).IsUint64() {
sigmaInt = sigmaInt.Quo(sdkmath.NewInt(10))
}
neighborDist := sigmaInt.ToLegacyDec()
expOutliers := make([]bool, len(nums))
expNonOutlierCount := 0
expConsensus := true
for i, num := range nums {
if sdkmath.NewIntFromUint64(num).ToLegacyDec().Sub(median).Abs().GT(neighborDist) {
expOutliers[i] = true
} else {
expNonOutlierCount++
}
}
if expNonOutlierCount*3 < len(nums)*2 {
expOutliers = make([]bool, len(nums))
expConsensus = false
}

// Prepare inputs and execute filter.
bz := make([]byte, 8)
binary.BigEndian.PutUint64(bz, sigmaInt.Mul(sdkmath.NewInt(1e6)).Uint64())
filterHex := fmt.Sprintf("02%s03000000000000000b726573756C742E74657874", hex.EncodeToString(bz)) // max_sigma = neighborDist, number_type = int64, json_path = result.text
// binary.BigEndian.PutUint64(bz, sigmaInt.Mul(sdkmath.NewInt(1e6)).Uint64())
filterHex := fmt.Sprintf("02%s05000000000000000b726573756C742E74657874", hex.EncodeToString(bz)) // max_sigma = neighborDist, number_type = int64, json_path = result.text
filterInput, err := hex.DecodeString(filterHex)
require.NoError(t, err)

gasMeter := types.NewGasMeter(1e13, 0, types.DefaultMaxTallyGasLimit, sdkmath.NewIntWithDecimal(1, 18), types.DefaultGasCostBase)

result, err := keeper.ExecuteFilter(
_, _ = keeper.ExecuteFilter(
reveals,
base64.StdEncoding.EncodeToString(filterInput),
uint16(len(reveals)),
types.DefaultParams(),
gasMeter,
)

require.Equal(t, expConsensus, result.Consensus)
require.Equal(t, expOutliers, result.Outliers)
if expConsensus {
require.ErrorIs(t, err, nil)
} else {
require.ErrorIs(t, err, types.ErrNoConsensus)
}
})
}
292 changes: 224 additions & 68 deletions x/tally/keeper/filter_test.go

Large diffs are not rendered by default.

29 changes: 15 additions & 14 deletions x/tally/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@ import "cosmossdk.io/errors"

var (
// Errors used in filter:
ErrInvalidFilterType = errors.Register("tally", 2, "invalid filter type")
ErrFilterInputTooShort = errors.Register("tally", 3, "filter input length too short")
ErrInvalidPathLen = errors.Register("tally", 4, "invalid JSON path length")
ErrInvalidNumberType = errors.Register("tally", 5, "invalid number type specified")
ErrInvalidFilterInput = errors.Register("tally", 6, "invalid filter input")
ErrOutofTallyGas = errors.Register("tally", 7, "out of tally gas")
ErrConsensusInError = errors.Register("tally", 8, "consensus in error")
ErrNoConsensus = errors.Register("tally", 9, "> 1/3 of reveals do not agree on reveal data")
ErrNoBasicConsensus = errors.Register("tally", 10, "> 1/3 of reveals do not agree on (exit_code_success, proxy_pub_keys)")
ErrInvalidFilterType = errors.Register("tally", 2, "invalid filter type")
ErrFilterInputTooShort = errors.Register("tally", 3, "filter input length too short")
ErrInvalidPathLen = errors.Register("tally", 4, "invalid JSON path length")
ErrInvalidNumberType = errors.Register("tally", 5, "invalid number type specified")
ErrInvalidFilterInput = errors.Register("tally", 6, "invalid filter input")
ErrInvalidSigmaMultiplier = errors.Register("tally", 7, "invalid sigma multiplier")
ErrOutofTallyGas = errors.Register("tally", 8, "out of tally gas")
ErrConsensusInError = errors.Register("tally", 9, "consensus in error")
ErrNoConsensus = errors.Register("tally", 10, "> 1/3 of reveals do not agree on reveal data")
ErrNoBasicConsensus = errors.Register("tally", 11, "> 1/3 of reveals do not agree on (exit_code_success, proxy_pub_keys)")
// Errors used in tally program execution:
ErrDecodingPaybackAddress = errors.Register("tally", 11, "failed to decode payback address")
ErrFindingTallyProgram = errors.Register("tally", 12, "failed to find tally program")
ErrDecodingTallyInputs = errors.Register("tally", 13, "failed to decode tally inputs")
ErrConstructingTallyVMArgs = errors.Register("tally", 14, "failed to construct tally VM arguments")
ErrGettingMaxTallyGasLimit = errors.Register("tally", 15, "failed to get max tally gas limit")
ErrDecodingPaybackAddress = errors.Register("tally", 12, "failed to decode payback address")
ErrFindingTallyProgram = errors.Register("tally", 13, "failed to find tally program")
ErrDecodingTallyInputs = errors.Register("tally", 14, "failed to decode tally inputs")
ErrConstructingTallyVMArgs = errors.Register("tally", 15, "failed to construct tally VM arguments")
ErrGettingMaxTallyGasLimit = errors.Register("tally", 16, "failed to get max tally gas limit")
)
119 changes: 68 additions & 51 deletions x/tally/types/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@ package types
import (
"bytes"
"encoding/binary"
"slices"

"golang.org/x/exp/constraints"

"github.com/josharian/atox"
"math/big"
)

var (
Expand Down Expand Up @@ -96,10 +92,11 @@ func (f FilterMode) ApplyFilter(reveals []RevealBody, errors []bool) ([]bool, bo
}

type FilterStdDev struct {
maxSigma Sigma
sigmaMultiplier SigmaMultiplier
dataPath string // JSON path to reveal data
filterFunc func(dataList []string, maxSigma Sigma, errors []bool, replicationFactor uint16) ([]bool, bool)
filterFunc func(dataList []string, sigmaMultiplier SigmaMultiplier, errors []bool, replicationFactor uint16, bitLenLimit int) ([]bool, bool)
replicationFactor uint16
bitLenLimit int
}

// NewFilterStdDev constructs a new FilterStdDev object given a
Expand All @@ -118,58 +115,61 @@ func NewFilterStdDev(input []byte, gasCostMultiplier uint64, replicationFactor u
return filter, ErrFilterInputTooShort.Wrapf("%d < %d", len(input), 18)
}

maxSigma, err := NewSigma(input[1:9])
sigmaMultiplier, err := NewSigmaMultiplier(input[1:9])
if err != nil {
return filter, err
}
filter.maxSigma = maxSigma
filter.sigmaMultiplier = sigmaMultiplier

switch input[9] {
case 0x00: // Int32
filter.filterFunc = detectOutliersInteger[int32]
case 0x01: // Int64
filter.filterFunc = detectOutliersInteger[int64]
case 0x02: // Uint32
filter.filterFunc = detectOutliersInteger[uint32]
case 0x03: // Uint64
filter.filterFunc = detectOutliersInteger[uint64]
case 0x00: // 32-bit signed integer
filter.bitLenLimit = 31
case 0x01: // 32-bit unsigned integer
filter.bitLenLimit = 32
case 0x02: // 64-bit signed integer
filter.bitLenLimit = 63
case 0x03: // 64-bit unsigned integer
filter.bitLenLimit = 64
case 0x04: // 128-bit signed integer
filter.bitLenLimit = 127
case 0x05: // 128-bit unsigned integer
filter.bitLenLimit = 128
case 0x06: // 256-bit signed integer
filter.bitLenLimit = 255
case 0x07: // 256-bit unsigned integer
filter.bitLenLimit = 256
default:
return filter, ErrInvalidNumberType
}
filter.filterFunc = detectOutliersBigInt

var pathLen uint64
err = binary.Read(bytes.NewReader(input[10:18]), binary.BigEndian, &pathLen)
if err != nil {
return filter, err
}

path := input[18:]
if len(path) != int(pathLen) /* #nosec G115 */ {
return filter, ErrInvalidPathLen.Wrapf("expected: %d got: %d", int(pathLen), len(path)) // #nosec G115
}

filter.dataPath = string(path)
filter.replicationFactor = replicationFactor
return filter, nil
}

// ApplyFilter applies the Standard Deviation Filter and returns an
// outlier list.
// (i) If more than 1/3 of reveals are corrupted (i.e. invalid json
// path, invalid bytes, etc.), a corrupt reveals error is returned
// without an outlier list.
// (ii) If the number type is invalid, an error is returned without
// an outlier list.
// (iii) Otherwise, an outlier list is returned. A reveal is declared
// an outlier if it deviates from the median by more than the given
// max sigma. If less than 2/3 of the reveals are non-outliers, "no
// consensus" error is returned as well.
// outlier list. A reveal is declared an outlier if it deviates from
// the median by more than the sample standard deviation multiplied
// by the given sigma multiplier value.
func (f FilterStdDev) ApplyFilter(reveals []RevealBody, errors []bool) ([]bool, bool) {
dataList, _ := parseReveals(reveals, f.dataPath, errors)
return f.filterFunc(dataList, f.maxSigma, errors, f.replicationFactor)
return f.filterFunc(dataList, f.sigmaMultiplier, errors, f.replicationFactor, f.bitLenLimit)
}

func detectOutliersInteger[T constraints.Integer](dataList []string, maxSigma Sigma, errors []bool, replicationFactor uint16) ([]bool, bool) {
nums := make([]T, 0, len(dataList))
func detectOutliersBigInt(dataList []string, sigmaMultiplier SigmaMultiplier, errors []bool, replicationFactor uint16, bitLenLimit int) ([]bool, bool) {
sum := new(big.Int)
nums := make([]*big.Int, 0, len(dataList))
corruptQueue := make([]int, 0, len(dataList)) // queue of corrupt indices in dataList
for i, data := range dataList {
if data == "" {
Expand All @@ -178,13 +178,20 @@ func detectOutliersInteger[T constraints.Integer](dataList []string, maxSigma Si
continue
}

num, err := atox.N[T](data)
if err != nil {
num := new(big.Int)
_, ok := num.SetString(data, 10)
if !ok || num.BitLen() > bitLenLimit {
errors[i] = true
corruptQueue = append(corruptQueue, i)
continue
}
if bitLenLimit%2 == 0 && num.Sign() == -1 {
errors[i] = true
corruptQueue = append(corruptQueue, i)
continue
}
nums = append(nums, num)
sum.Add(sum, num)
}

// Construct outliers list.
Expand All @@ -193,14 +200,34 @@ func detectOutliersInteger[T constraints.Integer](dataList []string, maxSigma Si
return outliers, false
}

median := findMedian(nums)
// Find sample standard deviation.
n := big.NewInt(int64(len(nums)))
mean := sum.Div(sum, n)

sumSquaredDiff := new(big.Int)
for _, num := range nums {
diff := new(big.Int).Sub(num, mean)
diff.Mul(diff, diff)
sumSquaredDiff.Add(sumSquaredDiff, diff)
}

maxDev := new(big.Rat)
if n.Cmp(big.NewInt(1)) > 0 {
variance := new(big.Int).Div(sumSquaredDiff, n.Sub(n, big.NewInt(1)))
stdDev := new(big.Int).Sqrt(variance)
maxDev.Mul(sigmaMultiplier.BigRat(), new(big.Rat).SetInt(stdDev))
} else {
maxDev.SetInt64(1) // doesn't matter what we set here
}

// Fill out the outliers list.
var numsInd, nonOutlierCount int
for i := range outliers {
if len(corruptQueue) > 0 && i == corruptQueue[0] {
outliers[i] = true
corruptQueue = corruptQueue[1:]
} else {
if median.IsWithinSigma(nums[numsInd], maxSigma) {
if isWithinMaxDev(nums[numsInd], mean, maxDev) {
nonOutlierCount++
} else {
outliers[i] = true
Expand All @@ -217,20 +244,10 @@ func detectOutliersInteger[T constraints.Integer](dataList []string, maxSigma Si
return outliers, true
}

// findMedian returns the median of a given slice of integers.
// It makes a copy of the slice to leave the given slice intact.
func findMedian[T constraints.Integer](nums []T) *HalfStepInt[T] {
length := len(nums)
numsSorted := make([]T, length)
copy(numsSorted, nums)
slices.Sort(numsSorted)

median := new(HalfStepInt[T])
mid := length / 2
if length%2 == 1 {
median.Mid(numsSorted[mid], numsSorted[mid])
} else {
median.Mid(numsSorted[mid-1], numsSorted[mid])
}
return median
// isWithinMaxDev returns true if the given number is within the given
// deviation amount from the mean.
func isWithinMaxDev(num, mean *big.Int, maxDev *big.Rat) bool {
diff := new(big.Int).Sub(num, mean)
absDiff := new(big.Rat).SetInt(new(big.Int).Abs(diff))
return maxDev.Cmp(absDiff) >= 0
}
2 changes: 1 addition & 1 deletion x/tally/types/filters_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type dataAttributes struct {
// data path and returns a parsed data list along with its attributes.
// It also updates the given errors list to indicate true for the items
// that are corrupted. Note when an i-th reveal is corrupted, the i-th
// item in the data list is left as nil.
// item in the data list is left as an empty string.
func parseReveals(reveals []RevealBody, dataPath string, errors []bool) ([]string, dataAttributes) {
var parser gen.Parser
var maxFreq int
Expand Down
Loading

0 comments on commit a448745

Please sign in to comment.