Skip to content

Commit

Permalink
Interleave arity and typechecking (#2481)
Browse files Browse the repository at this point in the history
- Closes #2362 

This pr implements a new typechecking algorithm. This algorithm can be
activated using the global flag `--new-typechecker`. This flag will only
take effect on the compilation pipeline but not the repl.

The main difference between the new and old algorithm is that the new
one inserts holes during typechecking. Thus, it does not require the
arity checker pass.

The new algorithm does not yet implement default arguments. The plan is
to make the change in the following steps:
1. Merge this pr.
2. Merge #2506.
3. Implement default arguments for the new algorithm.
4. Remove the arity checker and the old algorithm.

---------

Co-authored-by: Łukasz Czajka <[email protected]>
  • Loading branch information
janmasrovira and lukaszcz authored Nov 12, 2023
1 parent bdb0d9a commit a05586e
Show file tree
Hide file tree
Showing 28 changed files with 1,901 additions and 86 deletions.
24 changes: 16 additions & 8 deletions app/GlobalOptions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ data GlobalOptions = GlobalOptions
_globalNoCoverage :: Bool,
_globalNoStdlib :: Bool,
_globalUnrollLimit :: Int,
_globalOffline :: Bool
_globalOffline :: Bool,
_globalNewTypecheckingAlgorithm :: Bool
}
deriving stock (Eq, Show)

Expand Down Expand Up @@ -60,7 +61,8 @@ defaultGlobalOptions =
_globalNoCoverage = False,
_globalNoStdlib = False,
_globalUnrollLimit = defaultUnrollLimit,
_globalOffline = False
_globalOffline = False,
_globalNewTypecheckingAlgorithm = False
}

-- | Get a parser for global flags which can be hidden or not depending on
Expand All @@ -72,11 +74,6 @@ parseGlobalFlags = do
( long "no-colors"
<> help "Disable ANSI formatting"
)
_globalShowNameIds <-
switch
( long "show-name-ids"
<> help "Show the unique number of each identifier when pretty printing"
)
_globalBuildDir <-
optional
( parseBuildDir
Expand Down Expand Up @@ -126,6 +123,16 @@ parseGlobalFlags = do
( long "offline"
<> help "Disable access to network resources"
)
_globalShowNameIds <-
switch
( long "show-name-ids"
<> help "[DEV] Show the unique number of each identifier when pretty printing"
)
_globalNewTypecheckingAlgorithm <-
switch
( long "new-typechecker"
<> help "[DEV] Use the new experimental typechecker"
)
return GlobalOptions {..}

parseBuildDir :: Mod OptionFields (Prepath Dir) -> Parser (AppPath Dir)
Expand Down Expand Up @@ -158,7 +165,8 @@ entryPointFromGlobalOptions root mainFile opts = do
_entryPointUnrollLimit = opts ^. globalUnrollLimit,
_entryPointGenericOptions = project opts,
_entryPointBuildDir = maybe (def ^. entryPointBuildDir) (CustomBuildDir . Abs) mabsBuildDir,
_entryPointOffline = opts ^. globalOffline
_entryPointOffline = opts ^. globalOffline,
_entryPointNewTypeCheckingAlgorithm = opts ^. globalNewTypecheckingAlgorithm
}
where
optBuildDir :: Maybe (Prepath Dir)
Expand Down
1 change: 0 additions & 1 deletion src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ goType ::
Sem r Type
goType ty = do
normTy <- strongNormalizeHelper ty
-- traceM ("ty = " <> Internal.ppTrace ty <> ". Normalized = " <> Internal.ppTrace normTy)
squashApps <$> goExpression normTy

mkFunBody ::
Expand Down
18 changes: 12 additions & 6 deletions src/Juvix/Compiler/Internal/Data/LocalVars.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ data LocalVars = LocalVars
instance Semigroup LocalVars where
(LocalVars a b) <> (LocalVars a' b') = LocalVars (a <> a') (b <> b')

emptyLocalVars :: LocalVars
emptyLocalVars =
LocalVars
{ _localTypes = mempty,
_localTyMap = mempty
}

instance Monoid LocalVars where
mempty = emptyLocalVars

makeLenses ''LocalVars

withLocalTypeMaybe :: (Members '[Reader LocalVars] r) => Maybe VarName -> Expression -> Sem r a -> Sem r a
Expand All @@ -28,9 +38,5 @@ addType v t = over localTypes (HashMap.insert v t)
addTypeMapping :: VarName -> VarName -> LocalVars -> LocalVars
addTypeMapping v v' = over localTyMap (HashMap.insert v v')

emptyLocalVars :: LocalVars
emptyLocalVars =
LocalVars
{ _localTypes = mempty,
_localTyMap = mempty
}
withEmptyLocalVars :: Sem (Reader LocalVars ': r) a -> Sem r a
withEmptyLocalVars = runReader emptyLocalVars
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Internal/Data/TypedHole.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Juvix.Compiler.Internal.Language
import Juvix.Prelude

data TypedHole = TypedHole
{ _typedHoleHole :: Hole,
{ _typedHoleHole :: InstanceHole,
_typedHoleType :: Expression,
_typedHoleLocalVars :: LocalVars
}
Expand Down
35 changes: 30 additions & 5 deletions src/Juvix/Compiler/Internal/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,20 @@ holes = leafExpressions . _ExpressionHole
hasHoles :: (HasExpressions a) => a -> Bool
hasHoles = has holes

subsHoles :: forall a r. (HasExpressions a, Member NameIdGen r) => HashMap Hole Expression -> a -> Sem r a
subsInstanceHoles :: forall r a. (HasExpressions a, Member NameIdGen r) => HashMap InstanceHole Expression -> a -> Sem r a
subsInstanceHoles s = leafExpressions helper
where
helper :: Expression -> Sem r Expression
helper e = case e of
ExpressionInstanceHole h -> clone (fromMaybe e (s ^. at h))
_ -> return e

subsHoles :: forall r a. (HasExpressions a, Member NameIdGen r) => HashMap Hole Expression -> a -> Sem r a
subsHoles s = leafExpressions helper
where
helper :: Expression -> Sem r Expression
helper e = case e of
ExpressionHole h -> clone (fromMaybe e (s ^. at h))
ExpressionInstanceHole h -> clone (fromMaybe e (s ^. at h))
_ -> return e

instance HasExpressions Example where
Expand Down Expand Up @@ -357,10 +364,20 @@ unfoldTypeAbsType t = case t of
foldExplicitApplication :: Expression -> [Expression] -> Expression
foldExplicitApplication f = foldApplication f . map (ApplicationArg Explicit)

foldApplication' :: Expression -> NonEmpty ApplicationArg -> Application
foldApplication' f (arg :| args) =
let ApplicationArg i a = arg
in go (Application f a i) args
where
go :: Application -> [ApplicationArg] -> Application
go acc = \case
[] -> acc
ApplicationArg i a : as -> go (Application (ExpressionApplication acc) a i) as

foldApplication :: Expression -> [ApplicationArg] -> Expression
foldApplication f args = case args of
[] -> f
ApplicationArg i a : as -> foldApplication (ExpressionApplication (Application f a i)) as
foldApplication f args = case nonEmpty args of
Nothing -> f
Just args' -> ExpressionApplication (foldApplication' f args')

unfoldApplication' :: Application -> (Expression, NonEmpty ApplicationArg)
unfoldApplication' (Application l' r' i') = second (|: (ApplicationArg i' r')) (unfoldExpressionApp l')
Expand Down Expand Up @@ -554,6 +571,11 @@ infix 4 ==%
(==%) :: (IsExpression a, IsExpression b) => a -> b -> HashSet Name -> Bool
(==%) a b free = leftEq a b free || leftEq b a free

infixl 9 @@?

(@@?) :: (IsExpression a, IsExpression b) => a -> b -> IsImplicit -> Expression
a @@? b = toExpression . Application (toExpression a) (toExpression b)

infixl 9 @@

(@@) :: (IsExpression a, IsExpression b) => a -> b -> Expression
Expand All @@ -580,6 +602,9 @@ genWildcard loc impl = do
var <- varFromWildcard (Wildcard loc)
return (PatternArg impl Nothing (PatternVariable var))

freshInstanceHole :: (Members '[NameIdGen] r) => Interval -> Sem r InstanceHole
freshInstanceHole l = mkInstanceHole l <$> freshNameId

freshHole :: (Members '[NameIdGen] r) => Interval -> Sem r Hole
freshHole l = mkHole l <$> freshNameId

Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Internal/Extra/Clonable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ instance (Clonable a) => Clonable (WithLoc a) where
instance Clonable Literal where
freshNameIds = return

instance Clonable InstanceHole where
freshNameIds = return

instance Clonable Hole where
freshNameIds = return

Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Internal/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ data Expression
| ExpressionFunction Function
| ExpressionLiteral LiteralLoc
| ExpressionHole Hole
| ExpressionInstanceHole Hole
| ExpressionInstanceHole InstanceHole
| ExpressionLet Let
| ExpressionUniverse SmallUniverse
| ExpressionSimpleLambda SimpleLambda
Expand Down
32 changes: 31 additions & 1 deletion src/Juvix/Compiler/Internal/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ where
import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Internal.Data.InfoTable.Base
import Juvix.Compiler.Internal.Data.InstanceInfo (instanceInfoResult, instanceTableMap)
import Juvix.Compiler.Internal.Data.LocalVars
import Juvix.Compiler.Internal.Data.NameDependencyInfo
import Juvix.Compiler.Internal.Data.TypedHole
import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Pretty.Options
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.ArityChecking.Data.Types (Arity)
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.TypeChecking.CheckerNew.Arity qualified as New
import Juvix.Data.CodeAnn
import Juvix.Prelude

Expand Down Expand Up @@ -48,7 +51,7 @@ instance PrettyCode SimpleLambda where
ppCode l = do
b' <- ppCode (l ^. slambdaBody)
v' <- ppCode (l ^. slambdaBinder . sbinderVar)
return $ kwLambda <+> braces (v' <+> kwAssign <+> b')
return $ kwSimpleLambda <+> braces (v' <+> kwAssign <+> b')

instance PrettyCode Application where
ppCode a = do
Expand Down Expand Up @@ -192,6 +195,11 @@ instance PrettyCode Function where
funReturn' <- ppRightExpression funFixity r
return $ funParameter' <+> kwArrow <+> funReturn'

instance PrettyCode InstanceHole where
ppCode h = do
showNameId <- asks (^. optShowNameIds)
return (addNameIdTag showNameId (h ^. iholeId) kwHole)

instance PrettyCode Hole where
ppCode h = do
showNameId <- asks (^. optShowNameIds)
Expand Down Expand Up @@ -323,6 +331,28 @@ instance PrettyCode Module where
instance PrettyCode Interval where
ppCode = return . annotate AnnCode . pretty

instance PrettyCode New.ArityParameter where
ppCode = return . pretty

instance (PrettyCode a, PrettyCode b) => PrettyCode (Either a b) where
ppCode = \case
Left l -> do
l' <- ppCode l
return ("Left" <+> l')
Right r -> do
r' <- ppCode r
return ("Right" <+> r')

instance PrettyCode LocalVars where
ppCode LocalVars {..} = ppCode (HashMap.toList _localTypes)

instance PrettyCode TypedHole where
ppCode TypedHole {..} = do
h <- ppCode _typedHoleHole
ty <- ppCode _typedHoleType
vars <- ppCode _typedHoleLocalVars
return (h <+> kwColon <+> ty <> kwAt <> vars)

instance PrettyCode InfoTable where
ppCode tbl = do
inds <- ppCode (HashMap.keys (tbl ^. infoInductives))
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Internal/Translation/FromConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ goExpression = \case
ExpressionUniverse uni -> return (Internal.ExpressionUniverse (goUniverse uni))
ExpressionFunction func -> Internal.ExpressionFunction <$> goFunction func
ExpressionHole h -> return (Internal.ExpressionHole h)
ExpressionInstanceHole h -> return (Internal.ExpressionInstanceHole h)
ExpressionInstanceHole h -> return (Internal.ExpressionInstanceHole (fromHole h))
ExpressionIterator i -> goIterator i
ExpressionNamedApplication i -> goNamedApplication i
ExpressionNamedApplicationNew i -> goNamedApplicationNew i
Expand Down
44 changes: 43 additions & 1 deletion src/Juvix/Compiler/Internal/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Juvix.Compiler.Internal.Translation.FromInternal
( module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Reachability,
arityChecking,
typeChecking,
typeCheckingNew,
typeCheckExpression,
typeCheckExpressionType,
arityCheckExpression,
Expand All @@ -15,11 +16,13 @@ import Juvix.Compiler.Builtins.Effect
import Juvix.Compiler.Concrete.Data.Highlight.Input
import Juvix.Compiler.Internal.Language
import Juvix.Compiler.Internal.Pretty
import Juvix.Compiler.Internal.Translation.FromConcrete.Data.Context
import Juvix.Compiler.Internal.Translation.FromConcrete.Data.Context as Internal
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.ArityChecking qualified as ArityChecking
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.ArityChecking.Data.Context (InternalArityResult (InternalArityResult))
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Reachability
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Checker
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.TypeChecking
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.TypeChecking.CheckerNew qualified as New
import Juvix.Compiler.Pipeline.Artifacts
import Juvix.Compiler.Pipeline.EntryPoint
import Juvix.Data.Effect.NameIdGen
Expand Down Expand Up @@ -147,3 +150,42 @@ typeChecking a = do
_resultFunctions = funs,
_resultInfoTable = table
}

typeCheckingNew ::
forall r.
(Members '[HighlightBuilder, Error JuvixError, Builtins, NameIdGen] r) =>
Sem (Termination ': r) InternalResult ->
Sem r InternalTypedResult
typeCheckingNew a = do
(termin, (res, table, (normalized, (idens, (funs, r))))) <- runTermination iniTerminationState $ do
res :: InternalResult <- a
let table :: InfoTable
table = buildTable (res ^. Internal.resultModules)

entryPoint :: EntryPoint
entryPoint = res ^. Internal.internalResultEntryPoint
fmap (res,table,)
. runOutputList
. runReader entryPoint
. runState (mempty :: TypesTable)
. runState (mempty :: FunctionsTable)
. runReader table
. mapError (JuvixError @TypeCheckerError)
. evalCacheEmpty New.checkModuleNoCache
$ checkTable >> mapM New.checkModule (res ^. Internal.resultModules)
let ariRes :: InternalArityResult
ariRes =
InternalArityResult
{ _resultInternalResult = res,
_resultModules = res ^. Internal.resultModules
}
return
InternalTypedResult
{ _resultInternalArityResult = ariRes,
_resultModules = r,
_resultTermination = termin,
_resultNormalized = HashMap.fromList [(e ^. exampleId, e ^. exampleExpression) | e <- normalized],
_resultIdenTypes = idens,
_resultFunctions = funs,
_resultInfoTable = table
}
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ checkLhs loc guessedBody ariSignature pats = do

-- This is an heuristic and it can have an undesired result.
-- Sometimes the outcome may even be confusing.
-- TODO default arguments??
tailHelper :: Arity -> Maybe [IsImplicit]
tailHelper a
| 0 < pref = Just pref'
Expand Down Expand Up @@ -894,5 +893,5 @@ newHoleImplicit i loc = case i ^. arityParameterInfo . argInfoDefault of
-- TODO update location
return (True, e)

newHoleInstance :: (Member NameIdGen r) => Interval -> Sem r Hole
newHoleInstance loc = mkHole loc <$> freshNameId
newHoleInstance :: (Member NameIdGen r) => Interval -> Sem r InstanceHole
newHoleInstance loc = mkInstanceHole loc <$> freshNameId
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Che
runTermination,
evalTermination,
execTermination,
functionIsTerminating,
functionSafeToNormalize,
module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.TerminationState,
)
where
Expand All @@ -28,13 +28,8 @@ data Termination m a where

makeSem ''Termination

functionIsTerminating :: (Members '[Termination] r) => FunctionRef -> Sem r Bool
functionIsTerminating = fmap terminates . functionTermination
where
terminates :: IsTerminating -> Bool
terminates = \case
TerminatingCheckedOrMarked -> True
TerminatingFailed -> False
functionSafeToNormalize :: (Members '[Termination] r) => FunctionRef -> Sem r Bool
functionSafeToNormalize = fmap safeToNormalize . functionTermination

runTermination :: forall r a. (Members '[Error JuvixError] r) => TerminationState -> Sem (Termination ': r) a -> Sem r (TerminationState, a)
runTermination ini m = do
Expand Down Expand Up @@ -79,7 +74,7 @@ functionTermination' ::
(Members '[State TerminationState] r) =>
FunctionName ->
Sem r IsTerminating
functionTermination' f = fromMaybe TerminatingCheckedOrMarked <$> gets (^. terminationTable . at f)
functionTermination' f = fromMaybe TerminatingChecked <$> gets (^. terminationTable . at f)

-- | Returns the set of non-terminating functions. Does not go into imports.
checkTerminationShallow' ::
Expand All @@ -102,9 +97,9 @@ checkTerminationShallow' topModule = do
order = findOrder rb
addTerminating funName $
if
| markedTerminating -> TerminatingCheckedOrMarked
| Just {} <- order -> TerminatingChecked
| markedTerminating -> TerminatingFailedMarked
| Nothing <- order -> TerminatingFailed
| Just {} <- order -> TerminatingCheckedOrMarked

scanModule ::
(Members '[State CallMap] r) =>
Expand Down
Loading

0 comments on commit a05586e

Please sign in to comment.