Skip to content

Commit

Permalink
Merge pull request cedar-policy#38 from strongdm/idx-172/set-contains…
Browse files Browse the repository at this point in the history
…-performance

Improve `Set.Contains()` performance by making `Set` a hash set-like data structure
  • Loading branch information
patjakdev authored Sep 19, 2024
2 parents d1f59f4 + 2c541fb commit 055335e
Show file tree
Hide file tree
Showing 48 changed files with 1,171 additions and 442 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
run: go build -v ./...

- name: Test
run: go test -v ./...
run: go test ./...

- name: Fuzz
run: mkdir -p testdata && go test -fuzz=FuzzParse -fuzztime 60s && go test -fuzz=FuzzTokenize -fuzztime 60s
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ While in development (0.x.y), each tagged release may contain breaking changes.

## Change log

### New features in 0.4.0

- `types.Set` is now implemented as a hash set, turning `Set.Contains()` into an O(1) operation, on average. This mitigates a worst case quadratic runtime for the evaluation of the `containsAny()` operator.

### Upgrading from 0.3.x to 0.4.x

- `types.Set` is now an immutable type which must be constructed via `types.NewSet()`
- To iterate the values, use `Set.Iterate()`, which takes an iterator callback.
- Duplicates are now removed from `Set`s, so they won't be rendered when calling `Set.MarshalCedar()` or `Set.MarshalJSON`.
- All implementations of `types.Value` are now safe to copy shallowly, so `Set.DeepClone()` has been removed.
- `types.Record` is now an immutable type which must be constructed via `types.NewRecord()`
- To iterate the keys and values, use `Record.Iterate()`, which takes an iterator callback.
- All implementations of `types.Value` are now safe to copy shallowly, so `Record.DeepClone()` has been removed.

### New features in 0.3.2

- An implementation of the `datetime` and `duration` extension types specified in [RFC 80](https://github.com/cedar-policy/rfcs/blob/main/text/0080-datetime-extension.md).
Expand Down
4 changes: 2 additions & 2 deletions ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ func TestAstExamples(t *testing.T) {
// when { {x: "value"}.x == "value" }
// when { {x: 1 + context.fooCount}.x == 3 }
// when { [1, (2 + 3) * 4, context.fooCount].contains(1) };
simpleRecord := types.Record{
simpleRecord := types.NewRecord(types.RecordMap{
"x": types.String("value"),
}
})
_ = ast.Forbid().
When(
ast.Value(simpleRecord).Access("x").Equal(ast.String("value")),
Expand Down
16 changes: 8 additions & 8 deletions authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestIsAuthorized(t *testing.T) {
Principal: cuzco,
Action: dropTable,
Resource: types.NewEntityUID("table", "whatever"),
Context: types.Record{"x": types.Long(42)},
Context: types.NewRecord(types.RecordMap{"x": types.Long(42)}),
Want: true,
DiagErr: 0,
},
Expand All @@ -97,7 +97,7 @@ func TestIsAuthorized(t *testing.T) {
Principal: cuzco,
Action: dropTable,
Resource: types.NewEntityUID("table", "whatever"),
Context: types.Record{"x": types.Long(43)},
Context: types.NewRecord(types.RecordMap{"x": types.Long(43)}),
Want: false,
DiagErr: 0,
},
Expand All @@ -107,7 +107,7 @@ func TestIsAuthorized(t *testing.T) {
Entities: types.Entities{
cuzco: &types.Entity{
UID: cuzco,
Attributes: types.Record{"x": types.Long(42)},
Attributes: types.NewRecord(types.RecordMap{"x": types.Long(42)}),
},
},
Principal: cuzco,
Expand All @@ -123,7 +123,7 @@ func TestIsAuthorized(t *testing.T) {
Entities: types.Entities{
cuzco: &types.Entity{
UID: cuzco,
Attributes: types.Record{"x": types.Long(43)},
Attributes: types.NewRecord(types.RecordMap{"x": types.Long(43)}),
},
},
Principal: cuzco,
Expand Down Expand Up @@ -302,7 +302,7 @@ func TestIsAuthorized(t *testing.T) {
Entities: types.Entities{
cuzco: &types.Entity{
UID: cuzco,
Attributes: types.Record{"name": types.String("bob")},
Attributes: types.NewRecord(types.RecordMap{"name": types.String("bob")}),
},
},
Principal: cuzco,
Expand Down Expand Up @@ -678,7 +678,7 @@ func TestIsAuthorized(t *testing.T) {
Name: "negative-unary-op",
Policy: `permit(principal,action,resource) when { -context.value > 0 };`,
Entities: types.Entities{},
Context: types.Record{"value": types.Long(-42)},
Context: types.NewRecord(types.RecordMap{"value": types.Long(-42)}),
Want: true,
DiagErr: 0,
},
Expand Down Expand Up @@ -770,13 +770,13 @@ func TestIsAuthorized(t *testing.T) {
Entities: types.Entities{
types.NewEntityUID("Principal", "1"): &types.Entity{
UID: types.NewEntityUID("Principal", "1"),
Attributes: types.Record{"bar": types.Long(42)},
Attributes: types.NewRecord(types.RecordMap{"bar": types.Long(42)}),
},
},
Principal: types.NewEntityUID("Principal", "1"),
Action: types.NewEntityUID("Action", "action"),
Resource: types.NewEntityUID("Resource", "resource"),
Context: types.Record{"foo": types.Long(43)},
Context: types.NewRecord(types.RecordMap{"foo": types.Long(43)}),
Want: true,
DiagErr: 0,
},
Expand Down
12 changes: 6 additions & 6 deletions internal/ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func TestAstExamples(t *testing.T) {
// when { {x: "value"}.x == "value" }
// when { {x: 1 + context.fooCount}.x == 3 }
// when { [1, (2 + 3) * 4, context.fooCount].contains(1) };
simpleRecord := types.Record{
simpleRecord := types.NewRecord(types.RecordMap{
"x": types.String("value"),
}
})
_ = ast.Forbid().
When(
ast.Value(simpleRecord).Access("x").Equal(ast.String("value")),
Expand Down Expand Up @@ -249,9 +249,9 @@ func TestASTByTable(t *testing.T) {
},
{
"valueSet",
ast.Permit().When(ast.Value(types.Set{types.Long(42), types.Long(43)})),
ast.Permit().When(ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43)}))),
ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{},
Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Set{types.Long(42), types.Long(43)}}}},
Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.NewSet([]types.Value{types.Long(42), types.Long(43)})}}},
},
},
{
Expand All @@ -263,9 +263,9 @@ func TestASTByTable(t *testing.T) {
},
{
"valueRecord",
ast.Permit().When(ast.Value(types.Record{"key": types.Long(43)})),
ast.Permit().When(ast.Value(types.NewRecord(types.RecordMap{"key": types.Long(43)}))),
ast.Policy{Effect: ast.EffectPermit, Principal: ast.ScopeTypeAll{}, Action: ast.ScopeTypeAll{}, Resource: ast.ScopeTypeAll{},
Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.Record{"key": types.Long(43)}}}},
Conditions: []ast.ConditionType{{Condition: ast.ConditionWhen, Body: ast.NodeValue{Value: types.NewRecord(types.RecordMap{"key": types.Long(43)})}}},
},
},
{
Expand Down
6 changes: 3 additions & 3 deletions internal/eval/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ func scopeToNode(varNode ast.NodeTypeVariable, in ast.IsScopeNode) ast.Node {
case ast.ScopeTypeIn:
return ast.NewNode(varNode).In(ast.Value(t.Entity))
case ast.ScopeTypeInSet:
set := make(types.Set, len(t.Entities))
vals := make([]types.Value, len(t.Entities))
for i, e := range t.Entities {
set[i] = e
vals[i] = e
}
return ast.NewNode(varNode).In(ast.Value(set))
return ast.NewNode(varNode).In(ast.Value(types.NewSet(vals)))
case ast.ScopeTypeIs:
return ast.NewNode(varNode).Is(t.Type)

Expand Down
2 changes: 1 addition & 1 deletion internal/eval/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func TestScopeToNode(t *testing.T) {
"inSet",
ast.NewActionNode(),
ast.ScopeTypeInSet{Entities: []types.EntityUID{types.NewEntityUID("T", "42")}},
ast.Action().In(ast.Value(types.Set{types.NewEntityUID("T", "42")})),
ast.Action().In(ast.Value(types.NewSet([]types.Value{types.NewEntityUID("T", "42")}))),
},
{
"is",
Expand Down
14 changes: 7 additions & 7 deletions internal/eval/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ func TestToEval(t *testing.T) {
}{
{
"access",
ast.Value(types.Record{"key": types.Long(42)}).Access("key"),
ast.Value(types.NewRecord(types.RecordMap{"key": types.Long(42)})).Access("key"),
types.Long(42),
testutil.OK,
},
{
"has",
ast.Value(types.Record{"key": types.Long(42)}).Has("key"),
ast.Value(types.NewRecord(types.RecordMap{"key": types.Long(42)})).Has("key"),
types.True,
testutil.OK,
},
Expand Down Expand Up @@ -63,13 +63,13 @@ func TestToEval(t *testing.T) {
{
"record",
ast.Record(ast.Pairs{{Key: "key", Value: ast.Long(42)}}),
types.Record{"key": types.Long(42)},
types.NewRecord(types.RecordMap{"key": types.Long(42)}),
testutil.OK,
},
{
"set",
ast.Set(ast.Long(42)),
types.Set{types.Long(42)},
types.NewSet([]types.Value{types.Long(42)}),
testutil.OK,
},
{
Expand Down Expand Up @@ -182,19 +182,19 @@ func TestToEval(t *testing.T) {
},
{
"contains",
ast.Value(types.Set{types.Long(42)}).Contains(ast.Long(42)),
ast.Value(types.NewSet([]types.Value{types.Long(42)})).Contains(ast.Long(42)),
types.True,
testutil.OK,
},
{
"containsAll",
ast.Value(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAll(ast.Value(types.Set{types.Long(42), types.Long(43)})),
ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43), types.Long(44)})).ContainsAll(ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43)}))),
types.True,
testutil.OK,
},
{
"containsAny",
ast.Value(types.Set{types.Long(42), types.Long(43), types.Long(44)}).ContainsAny(ast.Value(types.Set{types.Long(1), types.Long(42)})),
ast.Value(types.NewSet([]types.Value{types.Long(42), types.Long(43), types.Long(44)})).ContainsAny(ast.Value(types.NewSet([]types.Value{types.Long(1), types.Long(42)}))),
types.True,
testutil.OK,
},
Expand Down
51 changes: 29 additions & 22 deletions internal/eval/evalers.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ func evalString(n Evaler, env *Env) (types.String, error) {
func evalSet(n Evaler, env *Env) (types.Set, error) {
v, err := n.Eval(env)
if err != nil {
return nil, err
return types.Set{}, err
}
s, err := ValueToSet(v)
if err != nil {
return nil, err
return types.Set{}, err
}
return s, nil
}
Expand Down Expand Up @@ -733,15 +733,15 @@ func newSetLiteralEval(elements []Evaler) *setLiteralEval {
}

func (n *setLiteralEval) Eval(env *Env) (types.Value, error) {
var vals types.Set
for _, e := range n.elements {
vals := make([]types.Value, len(n.elements))
for i, e := range n.elements {
v, err := e.Eval(env)
if err != nil {
return zeroValue(), err
}
vals = append(vals, v)
vals[i] = v
}
return vals, nil
return types.NewSet(vals), nil
}

// containsEval
Expand Down Expand Up @@ -790,12 +790,13 @@ func (n *containsAllEval) Eval(env *Env) (types.Value, error) {
return zeroValue(), err
}
result := true
for _, e := range rhs {
rhs.Iterate(func(e types.Value) bool {
if !lhs.Contains(e) {
result = false
break
return false
}
}
return true
})
return types.Boolean(result), nil
}

Expand All @@ -821,12 +822,13 @@ func (n *containsAnyEval) Eval(env *Env) (types.Value, error) {
return zeroValue(), err
}
result := false
for _, e := range rhs {
rhs.Iterate(func(e types.Value) bool {
if lhs.Contains(e) {
result = true
break
return false
}
}
return true
})
return types.Boolean(result), nil
}

Expand All @@ -840,15 +842,15 @@ func newRecordLiteralEval(elements map[types.String]Evaler) *recordLiteralEval {
}

func (n *recordLiteralEval) Eval(env *Env) (types.Value, error) {
vals := types.Record{}
vals := types.RecordMap{}
for k, en := range n.elements {
v, err := en.Eval(env)
if err != nil {
return zeroValue(), err
}
vals[k] = v
}
return vals, nil
return types.NewRecord(vals), nil
}

// attributeAccessEval
Expand Down Expand Up @@ -876,13 +878,13 @@ func (n *attributeAccessEval) Eval(env *Env) (types.Value, error) {
if !ok {
return zeroValue(), fmt.Errorf("entity `%v` %w", vv.String(), errEntityNotExist)
}
val, ok := rec.Attributes[n.attribute]
val, ok := rec.Attributes.Get(n.attribute)
if !ok {
return zeroValue(), fmt.Errorf("`%s` %w `%s`", vv.String(), errAttributeAccess, n.attribute)
}
return val, nil
case types.Record:
val, ok := vv[n.attribute]
val, ok := vv.Get(n.attribute)
if !ok {
return zeroValue(), fmt.Errorf("record %w `%s`", errAttributeAccess, n.attribute)
}
Expand Down Expand Up @@ -918,7 +920,7 @@ func (n *hasEval) Eval(env *Env) (types.Value, error) {
default:
return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, TypeName(v))
}
_, ok := record[n.attribute]
_, ok := record.Get(n.attribute)
return types.Boolean(ok), nil
}

Expand Down Expand Up @@ -1070,13 +1072,18 @@ func doInEval(env *Env, lhs types.EntityUID, rhs types.Value) (types.Value, erro
case types.EntityUID:
return types.Boolean(entityInOne(env, lhs, rhsv)), nil
case types.Set:
query := make(map[types.EntityUID]struct{}, len(rhsv))
for _, rhv := range rhsv {
e, err := ValueToEntity(rhv)
if err != nil {
return zeroValue(), err
query := make(map[types.EntityUID]struct{}, rhsv.Len())
var err error
rhsv.Iterate(func(rhv types.Value) bool {
var e types.EntityUID
if e, err = ValueToEntity(rhv); err != nil {
return false
}
query[e] = struct{}{}
return true
})
if err != nil {
return zeroValue(), err
}
return types.Boolean(entityInSet(env, lhs, query)), nil
}
Expand Down
Loading

0 comments on commit 055335e

Please sign in to comment.