Skip to content

Commit

Permalink
Implement imports on AST
Browse files Browse the repository at this point in the history
  • Loading branch information
tstirrat15 committed Jan 30, 2025
1 parent 3f72881 commit 0cec82e
Show file tree
Hide file tree
Showing 5 changed files with 382 additions and 136 deletions.
67 changes: 31 additions & 36 deletions pkg/composableschemadsl/compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,55 +86,39 @@ type Option func(*config)

type ObjectPrefixOption func(*config)

type compilationContext struct {
// The set of definition names that we've seen as we compile.
// If these collide we throw an error.
existingNames *mapz.Set[string]
// The global set of files we've visited in the import process.
// If these collide we short circuit, preventing duplicate imports.
globallyVisitedFiles *mapz.Set[string]
// The set of files that we've visited on a particular leg of the recursion.
// This allows for detection of circular imports.
// NOTE: This depends on an assumption that a depth-first search will always
// find a cycle, even if we're otherwise marking globally visited nodes.
locallyVisitedFiles *mapz.Set[string]
}

// Compile compilers the input schema into a set of namespace definition protos.
func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
cctx := compilationContext{
existingNames: mapz.NewSet[string](),
globallyVisitedFiles: mapz.NewSet[string](),
locallyVisitedFiles: mapz.NewSet[string](),
}
return compileImpl(schema, cctx, prefix, opts...)
}

func compileImpl(schema InputSchema, cctx compilationContext, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
cfg := &config{}
prefix(cfg) // required option

for _, fn := range opts {
fn(cfg)
}

mapper := newPositionMapper(schema)
root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode)
errs := root.FindAll(dslshape.NodeTypeError)
if len(errs) > 0 {
err := errorNodeToError(errs[0], mapper)
root, mapper, err := parseSchema(schema)
if err != nil {
return nil, err
}

compiled, err := translate(translationContext{
objectTypePrefix: cfg.objectTypePrefix,
mapper: mapper,
schemaString: schema.SchemaString,
skipValidate: cfg.skipValidation,
// NOTE: import translation is done separately so that partial references
// and definitions defined in separate files can correctly resolve.
err = translateImports(importTranslationContext{
globallyVisitedFiles: mapz.NewSet[string](),
locallyVisitedFiles: mapz.NewSet[string](),
sourceFolder: cfg.sourceFolder,
existingNames: cctx.existingNames,
locallyVisitedFiles: cctx.locallyVisitedFiles,
globallyVisitedFiles: cctx.globallyVisitedFiles,
}, root)
if err != nil {
return nil, err
}

compiled, err := translate(translationContext{
objectTypePrefix: cfg.objectTypePrefix,
mapper: mapper,
schemaString: schema.SchemaString,
skipValidate: cfg.skipValidation,
existingNames: mapz.NewSet[string](),
compiledPartials: make(map[string][]*core.Relation),
partialWaitingRooms: mapz.NewMultiMap[string, *dslNode](),
}, root)
if err != nil {
var withNodeError withNodeError
Expand All @@ -148,6 +132,17 @@ func compileImpl(schema InputSchema, cctx compilationContext, prefix ObjectPrefi
return compiled, nil
}

func parseSchema(schema InputSchema) (*dslNode, input.PositionMapper, error) {
mapper := newPositionMapper(schema)
root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode)
errs := root.FindAll(dslshape.NodeTypeError)
if len(errs) > 0 {
err := errorNodeToError(errs[0], mapper)
return nil, nil, err
}
return root, mapper, nil
}

func errorNodeToError(node *dslNode, mapper input.PositionMapper) error {
if node.GetType() != dslshape.NodeTypeError {
return fmt.Errorf("given none error node")
Expand Down
102 changes: 102 additions & 0 deletions pkg/composableschemadsl/compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,108 @@ func TestCompile(t *testing.T) {
),
},
},
{
"transitive partials",
withTenantPrefix,
`
partial view_partial {
relation user: user;
}
partial transitive_partial {
...view_partial
}
definition simple {
...view_partial
}
`,
"",
[]SchemaDefinition{
namespace.Namespace("sometenant/simple",
namespace.MustRelation("user", nil,
namespace.AllowedRelation("sometenant/user", "..."),
),
),
},
},
{
"transitive partials out of order",
withTenantPrefix,
`
partial transitive_partial {
...view_partial
}
partial view_partial {
relation user: user;
}
definition simple {
...view_partial
}
`,
"",
[]SchemaDefinition{
namespace.Namespace("sometenant/simple",
namespace.MustRelation("user", nil,
namespace.AllowedRelation("sometenant/user", "..."),
),
),
},
},
{
"forking transitive partials out of order",
withTenantPrefix,
`
partial transitive_partial {
...view_partial
...group_partial
}
partial view_partial {
relation user: user;
}
partial group_partial {
relation group: group;
}
definition simple {
...transitive_partial
}
`,
"",
[]SchemaDefinition{
namespace.Namespace("sometenant/simple",
namespace.MustRelation("user", nil,
namespace.AllowedRelation("sometenant/user", "..."),
),
namespace.MustRelation("group", nil,
namespace.AllowedRelation("sometenant/group", "..."),
),
),
},
},
{
"circular reference in partials",
withTenantPrefix,
`
partial one_partial {
...another_partial
}
partial another_partial {
...one_partial
}
definition simple {
...one_partial
}
`,
"could not resolve partials",
[]SchemaDefinition{},
},
{
"explicit relation",
withTenantPrefix,
Expand Down
54 changes: 4 additions & 50 deletions pkg/composableschemadsl/compiler/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package compiler
import (
"fmt"
"os"
"path"
"path/filepath"
"strings"

Expand All @@ -26,63 +25,18 @@ type CircularImportError struct {
filePath string
}

func importFile(importContext importContext) (*CompiledSchema, error) {
if err := validateFilepath(importContext.path); err != nil {
return nil, err
}
filePath := path.Join(importContext.sourceFolder, importContext.path)

newSourceFolder := filepath.Dir(filePath)

currentLocallyVisitedFiles := importContext.locallyVisitedFiles.Copy()

if ok := currentLocallyVisitedFiles.Add(filePath); !ok {
// If we've already visited the file on this particular branch walk, it's
// a circular import issue.
return nil, &CircularImportError{
error: fmt.Errorf("circular import detected: %s has been visited on this branch", filePath),
filePath: filePath,
}
}

if ok := importContext.globallyVisitedFiles.Add(filePath); !ok {
// If the file has already been visited, we short-circuit the import process
// by not reading the schema file in and compiling a schema with an empty string.
// This prevents duplicate definitions from ending up in the output, as well
// as preventing circular imports.
log.Debug().Str("filepath", filePath).Msg("file %s has already been visited in another part of the walk")
return compileImpl(InputSchema{
Source: input.Source(filePath),
SchemaString: "",
},
compilationContext{
existingNames: importContext.names,
locallyVisitedFiles: currentLocallyVisitedFiles,
globallyVisitedFiles: importContext.globallyVisitedFiles,
},
AllowUnprefixedObjectType(),
SourceFolder(newSourceFolder),
)
}

func importFile(filePath string) (*dslNode, error) {
schemaBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read schema file: %w", err)
}
log.Trace().Str("schema", string(schemaBytes)).Str("file", filePath).Msg("read schema from file")

return compileImpl(InputSchema{
parsedSchema, _, err := parseSchema(InputSchema{
Source: input.Source(filePath),
SchemaString: string(schemaBytes),
},
compilationContext{
existingNames: importContext.names,
locallyVisitedFiles: currentLocallyVisitedFiles,
globallyVisitedFiles: importContext.globallyVisitedFiles,
},
AllowUnprefixedObjectType(),
SourceFolder(newSourceFolder),
)
})
return parsedSchema, err
}

// Take a filepath and ensure that it's local to the current context.
Expand Down
9 changes: 9 additions & 0 deletions pkg/composableschemadsl/compiler/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ func (tn *dslNode) Connect(predicate string, other parser.AstNode) {
tn.children[predicate].PushBack(other)
}

// Used to preserve import order when doing import operations on AST
func (tn *dslNode) ConnectAndHoistMany(predicate string, other *list.List) {
if tn.children[predicate] == nil {
tn.children[predicate] = list.New()
}

tn.children[predicate].PushFrontList(other)
}

func (tn *dslNode) MustDecorate(property string, value string) parser.AstNode {
if _, ok := tn.properties[property]; ok {
panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
Expand Down
Loading

0 comments on commit 0cec82e

Please sign in to comment.