diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 819122b..e1324de 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -2,9 +2,9 @@ name: lint on: push: - branches: [ "master" ] + branches: [ "main" ] pull_request: - branches: [ "master" ] + branches: [ "main" ] jobs: ormolu: diff --git a/.github/workflows/nix-ci.yml b/.github/workflows/nix-ci.yml index 2e0d471..df161ea 100644 --- a/.github/workflows/nix-ci.yml +++ b/.github/workflows/nix-ci.yml @@ -2,9 +2,9 @@ name: "Test" on: push: - branches: [ "master" ] + branches: [ "main" ] pull_request: - branches: [ "master" ] + branches: [ "main" ] jobs: build: diff --git a/src/Snarkl/AST/LambdaExpr.hs b/src/Snarkl/AST/LambdaExpr.hs index 7435b13..411387e 100644 --- a/src/Snarkl/AST/LambdaExpr.hs +++ b/src/Snarkl/AST/LambdaExpr.hs @@ -7,8 +7,10 @@ module Snarkl.AST.LambdaExpr where import Control.Monad.Error.Class (throwError) +import Control.Monad.State (State, evalState, gets, modify) import Data.Field.Galois (GaloisField) import Data.Kind (Type) +import qualified Data.Map as Map import Snarkl.AST.Expr (Variable) import qualified Snarkl.AST.Expr as Core import Snarkl.Common (Op, UnOp) @@ -59,9 +61,34 @@ betaNormalize = \case EAbs var' e -> EAbs var' (substitute (var, e1) e) EApp e2 e3 -> EApp (substitute (var, e1) e2) (substitute (var, e1) e3) +inline :: Exp k -> Exp k +inline comp = evalState (go comp) mempty + where + go :: Exp k -> State (Map.Map Variable (Exp k)) (Exp k) + go = \case + EVar var -> do + ma <- gets (Map.lookup var) + case ma of + Nothing -> do + pure (EVar var) + Just e -> go e + e@(EVal _) -> pure e + EUnit -> pure EUnit + EUnop op e -> EUnop op <$> go e + EBinop op l r -> EBinop op <$> go l <*> go r + EIf b e1 e2 -> EIf <$> go b <*> go e1 <*> go e2 + EAssert (EVar v1) e@(EAbs _ _) -> do + _e <- go e + modify $ Map.insert v1 _e + pure EUnit + EAssert e1 e2 -> EAssert <$> go e1 <*> go e2 + ESeq e1 e2 -> ESeq <$> go e1 <*> go e2 + EAbs var e -> EAbs var <$> go e + EApp e1 e2 -> EApp <$> go e1 <*> go e2 + expOfLambdaExp :: (Show a) => Exp a -> Core.Exp a expOfLambdaExp _exp = - let coreExp = betaNormalize _exp + let coreExp = betaNormalize $ inline _exp in case expOfLambdaExp' coreExp of Left err -> error err Right e -> e