Skip to content

Commit

Permalink
Merge commit from fork
Browse files Browse the repository at this point in the history
* Fix valid decimal range

* Review feedback

* Update changelog; retract versions

---------

Co-authored-by: Julien Robert <[email protected]>
  • Loading branch information
alpe and julienrbrt authored Nov 20, 2024
1 parent 1effb80 commit c6522a7
Show file tree
Hide file tree
Showing 4 changed files with 330 additions and 62 deletions.
9 changes: 9 additions & 0 deletions math/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,17 @@ Ref: https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.j

## [Unreleased]

## [math/v1.4.0](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.4.0) - 2024-01-20

### Features

* [#20034](https://github.com/cosmos/cosmos-sdk/pull/20034) Significantly speedup LegacyDec.QuoTruncate and LegacyDec.QuoRoundUp.

### Bug fixes

* Fix [ASA-2024-010: Math](https://github.com/cosmos/cosmos-sdk/security/advisories/GHSA-7225-m954-23v7) Bit length differences between Int and Dec


## [math/v1.3.0](https://github.com/cosmos/cosmos-sdk/releases/tag/math/v1.3.0) - 2024-02-22

### Features
Expand Down
108 changes: 50 additions & 58 deletions math/dec.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,20 @@ const (

// LegacyDecimalPrecisionBits bits required to represent the above precision
// Ceiling[Log2[10^Precision - 1]]
// Deprecated: This is unused and will be removed
LegacyDecimalPrecisionBits = 60

// decimalTruncateBits is the minimum number of bits removed
// by a truncate operation. It is equal to
// Floor[Log2[10^Precision - 1]].
decimalTruncateBits = LegacyDecimalPrecisionBits - 1

maxDecBitLen = MaxBitLen + decimalTruncateBits

// maxApproxRootIterations max number of iterations in ApproxRoot function
maxApproxRootIterations = 300
)

var (
precisionReuse = new(big.Int).Exp(big.NewInt(10), big.NewInt(LegacyPrecision), nil)
fivePrecision = new(big.Int).Quo(precisionReuse, big.NewInt(2))
precisionReuse = new(big.Int).Exp(big.NewInt(10), big.NewInt(LegacyPrecision), nil)
fivePrecision = new(big.Int).Quo(precisionReuse, big.NewInt(2))

upperLimit LegacyDec
lowerLimit LegacyDec

precisionMultipliers []*big.Int
zeroInt = big.NewInt(0)
oneInt = big.NewInt(1)
Expand All @@ -58,6 +56,11 @@ func init() {
for i := 0; i <= LegacyPrecision; i++ {
precisionMultipliers[i] = calcPrecisionMultiplier(int64(i))
}
// 2^256 * 10^18 -1
tmp := new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil)
tmp = new(big.Int).Sub(new(big.Int).Mul(tmp, precisionReuse), big.NewInt(1))
upperLimit = LegacyNewDecFromBigIntWithPrec(tmp, LegacyPrecision)
lowerLimit = upperLimit.Neg()
}

func precisionInt() *big.Int {
Expand Down Expand Up @@ -191,14 +194,15 @@ func LegacyNewDecFromStr(str string) (LegacyDec, error) {
if !ok {
return LegacyDec{}, fmt.Errorf("failed to set decimal string with base 10: %s", combinedStr)
}
if combined.BitLen() > maxDecBitLen {
return LegacyDec{}, fmt.Errorf("decimal '%s' out of range; bitLen: got %d, max %d", str, combined.BitLen(), maxDecBitLen)
}
if neg {
combined = new(big.Int).Neg(combined)
}

return LegacyDec{combined}, nil
result := LegacyDec{i: combined}
if !result.IsInValidRange() {
return LegacyDec{}, fmt.Errorf("out of range: %w", ErrLegacyInvalidDecimalStr)
}
return result, nil
}

// LegacyMustNewDecFromStr Decimal from string, panic on error
Expand Down Expand Up @@ -275,9 +279,7 @@ func (d LegacyDec) Add(d2 LegacyDec) LegacyDec {
func (d LegacyDec) AddMut(d2 LegacyDec) LegacyDec {
d.i.Add(d.i, d2.i)

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
d.assertInValidRange()
return d
}

Expand All @@ -290,10 +292,20 @@ func (d LegacyDec) Sub(d2 LegacyDec) LegacyDec {
func (d LegacyDec) SubMut(d2 LegacyDec) LegacyDec {
d.i.Sub(d.i, d2.i)

if d.i.BitLen() > maxDecBitLen {
d.assertInValidRange()
return d
}

func (d LegacyDec) assertInValidRange() {
if !d.IsInValidRange() {
panic("Int overflow")
}
return d
}

// IsInValidRange returns true when the value is between the upper limit of (2^256 * 10^18)
// and the lower limit of -1*(2^256 * 10^18).
func (d LegacyDec) IsInValidRange() bool {
return !(d.GT(upperLimit) || d.LT(lowerLimit))
}

// Mul multiplication
Expand All @@ -306,10 +318,8 @@ func (d LegacyDec) MulMut(d2 LegacyDec) LegacyDec {
d.i.Mul(d.i, d2.i)
chopped := chopPrecisionAndRound(d.i)

if chopped.BitLen() > maxDecBitLen {
panic("Int overflow")
}
*d.i = *chopped
d.assertInValidRange()
return d
}

Expand All @@ -322,10 +332,7 @@ func (d LegacyDec) MulTruncate(d2 LegacyDec) LegacyDec {
func (d LegacyDec) MulTruncateMut(d2 LegacyDec) LegacyDec {
d.i.Mul(d.i, d2.i)
chopPrecisionAndTruncate(d.i)

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
d.assertInValidRange()
return d
}

Expand All @@ -339,9 +346,7 @@ func (d LegacyDec) MulRoundUpMut(d2 LegacyDec) LegacyDec {
d.i.Mul(d.i, d2.i)
chopPrecisionAndRoundUp(d.i)

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
d.assertInValidRange()
return d
}

Expand All @@ -352,9 +357,7 @@ func (d LegacyDec) MulInt(i Int) LegacyDec {

func (d LegacyDec) MulIntMut(i Int) LegacyDec {
d.i.Mul(d.i, i.BigIntMut())
if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
d.assertInValidRange()
return d
}

Expand All @@ -365,10 +368,7 @@ func (d LegacyDec) MulInt64(i int64) LegacyDec {

func (d LegacyDec) MulInt64Mut(i int64) LegacyDec {
d.i.Mul(d.i, big.NewInt(i))

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
d.assertInValidRange()
return d
}

Expand All @@ -386,9 +386,7 @@ func (d LegacyDec) QuoMut(d2 LegacyDec) LegacyDec {
d.i.Quo(d.i, d2.i)

chopPrecisionAndRound(d.i)
if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
d.assertInValidRange()
return d
}

Expand All @@ -403,9 +401,7 @@ func (d LegacyDec) QuoTruncateMut(d2 LegacyDec) LegacyDec {
d.i.Mul(d.i, precisionReuse)
d.i.Quo(d.i, d2.i)

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
d.assertInValidRange()
return d
}

Expand All @@ -423,10 +419,7 @@ func (d LegacyDec) QuoRoundupMut(d2 LegacyDec) LegacyDec {
rem.Sign() < 0 && d.IsNegative() != d2.IsNegative() {
d.i.Add(d.i, oneInt)
}

if d.i.BitLen() > maxDecBitLen {
panic("Int overflow")
}
d.assertInValidRange()
return d
}

Expand Down Expand Up @@ -745,17 +738,17 @@ func (d LegacyDec) Ceil() LegacyDec {
quo, rem = quo.QuoRem(tmp, precisionReuse, rem)

// no need to round with a zero remainder regardless of sign
if rem.Sign() == 0 {
return LegacyNewDecFromBigInt(quo)
} else if rem.Sign() == -1 {
return LegacyNewDecFromBigInt(quo)
}

if d.i.BitLen() >= maxDecBitLen {
panic("Int overflow")
var r LegacyDec
switch rem.Sign() {
case 0:
r = LegacyNewDecFromBigInt(quo)
case -1:
r = LegacyNewDecFromBigInt(quo)
default:
r = LegacyNewDecFromBigInt(quo.Add(quo, oneInt))
}

return LegacyNewDecFromBigInt(quo.Add(quo, oneInt))
r.assertInValidRange()
return r
}

// LegacyMaxSortableDec is the largest Dec that can be passed into SortableDecBytes()
Expand Down Expand Up @@ -885,10 +878,9 @@ func (d *LegacyDec) Unmarshal(data []byte) error {
return err
}

if d.i.BitLen() > maxDecBitLen {
return fmt.Errorf("decimal out of range; got: %d, max: %d", d.i.BitLen(), maxDecBitLen)
if !d.IsInValidRange() {
return errors.New("decimal out of range")
}

return nil
}

Expand Down
Loading

0 comments on commit c6522a7

Please sign in to comment.