diff --git a/constraint.go b/constraint.go new file mode 100644 index 0000000..81317b2 --- /dev/null +++ b/constraint.go @@ -0,0 +1,148 @@ +package version + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var constraintRegex = regexp.MustCompile(`^(?:(>=|>|<=|<|!=|==?)\s*)?(.+)$`) + +type constraintFunc func(a, b *Version) bool +type constraint struct { + f constraintFunc + b *Version + original string +} + +// Constraints is a collection of version constraint rules that can be checked against a version. +type Constraints []constraint + +// NewConstraint parses a string into a Constraints object that can be used to check +// if a given version satisfies the constraint. +func NewConstraint(cs string) (Constraints, error) { + parts := strings.Split(cs, ",") + newC := make(Constraints, len(parts)) + for i, p := range parts { + parts[i] = strings.TrimSpace(p) + } + for i, p := range parts { + c, err := newConstraint(p) + if err != nil { + return Constraints{}, err + } + newC[i] = c + } + + return newC, nil +} + +// MustConstraint is like NewConstraint but panics if the constraint is invalid. +func MustConstraint(cs string) Constraints { + c, err := NewConstraint(cs) + if err != nil { + panic("github.com/k0sproject/version: NewConstraint: " + err.Error()) + } + return c +} + +// Check returns true if the given version satisfies all of the constraints. +func (cs Constraints) Check(v *Version) bool { + for _, c := range cs { + if c.b.Prerelease() == "" && v.Prerelease() != "" { + return false + } + if !c.f(c.b, v) { + return false + } + } + + return true +} + +// CheckString is like Check but takes a string version. If the version is invalid, +// it returns false. +func (cs Constraints) CheckString(v string) bool { + vv, err := NewVersion(v) + if err != nil { + return false + } + return cs.Check(vv) +} + +// String returns the original constraint string. +func (c *constraint) String() string { + return c.original +} + +func newConstraint(s string) (constraint, error) { + match := constraintRegex.FindStringSubmatch(s) + if len(match) != 3 { + return constraint{}, errors.New("invalid constraint: " + s) + } + + op := match[1] + f, err := opfunc(op) + if err != nil { + return constraint{}, err + } + + // convert one or two digit constraints to threes digit unless it's an equality operation + if op != "" && op != "=" && op != "==" { + segments := strings.Split(match[2], ".") + if len(segments) < 3 { + lastSegment := segments[len(segments)-1] + var pre string + if strings.Contains(lastSegment, "-") { + parts := strings.Split(lastSegment, "-") + segments[len(segments)-1] = parts[0] + pre = "-" + parts[1] + } + switch len(segments) { + case 1: + // >= 1 becomes >= 1.0.0 + // >= 1-rc.1 becomes >= 1.0.0-rc.1 + return newConstraint(fmt.Sprintf("%s %s.0.0%s", op, segments[0], pre)) + case 2: + // >= 1.1 becomes >= 1.1.0 + // >= 1.1-rc.1 becomes >= 1.1.0-rc.1 + return newConstraint(fmt.Sprintf("%s %s.%s.0%s", op, segments[0], segments[1], pre)) + } + } + } + + target, err := NewVersion(match[2]) + if err != nil { + return constraint{}, err + } + + return constraint{f: f, b: target, original: s}, nil +} + +func opfunc(s string) (constraintFunc, error) { + switch s { + case "", "=", "==": + return eq, nil + case ">": + return gt, nil + case ">=": + return gte, nil + case "<": + return lt, nil + case "<=": + return lte, nil + case "!=": + return neq, nil + default: + return nil, errors.New("invalid operator: " + s) + } +} + +func gt(a, b *Version) bool { return b.GreaterThan(a) } +func lt(a, b *Version) bool { return b.LessThan(a) } +func gte(a, b *Version) bool { return b.GreaterThanOrEqual(a) } +func lte(a, b *Version) bool { return b.LessThanOrEqual(a) } +func eq(a, b *Version) bool { return b.Equal(a) } +func neq(a, b *Version) bool { return !b.Equal(a) } + diff --git a/constraint_test.go b/constraint_test.go new file mode 100644 index 0000000..e4f1b77 --- /dev/null +++ b/constraint_test.go @@ -0,0 +1,153 @@ +package version + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConstraint(t *testing.T) { + type testCase struct { + constraint string + truthTable map[bool][]string + } + + testCases := []testCase{ + { + constraint: ">= 1.1.0-beta.1+k0s.1", + truthTable: map[bool][]string{ + true: { + "1.1.0+k0s.0", + "1.1.0-rc.1+k0s.0", + "1.1.1+k0s.0", + "1.1.1-rc.1+k0s.0", + }, + false: { + "1.1.0-alpha.1+k0s.2", + "1.0.1+k0s.10", + }, + }, + }, + { + constraint: ">= 1.1.0+k0s.1", + truthTable: map[bool][]string{ + true: { + "1.1.0+k0s.1", + "1.1.0+k0s.2", + "1.1.1+k0s.0", + }, + false: { + "1.0.9+k0s.255", + "1.1.0+k0s.0", + }, + }, + }, + // simple operator checks + { + constraint: "= 1.0.0", + truthTable: map[bool][]string{ + true: {"1.0.0"}, + false: {"1.0.1", "0.9.9"}, + }, + }, + { + constraint: "1.0.0", + truthTable: map[bool][]string{ + true: {"1.0.0"}, + false: {"1.0.1", "0.9.9"}, + }, + }, + { + constraint: "!= 1.0.0", + truthTable: map[bool][]string{ + true: {"1.0.1", "0.9.9"}, + false: {"1.0.0"}, + }, + }, + { + constraint: "> 1.0.0", + truthTable: map[bool][]string{ + true: {"1.0.1", "1.1.0"}, + false: {"1.0.0", "0.9.9"}, + }, + }, + { + constraint: "< 1.0.0", + truthTable: map[bool][]string{ + true: {"0.9.9", "0.9.8"}, + false: {"1.0.0", "1.0.1"}, + }, + }, + { + constraint: ">= 1.0.0", + truthTable: map[bool][]string{ + true: {"1.0.0", "1.0.1"}, + false: {"0.9.9"}, + }, + }, + { + constraint: "<= 1.0.0", + truthTable: map[bool][]string{ + true: {"1.0.0", "0.9.9"}, + false: {"1.0.1"}, + }, + }, + // two digit constraints + { + constraint: ">= 1.0", + truthTable: map[bool][]string{ + true: {"1.0.0", "1.0.1", "1.1.0"}, + false: {"0.9.9", "1.0.1-alpha.1"}, + }, + }, + { + constraint: ">= 1.0-a", + truthTable: map[bool][]string{ + true: {"1.0.0", "1.0.1", "1.0.0-alpha.1"}, + false: {"0.9.9"}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.constraint, func(t *testing.T) { + c, err := NewConstraint(tc.constraint) + assert.NoError(t, err) + + for expected, versions := range tc.truthTable { + t.Run(fmt.Sprintf("%t", expected), func(t *testing.T) { + for _, version := range versions { + t.Run(version, func(t *testing.T) { + assert.Equal(t, expected, c.Check(MustParse(version))) + }) + } + }) + } + }) + } +} + +func TestInvalidConstraint(t *testing.T) { + invalidConstraints := []string{ + "", + "==", + ">= ", + "invalid", + ">= abc", + } + + for _, invalidConstraint := range invalidConstraints { + _, err := newConstraint(invalidConstraint) + assert.Error(t, err, "Expected error for invalid constraint: "+invalidConstraint) + } +} + +func TestCheckString(t *testing.T) { + c, err := NewConstraint(">= 1.0.0") + assert.NoError(t, err) + + assert.True(t, c.CheckString("1.0.0")) + assert.False(t, c.CheckString("0.9.9")) + assert.False(t, c.CheckString("x")) +} diff --git a/version.go b/version.go index bf9ef37..672d52b 100644 --- a/version.go +++ b/version.go @@ -147,6 +147,11 @@ func (v *Version) UnmarshalJSON(b []byte) error { }) } +// Satisfies returns true if the version satisfies the supplied constraint +func (v *Version) Satisfies(constraint Constraints) bool { + return constraint.Check(v) +} + // NewVersion returns a new Version created from the supplied string or an error if the string is not a valid version number func NewVersion(v string) (*Version, error) { n, err := goversion.NewVersion(strings.TrimPrefix(v, "v")) diff --git a/version_test.go b/version_test.go index 562075b..d9be71b 100644 --- a/version_test.go +++ b/version_test.go @@ -38,6 +38,20 @@ func TestK0sComparison(t *testing.T) { assert.False(t, b.Equal(a), "version %s should not be equal to %s", b, a) } +func TestSatisfies(t *testing.T) { + v, err := NewVersion("1.23.1+k0s.1") + assert.NoError(t, err) + assert.True(t, v.Satisfies(MustConstraint(">=1.23.1"))) + assert.True(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.0"))) + assert.True(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.1"))) + assert.True(t, v.Satisfies(MustConstraint("=1.23.1+k0s.1"))) + assert.True(t, v.Satisfies(MustConstraint("<1.23.1+k0s.2"))) + assert.False(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.2"))) + assert.False(t, v.Satisfies(MustConstraint(">=1.23.2"))) + assert.False(t, v.Satisfies(MustConstraint(">1.23.1+k0s.1"))) + assert.False(t, v.Satisfies(MustConstraint("<1.23.1+k0s.1"))) +} + func TestURLs(t *testing.T) { a, err := NewVersion("1.23.3+k0s.1") assert.NoError(t, err)