diff --git a/pkg/composableschemadsl/compiler/compiler.go b/pkg/composableschemadsl/compiler/compiler.go index e0ec6946e8..81b4fcfdd3 100644 --- a/pkg/composableschemadsl/compiler/compiler.go +++ b/pkg/composableschemadsl/compiler/compiler.go @@ -86,31 +86,8 @@ type Option func(*config) type ObjectPrefixOption func(*config) -type compilationContext struct { - // The set of definition names that we've seen as we compile. - // If these collide we throw an error. - existingNames *mapz.Set[string] - // The global set of files we've visited in the import process. - // If these collide we short circuit, preventing duplicate imports. - globallyVisitedFiles *mapz.Set[string] - // The set of files that we've visited on a particular leg of the recursion. - // This allows for detection of circular imports. - // NOTE: This depends on an assumption that a depth-first search will always - // find a cycle, even if we're otherwise marking globally visited nodes. - locallyVisitedFiles *mapz.Set[string] -} - // Compile compilers the input schema into a set of namespace definition protos. func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) { - cctx := compilationContext{ - existingNames: mapz.NewSet[string](), - globallyVisitedFiles: mapz.NewSet[string](), - locallyVisitedFiles: mapz.NewSet[string](), - } - return compileImpl(schema, cctx, prefix, opts...) -} - -func compileImpl(schema InputSchema, cctx compilationContext, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) { cfg := &config{} prefix(cfg) // required option @@ -118,23 +95,30 @@ func compileImpl(schema InputSchema, cctx compilationContext, prefix ObjectPrefi fn(cfg) } - mapper := newPositionMapper(schema) - root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode) - errs := root.FindAll(dslshape.NodeTypeError) - if len(errs) > 0 { - err := errorNodeToError(errs[0], mapper) + root, mapper, err := parseSchema(schema) + if err != nil { return nil, err } - compiled, err := translate(translationContext{ - objectTypePrefix: cfg.objectTypePrefix, - mapper: mapper, - schemaString: schema.SchemaString, - skipValidate: cfg.skipValidation, + // NOTE: import translation is done separately so that partial references + // and definitions defined in separate files can correctly resolve. + err = translateImports(importResolutionContext{ + globallyVisitedFiles: mapz.NewSet[string](), + locallyVisitedFiles: mapz.NewSet[string](), sourceFolder: cfg.sourceFolder, - existingNames: cctx.existingNames, - locallyVisitedFiles: cctx.locallyVisitedFiles, - globallyVisitedFiles: cctx.globallyVisitedFiles, + }, root) + if err != nil { + return nil, err + } + + compiled, err := translate(translationContext{ + objectTypePrefix: cfg.objectTypePrefix, + mapper: mapper, + schemaString: schema.SchemaString, + skipValidate: cfg.skipValidation, + existingNames: mapz.NewSet[string](), + compiledPartials: make(map[string][]*core.Relation), + unresolvedPartials: mapz.NewMultiMap[string, *dslNode](), }, root) if err != nil { var withNodeError withNodeError @@ -148,6 +132,17 @@ func compileImpl(schema InputSchema, cctx compilationContext, prefix ObjectPrefi return compiled, nil } +func parseSchema(schema InputSchema) (*dslNode, input.PositionMapper, error) { + mapper := newPositionMapper(schema) + root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode) + errs := root.FindAll(dslshape.NodeTypeError) + if len(errs) > 0 { + err := errorNodeToError(errs[0], mapper) + return nil, nil, err + } + return root, mapper, nil +} + func errorNodeToError(node *dslNode, mapper input.PositionMapper) error { if node.GetType() != dslshape.NodeTypeError { return fmt.Errorf("given none error node") diff --git a/pkg/composableschemadsl/compiler/compiler_test.go b/pkg/composableschemadsl/compiler/compiler_test.go index 512740bdc8..79f65fd896 100644 --- a/pkg/composableschemadsl/compiler/compiler_test.go +++ b/pkg/composableschemadsl/compiler/compiler_test.go @@ -88,6 +88,238 @@ func TestCompile(t *testing.T) { ), }, }, + { + "simple partial", + withTenantPrefix, + `partial view_partial { + relation user: user; + } + + definition simple { + ...view_partial + }`, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/simple", + namespace.MustRelation("user", nil, + namespace.AllowedRelation("sometenant/user", "..."), + ), + ), + }, + }, + { + "more complex partial", + withTenantPrefix, + ` + definition user {} + definition organization {} + + partial view_partial { + relation user: user; + permission view = user + } + + definition resource { + relation organization: organization + permission manage = organization + + ...view_partial + } + `, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/user"), + namespace.Namespace("sometenant/organization"), + namespace.Namespace("sometenant/resource", + namespace.MustRelation("organization", nil, + namespace.AllowedRelation("sometenant/organization", "..."), + ), + namespace.MustRelation("manage", + namespace.Union( + namespace.ComputedUserset("organization"), + ), + ), + namespace.MustRelation("user", nil, + namespace.AllowedRelation("sometenant/user", "..."), + ), + namespace.MustRelation("view", + namespace.Union( + namespace.ComputedUserset("user"), + ), + ), + ), + }, + }, + { + "partial defined after reference", + withTenantPrefix, + `definition simple { + ...view_partial + } + + partial view_partial { + relation user: user; + }`, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/simple", + namespace.MustRelation("user", nil, + namespace.AllowedRelation("sometenant/user", "..."), + ), + ), + }, + }, + { + "transitive partials", + withTenantPrefix, + ` + partial view_partial { + relation user: user; + } + + partial transitive_partial { + ...view_partial + } + + definition simple { + ...view_partial + } + `, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/simple", + namespace.MustRelation("user", nil, + namespace.AllowedRelation("sometenant/user", "..."), + ), + ), + }, + }, + { + "transitive partials out of order", + withTenantPrefix, + ` + partial transitive_partial { + ...view_partial + } + + partial view_partial { + relation user: user; + } + + definition simple { + ...view_partial + } + `, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/simple", + namespace.MustRelation("user", nil, + namespace.AllowedRelation("sometenant/user", "..."), + ), + ), + }, + }, + { + "transitive partials in reverse order", + withTenantPrefix, + ` + definition simple { + ...view_partial + } + + partial transitive_partial { + ...view_partial + } + + partial view_partial { + relation user: user; + } + `, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/simple", + namespace.MustRelation("user", nil, + namespace.AllowedRelation("sometenant/user", "..."), + ), + ), + }, + }, + { + "forking transitive partials out of order", + withTenantPrefix, + ` + partial transitive_partial { + ...view_partial + ...group_partial + } + + partial view_partial { + relation user: user; + } + + partial group_partial { + relation group: group; + } + + definition simple { + ...transitive_partial + } + `, + "", + []SchemaDefinition{ + namespace.Namespace("sometenant/simple", + namespace.MustRelation("user", nil, + namespace.AllowedRelation("sometenant/user", "..."), + ), + namespace.MustRelation("group", nil, + namespace.AllowedRelation("sometenant/group", "..."), + ), + ), + }, + }, + { + "circular reference in partials", + withTenantPrefix, + ` + partial one_partial { + ...another_partial + } + + partial another_partial { + ...one_partial + } + + definition simple { + ...one_partial + } + `, + "could not resolve partials", + []SchemaDefinition{}, + }, + { + "definition reference to nonexistent partial", + withTenantPrefix, + ` + definition simple { + ...some_partial + } + `, + "could not find partial reference", + []SchemaDefinition{}, + }, + { + "definition reference to another definition errors", + withTenantPrefix, + ` + definition some_definition {} + + definition simple { + ...some_definition + } + `, + "could not find partial reference", + []SchemaDefinition{}, + }, { "explicit relation", withTenantPrefix, diff --git a/pkg/composableschemadsl/compiler/importer.go b/pkg/composableschemadsl/compiler/importer.go index 7636af0166..b51790711d 100644 --- a/pkg/composableschemadsl/compiler/importer.go +++ b/pkg/composableschemadsl/compiler/importer.go @@ -3,86 +3,31 @@ package compiler import ( "fmt" "os" - "path" "path/filepath" "strings" "github.com/rs/zerolog/log" "github.com/authzed/spicedb/pkg/composableschemadsl/input" - "github.com/authzed/spicedb/pkg/genutil/mapz" ) -type importContext struct { - path string - sourceFolder string - names *mapz.Set[string] - locallyVisitedFiles *mapz.Set[string] - globallyVisitedFiles *mapz.Set[string] -} - type CircularImportError struct { error filePath string } -func importFile(importContext importContext) (*CompiledSchema, error) { - if err := validateFilepath(importContext.path); err != nil { - return nil, err - } - filePath := path.Join(importContext.sourceFolder, importContext.path) - - newSourceFolder := filepath.Dir(filePath) - - currentLocallyVisitedFiles := importContext.locallyVisitedFiles.Copy() - - if ok := currentLocallyVisitedFiles.Add(filePath); !ok { - // If we've already visited the file on this particular branch walk, it's - // a circular import issue. - return nil, &CircularImportError{ - error: fmt.Errorf("circular import detected: %s has been visited on this branch", filePath), - filePath: filePath, - } - } - - if ok := importContext.globallyVisitedFiles.Add(filePath); !ok { - // If the file has already been visited, we short-circuit the import process - // by not reading the schema file in and compiling a schema with an empty string. - // This prevents duplicate definitions from ending up in the output, as well - // as preventing circular imports. - log.Debug().Str("filepath", filePath).Msg("file %s has already been visited in another part of the walk") - return compileImpl(InputSchema{ - Source: input.Source(filePath), - SchemaString: "", - }, - compilationContext{ - existingNames: importContext.names, - locallyVisitedFiles: currentLocallyVisitedFiles, - globallyVisitedFiles: importContext.globallyVisitedFiles, - }, - AllowUnprefixedObjectType(), - SourceFolder(newSourceFolder), - ) - } - +func importFile(filePath string) (*dslNode, error) { schemaBytes, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read schema file: %w", err) } log.Trace().Str("schema", string(schemaBytes)).Str("file", filePath).Msg("read schema from file") - return compileImpl(InputSchema{ + parsedSchema, _, err := parseSchema(InputSchema{ Source: input.Source(filePath), SchemaString: string(schemaBytes), - }, - compilationContext{ - existingNames: importContext.names, - locallyVisitedFiles: currentLocallyVisitedFiles, - globallyVisitedFiles: importContext.globallyVisitedFiles, - }, - AllowUnprefixedObjectType(), - SourceFolder(newSourceFolder), - ) + }) + return parsedSchema, err } // Take a filepath and ensure that it's local to the current context. diff --git a/pkg/composableschemadsl/compiler/node.go b/pkg/composableschemadsl/compiler/node.go index 8e095d677d..62ea56603b 100644 --- a/pkg/composableschemadsl/compiler/node.go +++ b/pkg/composableschemadsl/compiler/node.go @@ -35,6 +35,15 @@ func (tn *dslNode) Connect(predicate string, other parser.AstNode) { tn.children[predicate].PushBack(other) } +// Used to preserve import order when doing import operations on AST +func (tn *dslNode) ConnectAndHoistMany(predicate string, other *list.List) { + if tn.children[predicate] == nil { + tn.children[predicate] = list.New() + } + + tn.children[predicate].PushFrontList(other) +} + func (tn *dslNode) MustDecorate(property string, value string) parser.AstNode { if _, ok := tn.properties[property]; ok { panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties)) diff --git a/pkg/composableschemadsl/compiler/translator.go b/pkg/composableschemadsl/compiler/translator.go index 5e0e6663ac..4ec48d888d 100644 --- a/pkg/composableschemadsl/compiler/translator.go +++ b/pkg/composableschemadsl/compiler/translator.go @@ -2,8 +2,9 @@ package compiler import ( "bufio" - "errors" + "container/list" "fmt" + "path/filepath" "strings" "github.com/ccoveille/go-safecast" @@ -20,14 +21,16 @@ import ( ) type translationContext struct { - objectTypePrefix *string - mapper input.PositionMapper - schemaString string - skipValidate bool - existingNames *mapz.Set[string] - locallyVisitedFiles *mapz.Set[string] - globallyVisitedFiles *mapz.Set[string] - sourceFolder string + objectTypePrefix *string + mapper input.PositionMapper + schemaString string + skipValidate bool + existingNames *mapz.Set[string] + // The mapping of partial name -> relations represented by the partial + compiledPartials map[string][]*core.Relation + // A mapping of partial name -> partial DSL nodes whose resolution depends on + // the resolution of the named partial + unresolvedPartials *mapz.MultiMap[string, *dslNode] } func (tctx translationContext) prefixedPath(definitionName string) (string, error) { @@ -58,6 +61,14 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) // as we do our walk names := tctx.existingNames.Copy() + // Do an initial pass to translate partials and add them to the + // translation context. This ensures that they're available for + // subsequent reference in definition compilation. + err := collectPartials(tctx, root) + if err != nil { + return nil, err + } + for _, topLevelNode := range root.GetChildren() { switch topLevelNode.GetType() { case dslshape.NodeTypeCaveatDefinition: @@ -92,17 +103,6 @@ func translate(tctx translationContext, root *dslNode) (*CompiledSchema, error) } orderedDefinitions = append(orderedDefinitions, def) - - case dslshape.NodeTypeImport: - compiled, err := translateImport(tctx, topLevelNode, names) - if err != nil { - return nil, err - } - // NOTE: name collision validation happens in the recursive compilation process, - // so we don't need an explicit check here and we can just add our definitions. - caveatDefinitions = append(caveatDefinitions, compiled.CaveatDefinitions...) - objectDefinitions = append(objectDefinitions, compiled.ObjectDefinitions...) - orderedDefinitions = append(orderedDefinitions, compiled.OrderedDefinitions...) } } @@ -227,18 +227,10 @@ func translateObjectDefinition(tctx translationContext, defNode *dslNode) (*core return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) } - relationsAndPermissions := []*core.Relation{} - for _, relationOrPermissionNode := range defNode.GetChildren() { - if relationOrPermissionNode.GetType() == dslshape.NodeTypeComment { - continue - } - - relationOrPermission, err := translateRelationOrPermission(tctx, relationOrPermissionNode) - if err != nil { - return nil, err - } - - relationsAndPermissions = append(relationsAndPermissions, relationOrPermission) + errorOnMissingReference := true + relationsAndPermissions, _, err := translateRelationsAndPermissions(tctx, defNode, errorOnMissingReference) + if err != nil { + return nil, err } nspath, err := tctx.prefixedPath(definitionName) @@ -273,6 +265,39 @@ func translateObjectDefinition(tctx translationContext, defNode *dslNode) (*core return ns, nil } +// NOTE: This function behaves differently based on an errorOnMissingReference flag. +// A value of true treats that as an error state, since all partials should be resolved when translating definitions, +// where the false value returns the name of the partial for collection for future processing +// when translating partials. +func translateRelationsAndPermissions(tctx translationContext, astNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) { + relationsAndPermissions := []*core.Relation{} + for _, definitionChildNode := range astNode.GetChildren() { + if definitionChildNode.GetType() == dslshape.NodeTypeComment { + continue + } + + if definitionChildNode.GetType() == dslshape.NodeTypePartialReference { + partialContents, unresolvedPartial, err := translatePartialReference(tctx, definitionChildNode, errorOnMissingReference) + if err != nil { + return nil, "", err + } + if unresolvedPartial != "" { + return nil, unresolvedPartial, nil + } + relationsAndPermissions = append(relationsAndPermissions, partialContents...) + continue + } + + relationOrPermission, err := translateRelationOrPermission(tctx, definitionChildNode) + if err != nil { + return nil, "", err + } + + relationsAndPermissions = append(relationsAndPermissions, relationOrPermission) + } + return relationsAndPermissions, "", nil +} + func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.SourcePosition { if !dslNode.Has(dslshape.NodePredicateStartRune) { return nil @@ -696,27 +721,161 @@ func addWithCaveats(tctx translationContext, typeRefNode *dslNode, ref *core.All return nil } -func translateImport(tctx translationContext, importNode *dslNode, names *mapz.Set[string]) (*CompiledSchema, error) { - path, err := importNode.GetString(dslshape.NodeImportPredicatePath) - if err != nil { - return nil, err +type importResolutionContext struct { + // The global set of files we've visited in the import process. + // If these collide we short circuit, preventing duplicate imports. + globallyVisitedFiles *mapz.Set[string] + // The set of files that we've visited on a particular leg of the recursion. + // This allows for detection of circular imports. + // NOTE: This depends on an assumption that a depth-first search will always + // find a cycle, even if we're otherwise marking globally visited nodes. + locallyVisitedFiles *mapz.Set[string] + sourceFolder string +} + +// Takes a parsed schema and recursively translates import syntax and replaces +// import nodes with parsed nodes from the target files +func translateImports(itctx importResolutionContext, root *dslNode) error { + // We create a new list so that we can maintain the order + // of imported nodes + importedDefinitionNodes := list.New() + + for _, topLevelNode := range root.GetChildren() { + // Process import nodes; ignore the others + if topLevelNode.GetType() == dslshape.NodeTypeImport { + // Do the handling of recursive imports etc here + importPath, err := topLevelNode.GetString(dslshape.NodeImportPredicatePath) + if err != nil { + return err + } + + if err := validateFilepath(importPath); err != nil { + return err + } + filePath := filepath.Join(itctx.sourceFolder, importPath) + + newSourceFolder := filepath.Dir(filePath) + + currentLocallyVisitedFiles := itctx.locallyVisitedFiles.Copy() + + if ok := currentLocallyVisitedFiles.Add(filePath); !ok { + // If we've already visited the file on this particular branch walk, it's + // a circular import issue. + return &CircularImportError{ + error: fmt.Errorf("circular import detected: %s has been visited on this branch", filePath), + filePath: filePath, + } + } + + if ok := itctx.globallyVisitedFiles.Add(filePath); !ok { + // If the file has already been visited, we short-circuit the import process + // by not reading the schema file in and compiling a schema with an empty string. + // This prevents duplicate definitions from ending up in the output, as well + // as preventing circular imports. + log.Debug().Str("filepath", filePath).Msg("file %s has already been visited in another part of the walk") + return nil + } + + // Do the actual import here + // This is a new node provided by the translateImport + parsedImportRoot, err := importFile(filePath) + if err != nil { + return err + } + + // We recurse on that node to resolve any further imports + err = translateImports(importResolutionContext{ + sourceFolder: newSourceFolder, + locallyVisitedFiles: currentLocallyVisitedFiles, + globallyVisitedFiles: itctx.globallyVisitedFiles, + }, parsedImportRoot) + if err != nil { + return err + } + + // And then append the definition to the list of definitions to be added + for _, importedNode := range parsedImportRoot.GetChildren() { + if importedNode.GetType() != dslshape.NodeTypeImport { + importedDefinitionNodes.PushBack(importedNode) + } + } + } + } + + // finally, take the list of definitions to add and prepend them + // (effectively hoists definitions) + root.ConnectAndHoistMany(dslshape.NodePredicateChild, importedDefinitionNodes) + + return nil +} + +func collectPartials(tctx translationContext, rootNode *dslNode) error { + for _, topLevelNode := range rootNode.GetChildren() { + if topLevelNode.GetType() == dslshape.NodeTypePartial { + err := translatePartial(tctx, topLevelNode) + if err != nil { + return err + } + } } + if tctx.unresolvedPartials.Len() != 0 { + return fmt.Errorf("could not resolve partials: [%s]. this may indicate a circular reference", strings.Join(tctx.unresolvedPartials.Keys(), ", ")) + } + return nil +} - compiledSchema, err := importFile(importContext{ - names: names, - path: path, - sourceFolder: tctx.sourceFolder, - globallyVisitedFiles: tctx.globallyVisitedFiles, - locallyVisitedFiles: tctx.locallyVisitedFiles, - }) +// This function modifies the translation context, so we don't need to return anything from it. +func translatePartial(tctx translationContext, partialNode *dslNode) error { + partialName, err := partialNode.GetString(dslshape.NodePartialPredicateName) if err != nil { - var circularImportError *CircularImportError - if errors.As(err, &circularImportError) { - // NOTE: The "%s" is an empty format string to keep with the form of WithSourceErrorf - return nil, importNode.WithSourceErrorf(circularImportError.filePath, "%s", circularImportError.error.Error()) + return err + } + // This needs to return the unresolved name. + errorOnMissingReference := false + relationsAndPermissions, unresolvedPartial, err := translateRelationsAndPermissions(tctx, partialNode, errorOnMissingReference) + if err != nil { + return err + } + if unresolvedPartial != "" { + tctx.unresolvedPartials.Add(unresolvedPartial, partialNode) + return nil + } + + tctx.compiledPartials[partialName] = relationsAndPermissions + + // Since we've successfully compiled a partial, check the unresolved partials to see if any other partial was + // waiting on this partial + // NOTE: we're making an assumption here that a partial can't end up back in the same + // list of unresolved partials - if it hangs again in a different spot, it will end up in a different + // list of unresolved partials. + waitingPartials, _ := tctx.unresolvedPartials.Get(partialName) + for _, waitingPartialNode := range waitingPartials { + err := translatePartial(tctx, waitingPartialNode) + if err != nil { + return err } - return nil, err } + // Clear out this partial's key from the unresolved partials if it's not already empty. + tctx.unresolvedPartials.RemoveKey(partialName) + return nil +} - return compiledSchema, nil +// NOTE: we treat partial references in definitions and partials differently because a missing partial +// reference in definition compilation is an error state, where a missing partial reference in a +// partial definition is an indeterminate state. +func translatePartialReference(tctx translationContext, partialReferenceNode *dslNode, errorOnMissingReference bool) ([]*core.Relation, string, error) { + name, err := partialReferenceNode.GetString(dslshape.NodePartialReferencePredicateName) + if err != nil { + return nil, "", err + } + relationsAndPermissions, ok := tctx.compiledPartials[name] + if !ok { + if errorOnMissingReference { + return nil, "", partialReferenceNode.Errorf("could not find partial reference with name %s", name) + } + // If the partial isn't present and we're not throwing an error, we return the name of the missing partial + // This behavior supports partial collection + return nil, name, nil + } + return relationsAndPermissions, "", nil }