From aa40fa4f6647b5bfc7aae096f75122d9e68ba783 Mon Sep 17 00:00:00 2001 From: martyall Date: Sat, 6 Jan 2024 13:35:31 -0800 Subject: [PATCH 01/19] wip --- flake.lock | 17 ++++++- snarkl.cabal | 2 +- src/Snarkl/Language/LambdaExpr.hs | 82 +++++++++++-------------------- 3 files changed, 47 insertions(+), 54 deletions(-) diff --git a/flake.lock b/flake.lock index f706238..18024b6 100644 --- a/flake.lock +++ b/flake.lock @@ -84,6 +84,20 @@ } }, "flake-compat": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "revCount": 57, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz" + } + }, + "flake-compat_2": { "flake": false, "locked": { "lastModified": 1672831974, @@ -195,7 +209,7 @@ "cabal-34": "cabal-34", "cabal-36": "cabal-36", "cardano-shell": "cardano-shell", - "flake-compat": "flake-compat", + "flake-compat": "flake-compat_2", "ghc-8.6.5-iohk": "ghc-8.6.5-iohk", "ghc98X": "ghc98X", "ghc99": "ghc99", @@ -594,6 +608,7 @@ }, "root": { "inputs": { + "flake-compat": "flake-compat", "flake-utils": "flake-utils", "haskellNix": "haskellNix", "nixpkgs": [ diff --git a/snarkl.cabal b/snarkl.cabal index f6a9bfa..31d4b3f 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -25,7 +25,7 @@ source-repository head library ghc-options: - -Wall -Wredundant-constraints -Werror -funbox-strict-fields + -Wall -Wredundant-constraints -funbox-strict-fields -optc-O3 -- -threaded diff --git a/src/Snarkl/Language/LambdaExpr.hs b/src/Snarkl/Language/LambdaExpr.hs index c2d6873..faf4623 100644 --- a/src/Snarkl/Language/LambdaExpr.hs +++ b/src/Snarkl/Language/LambdaExpr.hs @@ -4,6 +4,7 @@ module Snarkl.Language.LambdaExpr ( Exp (..), expOfLambdaExp, expBinop, + betaNormalize, ) where @@ -33,58 +34,35 @@ deriving instance (Show a) => Show (Exp a) deriving instance (Eq a) => Eq (Exp a) -type Env a = Map Variable (Exp a) - -defaultEnv :: Env a -defaultEnv = Map.empty - -applyLambdas :: - Exp a -> - (Exp a, Env a) -applyLambdas expression = runState (go expression) defaultEnv - where - go :: Exp a -> State (Env a) (Exp a) - go = \case - EVar var -> do - ma <- gets (Map.lookup var) - case ma of - Nothing -> do - pure (EVar var) - Just e -> go e - e@(EVal _) -> pure e - EUnit -> pure EUnit - EUnop op e -> EUnop op <$> go e - EBinop op es -> EBinop op <$> mapM go es - EIf b e1 e2 -> EIf <$> go b <*> go e1 <*> go e2 - EAssert (EVar v1) e@(EAbs _ _) -> do - _e <- go e - modify $ Map.insert v1 _e - pure EUnit - EAssert e1 e2 -> EAssert <$> go e1 <*> go e2 - ESeq e1 e2 -> ESeq <$> go e1 <*> go e2 - EAbs var e -> do - EAbs var <$> go e - EApp (EAbs var e1) e2 -> do - _e2 <- go e2 - modify $ Map.insert var e2 - go (substitute var e2 e1) - EApp e1 e2 -> do - _e1 <- go e1 - _e2 <- go e2 - go (EApp _e1 _e2) - -substitute :: Variable -> Exp a -> Exp a -> Exp a -substitute var e1 = \case - e@(EVar var') -> if var == var' then e1 else e - e@(EVal _) -> e +betaNormalize :: Exp a -> Exp a +betaNormalize = \case + EVar x -> EVar x + EVal v -> EVal v + EUnop op e -> EUnop op (betaNormalize e) + EBinop op es -> EBinop op (betaNormalize <$> es) + EIf e1 e2 e3 -> EIf (betaNormalize e1) (betaNormalize e2) (betaNormalize e3) + EAssert e1 e2 -> EAssert (betaNormalize e1) (betaNormalize e2) + ESeq e1 e2 -> ESeq (betaNormalize e1) (betaNormalize e2) + EAbs v e -> EAbs v (betaNormalize e) + EApp f a -> + case betaNormalize f of + EAbs v e -> substitute (v, betaNormalize a) e + f' -> EApp f' (betaNormalize a) EUnit -> EUnit - EUnop op e -> EUnop op (substitute var e1 e) - EBinop op es -> EBinop op (map (substitute var e1) es) - EIf b e2 e3 -> EIf (substitute var e1 b) (substitute var e1 e2) (substitute var e1 e3) - EAssert e2 e3 -> EAssert (substitute var e1 e2) (substitute var e1 e3) - ESeq e2 e3 -> ESeq (substitute var e1 e2) (substitute var e1 e3) - EAbs var' e -> EAbs var' (substitute var e1 e) - EApp e2 e3 -> EApp (substitute var e1 e2) (substitute var e1 e3) + where + -- substitute x e1 e2 = e2 [x := e1 ] + substitute :: (Variable, Exp a) -> Exp a -> Exp a + substitute (var, e1) = \case + e@(EVar var') -> if var == var' then e1 else e + e@(EVal _) -> e + EUnit -> EUnit + EUnop op e -> EUnop op (substitute (var, e1) e) + EBinop op es -> EBinop op (substitute (var, e1) <$> es) + EIf b e2 e3 -> EIf (substitute (var, e1) b) (substitute (var, e1) e2) (substitute (var, e1) e3) + EAssert e2 e3 -> EAssert (substitute (var, e1) e2) (substitute (var, e1) e3) + ESeq e2 e3 -> ESeq (substitute (var, e1) e2) (substitute (var, e1) e3) + EAbs var' e -> EAbs var' (substitute (var, e1) e) + EApp e2 e3 -> EApp (substitute (var, e1) e2) (substitute (var, e1) e3) expBinop :: Op -> Exp a -> Exp a -> Exp a expBinop op e1 e2 = @@ -112,7 +90,7 @@ expSeq e1 e2 = expOfLambdaExp :: (Show a) => Exp a -> Core.Exp a expOfLambdaExp _exp = - let (coreExp, _) = applyLambdas _exp + let coreExp = betaNormalize _exp in case expOfLambdaExp' coreExp of Left err -> error err Right e -> e From 0352b0ab641660259ece3a3cbebad69d044b5ffe Mon Sep 17 00:00:00 2001 From: martyall Date: Sat, 6 Jan 2024 15:26:19 -0800 Subject: [PATCH 02/19] restrict exports --- snarkl.cabal | 2 +- src/Snarkl/Language.hs | 80 ++++++++++++++++++++++++++++-- src/Snarkl/Language/LambdaExpr.hs | 5 +- src/Snarkl/Language/SyntaxMonad.hs | 2 +- 4 files changed, 79 insertions(+), 10 deletions(-) diff --git a/snarkl.cabal b/snarkl.cabal index 31d4b3f..875f47a 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -25,7 +25,7 @@ source-repository head library ghc-options: - -Wall -Wredundant-constraints -funbox-strict-fields + -Wall -Werror -Wredundant-constraints -funbox-strict-fields -optc-O3 -- -threaded diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index 0f1e150..733a182 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -1,9 +1,80 @@ module Snarkl.Language ( expOfTExp, - module Snarkl.Language.TExpr, + booleanVarsOfTexp, + TExp, module Snarkl.Language.Expr, - module Snarkl.Language.SyntaxMonad, - module Snarkl.Language.Syntax, + -- | SyntaxMonad + Comp, + runState, + return, + (>>=), + (>>), + Env (..), + -- | Return a fresh input variable. + fresh_input, + -- | Classes + Zippable, + Derive, + -- | Basic values + unit, + false, + true, + fromField, + -- | Sums, products, recursive types + inl, + inr, + case_sum, + pair, + fst_pair, + snd_pair, + roll, + unroll, + fixN, + fix, + -- | Arithmetic and boolean operations + (+), + (-), + (*), + (/), + (&&), + zeq, + not, + xor, + eq, + beq, + exp_of_int, + inc, + dec, + ifThenElse, + negate, + -- | Arrays + arr, + arr2, + arr3, + input_arr, + input_arr2, + input_arr3, + set, + set2, + set3, + set4, + get, + get2, + get3, + get4, + -- | Iteration + iter, + iterM, + bigsum, + times, + forall, + forall2, + forall3, + -- | Function combinators + lambda, + curry, + uncurry, + apply, ) where @@ -14,6 +85,7 @@ import Snarkl.Language.LambdaExpr (expOfLambdaExp) import Snarkl.Language.Syntax import Snarkl.Language.SyntaxMonad import Snarkl.Language.TExpr +import qualified Prelude expOfTExp :: (GaloisField a, Typeable ty) => TExp ty a -> Exp a -expOfTExp = expOfLambdaExp . lambdaExpOfTExp +expOfTExp = expOfLambdaExp Prelude.. lambdaExpOfTExp diff --git a/src/Snarkl/Language/LambdaExpr.hs b/src/Snarkl/Language/LambdaExpr.hs index faf4623..57edbd1 100644 --- a/src/Snarkl/Language/LambdaExpr.hs +++ b/src/Snarkl/Language/LambdaExpr.hs @@ -9,11 +9,8 @@ module Snarkl.Language.LambdaExpr where import Control.Monad.Error.Class (throwError) -import Control.Monad.State (State, gets, modify, runState) import Data.Field.Galois (GaloisField) import Data.Kind (Type) -import Data.Map (Map) -import qualified Data.Map as Map import Snarkl.Common (Op, UnOp, isAssoc) import Snarkl.Language.Expr (Variable) import qualified Snarkl.Language.Expr as Core @@ -105,4 +102,4 @@ expOfLambdaExp _exp = EIf b e1 e2 -> Core.EIf <$> expOfLambdaExp' b <*> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 EAssert e1 e2 -> Core.EAssert <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 ESeq e1 e2 -> expSeq <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 - e -> throwError ("Impossible after IR simplicifaction: " <> show e) + e -> throwError ("Impossible after lambda simplicifaction: " <> show e) diff --git a/src/Snarkl/Language/SyntaxMonad.hs b/src/Snarkl/Language/SyntaxMonad.hs index ee57287..1aa7ead 100644 --- a/src/Snarkl/Language/SyntaxMonad.hs +++ b/src/Snarkl/Language/SyntaxMonad.hs @@ -58,7 +58,7 @@ import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Language.Expr (Variable (..)) import Snarkl.Language.TExpr ( Loc, - TExp (TEAssert, TEBinop, TEBot, TESeq, TEUnop, TEVal, TEVar), + TExp (..), TLoc (TLoc), TVar (TVar), Ty (TArr, TBool, TProd, TUnit), From 91d5ee5a6ad836456a8028af6d75f9ded3a93c97 Mon Sep 17 00:00:00 2001 From: martyall Date: Sat, 6 Jan 2024 16:16:21 -0800 Subject: [PATCH 03/19] control input types --- examples/Snarkl/Example/Basic.hs | 5 +++- examples/Snarkl/Example/Games.hs | 47 ++++++++++++++++++++---------- examples/Snarkl/Example/List.hs | 6 ++-- examples/Snarkl/Example/Queue.hs | 19 ++++++------ src/Snarkl/Language.hs | 2 +- src/Snarkl/Language/SyntaxMonad.hs | 26 ++++++++++++++--- tests/Test/Snarkl/Unit/Programs.hs | 3 +- 7 files changed, 74 insertions(+), 34 deletions(-) diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index 7a86fd1..c2cd978 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -1,4 +1,5 @@ {-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE TypeApplications #-} module Snarkl.Example.Basic where @@ -48,10 +49,12 @@ desugar1 = compileCompToTexp p1 interp1 :: (GaloisField k) => k interp1 = comp_interp p1 [] +p2 :: (GaloisField k) => Comp 'TField k p2 = do - x <- fresh_input + x <- fresh_input @'TField return $ x + x +desugar2 :: (GaloisField k) => TExpPkg 'TField k desugar2 = compileCompToTexp p2 interp2 :: (GaloisField k) => k diff --git a/examples/Snarkl/Example/Games.hs b/examples/Snarkl/Example/Games.hs index e152e2c..a66c537 100644 --- a/examples/Snarkl/Example/Games.hs +++ b/examples/Snarkl/Example/Games.hs @@ -4,7 +4,7 @@ {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} -module Snarkl.Example.Games where +module Snarkl.Example.Games () where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable @@ -42,7 +42,7 @@ data Game :: Ty -> * -> * where Single :: forall (s :: Ty) (t :: Ty) k. ( Typeable s, - Typeable t + InputType s ) => ISO t s k -> Game t k @@ -55,15 +55,17 @@ data Game :: Ty -> * -> * where Zippable t2 k, Zippable t k, Derive t1 k, - Derive t2 k + Derive t2 k, + InputType t1, + InputType t2 ) => ISO t ('TSum t1 t2) k -> Game t1 k -> Game t2 k -> Game t k -decode :: (GaloisField k) => Game t k -> Comp t k -decode (Single (Iso _ bld)) = +decode :: (GaloisField k, InputType t) => Game t k -> Comp t k +decode (Single (Iso _ bld :: ISO t s k)) = do x <- fresh_input bld x @@ -92,7 +94,7 @@ bool_game = unit_game :: (GaloisField k) => Game 'TUnit k unit_game = Single (Iso (\_ -> return (fromField 1)) (\(_ :: TExp 'TField k) -> return unit)) -fail_game :: (Typeable ty) => Game ty p +fail_game :: (Typeable ty, InputType ty) => Game ty p fail_game = Single ( Iso @@ -109,7 +111,9 @@ sum_game :: Zippable t2 k, Derive t1 k, Derive t2 k, - GaloisField k + GaloisField k, + InputType t1, + InputType t2 ) => Game t1 k -> Game t2 k -> @@ -120,11 +124,11 @@ sum_game g1 g2 = basic_game :: (GaloisField k) => Game ('TSum 'TField 'TField) k basic_game = sum_game field_game field_game +{- basic_test :: (GaloisField k) => Comp 'TField k -basic_test = - do - s <- decode basic_game - case_sum return return s +basic_test = do + s <- decode basic_game + case_sum return return s t1 :: F_BN128 t1 = comp_interp basic_test [0, 23, 88] -- 23 @@ -132,11 +136,15 @@ t1 = comp_interp basic_test [0, 23, 88] -- 23 t2 :: F_BN128 t2 = comp_interp basic_test [1, 23, 88] -- 88 +-} + (+>) :: ( Typeable t, Typeable s, Zippable t k, - Zippable s k + Zippable s k, + InputType t, + InputType s ) => Game t k -> ISO s t k -> @@ -179,6 +187,7 @@ seqI (Iso f g) (Iso f' g') = Iso (\a -> f a >>= f') (\c -> g' c >>= g) prodLInputI :: ( Typeable a, + InputType a, Typeable b, GaloisField k ) => @@ -243,7 +252,9 @@ prod_game :: Zippable b k, Derive a k, Derive b k, - GaloisField k + GaloisField k, + InputType a, + InputType b ) => Game a k -> Game b k -> @@ -309,7 +320,9 @@ instance Derive b k, Gameable a k, Gameable b k, - GaloisField k + GaloisField k, + InputType a, + InputType b ) => Gameable ('TProd a b) k where @@ -324,11 +337,13 @@ instance Derive b k, Gameable a k, Gameable b k, - GaloisField k + GaloisField k, + InputType a, + InputType b ) => Gameable ('TSum a b) k where mkGame = sum_game mkGame mkGame -gdecode :: (Gameable t k, GaloisField k) => Comp t k +gdecode :: (Gameable t k, InputType t, GaloisField k) => Comp t k gdecode = decode mkGame diff --git a/examples/Snarkl/Example/List.hs b/examples/Snarkl/Example/List.hs index fe2157b..5dc7a69 100644 --- a/examples/Snarkl/Example/List.hs +++ b/examples/Snarkl/Example/List.hs @@ -1,4 +1,6 @@ +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE TypeApplications #-} module Snarkl.Example.List where @@ -216,11 +218,11 @@ list_comp4 = l <- list2 last_list (fromField 0) l -listN :: (Typeable a, Zippable a k, Derive a k, GaloisField k) => TExp 'TField k -> Comp (TList a) k +listN :: (Typeable 'TField, Zippable 'TField k, Derive 'TField k, GaloisField k) => TExp 'TField k -> Comp (TList 'TField) k listN n = fixN 100 go n where go self n0 = do - x <- fresh_input + x <- fresh_input @TField tl <- self (n0 - fromField 1) if return (eq n0 (fromField 0)) then nil else cons x tl diff --git a/examples/Snarkl/Example/Queue.hs b/examples/Snarkl/Example/Queue.hs index b306692..e00ec44 100644 --- a/examples/Snarkl/Example/Queue.hs +++ b/examples/Snarkl/Example/Queue.hs @@ -1,4 +1,5 @@ {-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE TypeApplications #-} module Snarkl.Example.Queue where @@ -184,15 +185,15 @@ queue_comp3 = sx <- dequeue q1 (fromField 0) fst_pair sx -queueN :: (Typeable a, Zippable a k, Derive a k, GaloisField k) => TExp 'TField k -> Comp (TQueue a) k -queueN n = fixN 100 go n - where - go self n0 = do - x <- fresh_input - tl <- self (n0 - fromField 1) - if return (eq n0 (fromField 0)) - then empty_queue - else enqueue x tl +queueN :: forall a k. (Typeable a, Zippable a k, Derive a k, GaloisField k, InputType a) => TExp 'TField k -> Comp (TQueue a) k +queueN n = + let go self n0 = do + x <- fresh_input @a + tl <- self (n0 - fromField 1) + if return (eq n0 (fromField 0)) + then empty_queue + else enqueue x tl + in fixN 100 go n test_queueN :: (GaloisField k) => Comp 'TField k test_queueN = do diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index 733a182..fce16f7 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -11,7 +11,7 @@ module Snarkl.Language (>>), Env (..), -- | Return a fresh input variable. - fresh_input, + InputType (..), -- | Classes Zippable, Derive, diff --git a/src/Snarkl/Language/SyntaxMonad.hs b/src/Snarkl/Language/SyntaxMonad.hs index 1aa7ead..1a8fa31 100644 --- a/src/Snarkl/Language/SyntaxMonad.hs +++ b/src/Snarkl/Language/SyntaxMonad.hs @@ -15,7 +15,7 @@ module Snarkl.Language.SyntaxMonad Env (..), State (..), -- | Return a fresh input variable. - fresh_input, + InputType (..), -- | Return a fresh variable. fresh_var, -- | Return a fresh location. @@ -61,7 +61,7 @@ import Snarkl.Language.TExpr TExp (..), TLoc (TLoc), TVar (TVar), - Ty (TArr, TBool, TProd, TUnit), + Ty (TArr, TBool, TField, TProd, TUnit), Val (VFalse, VLoc, VTrue, VUnit), lastSeq, locOfTexp, @@ -443,8 +443,8 @@ fresh_var = ) ) -fresh_input :: State (Env k) (TExp ty a) -fresh_input = +_fresh_input :: State (Env k) (TExp ty a) +_fresh_input = State ( \s -> let (v, nextVar) = runSupply (Variable <$> fresh) (next_variable s) @@ -457,6 +457,24 @@ fresh_input = ) ) +class (Typeable ty) => InputType (ty :: Ty) where + fresh_input :: (GaloisField k) => State (Env k) (TExp ty k) + +instance InputType 'TBool where + fresh_input = _fresh_input + +instance InputType 'TUnit where + fresh_input = _fresh_input + +instance InputType 'TField where + fresh_input = _fresh_input + +instance (InputType ty1, InputType ty2) => InputType ('TProd ty1 ty2) where + fresh_input = do + x1 <- fresh_input + x2 <- fresh_input + pair x1 x2 + fresh_loc :: State (Env k) (TExp ty a) fresh_loc = State diff --git a/tests/Test/Snarkl/Unit/Programs.hs b/tests/Test/Snarkl/Unit/Programs.hs index fa316ef..3098c25 100644 --- a/tests/Test/Snarkl/Unit/Programs.hs +++ b/tests/Test/Snarkl/Unit/Programs.hs @@ -1,4 +1,5 @@ {-# LANGUAGE RebindableSyntax #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# HLINT ignore "Use let" #-} @@ -147,7 +148,7 @@ bool_prog10 = -- | 11. are unused fresh_input variables treated properly? prog11 = do - _ <- fresh_input :: Comp ('TArr 'TField) F_BN128 + _ <- fresh_input @'TField b <- fresh_input return b From d57fd7a807d38c6aefb480984fe4d18806eab974 Mon Sep 17 00:00:00 2001 From: martyall Date: Sat, 6 Jan 2024 22:16:08 -0800 Subject: [PATCH 04/19] Revert "control input types" This reverts commit 91d5ee5a6ad836456a8028af6d75f9ded3a93c97. --- examples/Snarkl/Example/Basic.hs | 5 +--- examples/Snarkl/Example/Games.hs | 47 ++++++++++-------------------- examples/Snarkl/Example/List.hs | 6 ++-- examples/Snarkl/Example/Queue.hs | 19 ++++++------ src/Snarkl/Language.hs | 2 +- src/Snarkl/Language/SyntaxMonad.hs | 26 +++-------------- tests/Test/Snarkl/Unit/Programs.hs | 3 +- 7 files changed, 34 insertions(+), 74 deletions(-) diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index c2cd978..7a86fd1 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -1,5 +1,4 @@ {-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE TypeApplications #-} module Snarkl.Example.Basic where @@ -49,12 +48,10 @@ desugar1 = compileCompToTexp p1 interp1 :: (GaloisField k) => k interp1 = comp_interp p1 [] -p2 :: (GaloisField k) => Comp 'TField k p2 = do - x <- fresh_input @'TField + x <- fresh_input return $ x + x -desugar2 :: (GaloisField k) => TExpPkg 'TField k desugar2 = compileCompToTexp p2 interp2 :: (GaloisField k) => k diff --git a/examples/Snarkl/Example/Games.hs b/examples/Snarkl/Example/Games.hs index a66c537..e152e2c 100644 --- a/examples/Snarkl/Example/Games.hs +++ b/examples/Snarkl/Example/Games.hs @@ -4,7 +4,7 @@ {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE ScopedTypeVariables #-} -module Snarkl.Example.Games () where +module Snarkl.Example.Games where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable @@ -42,7 +42,7 @@ data Game :: Ty -> * -> * where Single :: forall (s :: Ty) (t :: Ty) k. ( Typeable s, - InputType s + Typeable t ) => ISO t s k -> Game t k @@ -55,17 +55,15 @@ data Game :: Ty -> * -> * where Zippable t2 k, Zippable t k, Derive t1 k, - Derive t2 k, - InputType t1, - InputType t2 + Derive t2 k ) => ISO t ('TSum t1 t2) k -> Game t1 k -> Game t2 k -> Game t k -decode :: (GaloisField k, InputType t) => Game t k -> Comp t k -decode (Single (Iso _ bld :: ISO t s k)) = +decode :: (GaloisField k) => Game t k -> Comp t k +decode (Single (Iso _ bld)) = do x <- fresh_input bld x @@ -94,7 +92,7 @@ bool_game = unit_game :: (GaloisField k) => Game 'TUnit k unit_game = Single (Iso (\_ -> return (fromField 1)) (\(_ :: TExp 'TField k) -> return unit)) -fail_game :: (Typeable ty, InputType ty) => Game ty p +fail_game :: (Typeable ty) => Game ty p fail_game = Single ( Iso @@ -111,9 +109,7 @@ sum_game :: Zippable t2 k, Derive t1 k, Derive t2 k, - GaloisField k, - InputType t1, - InputType t2 + GaloisField k ) => Game t1 k -> Game t2 k -> @@ -124,11 +120,11 @@ sum_game g1 g2 = basic_game :: (GaloisField k) => Game ('TSum 'TField 'TField) k basic_game = sum_game field_game field_game -{- basic_test :: (GaloisField k) => Comp 'TField k -basic_test = do - s <- decode basic_game - case_sum return return s +basic_test = + do + s <- decode basic_game + case_sum return return s t1 :: F_BN128 t1 = comp_interp basic_test [0, 23, 88] -- 23 @@ -136,15 +132,11 @@ t1 = comp_interp basic_test [0, 23, 88] -- 23 t2 :: F_BN128 t2 = comp_interp basic_test [1, 23, 88] -- 88 --} - (+>) :: ( Typeable t, Typeable s, Zippable t k, - Zippable s k, - InputType t, - InputType s + Zippable s k ) => Game t k -> ISO s t k -> @@ -187,7 +179,6 @@ seqI (Iso f g) (Iso f' g') = Iso (\a -> f a >>= f') (\c -> g' c >>= g) prodLInputI :: ( Typeable a, - InputType a, Typeable b, GaloisField k ) => @@ -252,9 +243,7 @@ prod_game :: Zippable b k, Derive a k, Derive b k, - GaloisField k, - InputType a, - InputType b + GaloisField k ) => Game a k -> Game b k -> @@ -320,9 +309,7 @@ instance Derive b k, Gameable a k, Gameable b k, - GaloisField k, - InputType a, - InputType b + GaloisField k ) => Gameable ('TProd a b) k where @@ -337,13 +324,11 @@ instance Derive b k, Gameable a k, Gameable b k, - GaloisField k, - InputType a, - InputType b + GaloisField k ) => Gameable ('TSum a b) k where mkGame = sum_game mkGame mkGame -gdecode :: (Gameable t k, InputType t, GaloisField k) => Comp t k +gdecode :: (Gameable t k, GaloisField k) => Comp t k gdecode = decode mkGame diff --git a/examples/Snarkl/Example/List.hs b/examples/Snarkl/Example/List.hs index 5dc7a69..fe2157b 100644 --- a/examples/Snarkl/Example/List.hs +++ b/examples/Snarkl/Example/List.hs @@ -1,6 +1,4 @@ -{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE TypeApplications #-} module Snarkl.Example.List where @@ -218,11 +216,11 @@ list_comp4 = l <- list2 last_list (fromField 0) l -listN :: (Typeable 'TField, Zippable 'TField k, Derive 'TField k, GaloisField k) => TExp 'TField k -> Comp (TList 'TField) k +listN :: (Typeable a, Zippable a k, Derive a k, GaloisField k) => TExp 'TField k -> Comp (TList a) k listN n = fixN 100 go n where go self n0 = do - x <- fresh_input @TField + x <- fresh_input tl <- self (n0 - fromField 1) if return (eq n0 (fromField 0)) then nil else cons x tl diff --git a/examples/Snarkl/Example/Queue.hs b/examples/Snarkl/Example/Queue.hs index e00ec44..b306692 100644 --- a/examples/Snarkl/Example/Queue.hs +++ b/examples/Snarkl/Example/Queue.hs @@ -1,5 +1,4 @@ {-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE TypeApplications #-} module Snarkl.Example.Queue where @@ -185,15 +184,15 @@ queue_comp3 = sx <- dequeue q1 (fromField 0) fst_pair sx -queueN :: forall a k. (Typeable a, Zippable a k, Derive a k, GaloisField k, InputType a) => TExp 'TField k -> Comp (TQueue a) k -queueN n = - let go self n0 = do - x <- fresh_input @a - tl <- self (n0 - fromField 1) - if return (eq n0 (fromField 0)) - then empty_queue - else enqueue x tl - in fixN 100 go n +queueN :: (Typeable a, Zippable a k, Derive a k, GaloisField k) => TExp 'TField k -> Comp (TQueue a) k +queueN n = fixN 100 go n + where + go self n0 = do + x <- fresh_input + tl <- self (n0 - fromField 1) + if return (eq n0 (fromField 0)) + then empty_queue + else enqueue x tl test_queueN :: (GaloisField k) => Comp 'TField k test_queueN = do diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index fce16f7..733a182 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -11,7 +11,7 @@ module Snarkl.Language (>>), Env (..), -- | Return a fresh input variable. - InputType (..), + fresh_input, -- | Classes Zippable, Derive, diff --git a/src/Snarkl/Language/SyntaxMonad.hs b/src/Snarkl/Language/SyntaxMonad.hs index 1a8fa31..1aa7ead 100644 --- a/src/Snarkl/Language/SyntaxMonad.hs +++ b/src/Snarkl/Language/SyntaxMonad.hs @@ -15,7 +15,7 @@ module Snarkl.Language.SyntaxMonad Env (..), State (..), -- | Return a fresh input variable. - InputType (..), + fresh_input, -- | Return a fresh variable. fresh_var, -- | Return a fresh location. @@ -61,7 +61,7 @@ import Snarkl.Language.TExpr TExp (..), TLoc (TLoc), TVar (TVar), - Ty (TArr, TBool, TField, TProd, TUnit), + Ty (TArr, TBool, TProd, TUnit), Val (VFalse, VLoc, VTrue, VUnit), lastSeq, locOfTexp, @@ -443,8 +443,8 @@ fresh_var = ) ) -_fresh_input :: State (Env k) (TExp ty a) -_fresh_input = +fresh_input :: State (Env k) (TExp ty a) +fresh_input = State ( \s -> let (v, nextVar) = runSupply (Variable <$> fresh) (next_variable s) @@ -457,24 +457,6 @@ _fresh_input = ) ) -class (Typeable ty) => InputType (ty :: Ty) where - fresh_input :: (GaloisField k) => State (Env k) (TExp ty k) - -instance InputType 'TBool where - fresh_input = _fresh_input - -instance InputType 'TUnit where - fresh_input = _fresh_input - -instance InputType 'TField where - fresh_input = _fresh_input - -instance (InputType ty1, InputType ty2) => InputType ('TProd ty1 ty2) where - fresh_input = do - x1 <- fresh_input - x2 <- fresh_input - pair x1 x2 - fresh_loc :: State (Env k) (TExp ty a) fresh_loc = State diff --git a/tests/Test/Snarkl/Unit/Programs.hs b/tests/Test/Snarkl/Unit/Programs.hs index 3098c25..fa316ef 100644 --- a/tests/Test/Snarkl/Unit/Programs.hs +++ b/tests/Test/Snarkl/Unit/Programs.hs @@ -1,5 +1,4 @@ {-# LANGUAGE RebindableSyntax #-} -{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# HLINT ignore "Use let" #-} @@ -148,7 +147,7 @@ bool_prog10 = -- | 11. are unused fresh_input variables treated properly? prog11 = do - _ <- fresh_input @'TField + _ <- fresh_input :: Comp ('TArr 'TField) F_BN128 b <- fresh_input return b From fe0d40d4ff2b88d130f3698f43152759eab2fccf Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 15:59:16 -0800 Subject: [PATCH 05/19] tests pass with core program --- app/Main.hs | 3 +- examples/Snarkl/Example/Basic.hs | 11 +-- snarkl.cabal | 6 +- src/Snarkl/Compile.hs | 140 +++++++++++++++-------------- src/Snarkl/Interp.hs | 93 ++++++++++++++++++- src/Snarkl/Language.hs | 14 ++- src/Snarkl/Language/Core.hs | 27 ++++++ src/Snarkl/Language/Expr.hs | 116 +++++++++++++----------- src/Snarkl/Language/LambdaExpr.hs | 35 +++----- src/Snarkl/Language/Syntax.hs | 18 ++++ src/Snarkl/Language/SyntaxMonad.hs | 2 +- src/Snarkl/Language/TExpr.hs | 2 +- src/Snarkl/Toplevel.hs | 4 +- tests/Test/ArkworksBridge.hs | 3 +- tests/Test/Snarkl/UnitSpec.hs | 3 +- 15 files changed, 316 insertions(+), 161 deletions(-) create mode 100644 src/Snarkl/Language/Core.hs diff --git a/app/Main.hs b/app/Main.hs index 5d2adc0..3c55fb4 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -4,6 +4,7 @@ import Control.Monad (unless) import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (PrimeField) import Data.Typeable (Typeable) +import Prettyprinter import Snarkl.Compile (SimplParam (NoSimplify)) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Field @@ -14,7 +15,7 @@ main :: IO () main = do executeAndWriteArtifacts "./snarkl-output" "prog2" NoSimplify (Programs.prog2 10) [1 :: F_BN128] -executeAndWriteArtifacts :: (Typeable ty, PrimeField k) => FilePath -> String -> SimplParam -> Comp ty k -> [k] -> IO () +executeAndWriteArtifacts :: (Typeable ty, Pretty k, PrimeField k) => FilePath -> String -> SimplParam -> Comp ty k -> [k] -> IO () executeAndWriteArtifacts fp name simpl mf inputs = do let Result {result_sat = isSatisfied, result_r1cs = r1cs, result_witness = wit} = execute simpl mf inputs unless isSatisfied $ failWith $ ErrMsg "R1CS is not satisfied" diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index 7a86fd1..91a9a7c 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -5,6 +5,7 @@ module Snarkl.Example.Basic where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable (Typeable) import GHC.TypeLits (KnownNat) +import Prettyprinter (Pretty (pretty)) import Snarkl.Compile import Snarkl.Field (F_BN128) import Snarkl.Language.Syntax @@ -42,10 +43,10 @@ arr_ex x = do p1 :: (GaloisField k) => Comp 'TField k p1 = arr_ex $ fromField 1 -desugar1 :: (GaloisField k) => TExpPkg 'TField k +desugar1 :: (GaloisField k, Pretty k) => TExpPkg 'TField k desugar1 = compileCompToTexp p1 -interp1 :: (GaloisField k) => k +interp1 :: (GaloisField k, Pretty k) => k interp1 = comp_interp p1 [] p2 = do @@ -54,13 +55,13 @@ p2 = do desugar2 = compileCompToTexp p2 -interp2 :: (GaloisField k) => k +interp2 :: (GaloisField k, Pretty k) => k interp2 = comp_interp p2 [] -interp2' :: (GaloisField k) => k +interp2' :: (GaloisField k, Pretty k) => k interp2' = comp_interp p2 [256] -compile1 :: (GaloisField k) => R1CS k +compile1 :: (GaloisField k, Pretty k) => R1CS k compile1 = compileCompToR1CS Simplify p1 comp1 :: (GaloisField k, Typeable a) => Comp ('TSum 'TBool a) k diff --git a/snarkl.cabal b/snarkl.cabal index 875f47a..742f9c4 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -25,7 +25,7 @@ source-repository head library ghc-options: - -Wall -Werror -Wredundant-constraints -funbox-strict-fields + -Wall -Wredundant-constraints -funbox-strict-fields -optc-O3 -- -threaded @@ -52,6 +52,7 @@ library Snarkl.Language.Expr Snarkl.Language.LambdaExpr Snarkl.Language.Syntax + Snarkl.Language.Core Snarkl.Language.SyntaxMonad Snarkl.Language.TExpr Snarkl.Toplevel @@ -80,6 +81,7 @@ library , bytestring , Cabal >=1.22 , containers >=0.5 && <0.7 + , errors , galois-field >=1.0.4 , hspec >=2.0 , jsonl >=0.1.4 @@ -88,6 +90,7 @@ library , parallel >=3.2 && <3.3 , prettyprinter , process >=1.2 + , transformers hs-source-dirs: src default-language: Haskell2010 @@ -133,6 +136,7 @@ test-suite spec , mtl >=2.2 && <2.3 , parallel >=3.2 && <3.3 , process >=1.2 + , prettyprinter , QuickCheck , snarkl >=0.1.0.0 diff --git a/src/Snarkl/Compile.hs b/src/Snarkl/Compile.hs index e17aec1..285beb7 100644 --- a/src/Snarkl/Compile.hs +++ b/src/Snarkl/Compile.hs @@ -22,6 +22,9 @@ import Control.Monad.State import qualified Control.Monad.State as State import Data.Either (fromRight) import Data.Field.Galois (GaloisField) +-- do_const_prop, + +import Data.Foldable (traverse_) import Data.List (sort) import qualified Data.Map as Map import qualified Data.Set as Set @@ -46,15 +49,13 @@ import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Language ( Comp, Env (Env, input_vars, next_variable), - Exp (..), TExp, Variable (Variable), booleanVarsOfTexp, - do_const_prop, expOfTExp, runState, - var_of_exp, ) +import qualified Snarkl.Language.Core as Core ---------------------------------------------------------------- -- @@ -110,13 +111,13 @@ encode_or :: (GaloisField a) => (Var, Var, Var) -> State (CEnv a) () encode_or (x, y, z) = do x_mult_y <- fresh_var - cs_of_exp x_mult_y (EBinop Mult [EVar (_Var # x), EVar (_Var # y)]) + cs_of_exp x_mult_y (Core.EBinop Mult [Core.EVar (_Var # x), Core.EVar (_Var # y)]) cs_of_exp x_mult_y - ( EBinop + ( Core.EBinop Sub - [ EBinop Add [EVar (_Var # x), EVar (_Var # y)], - EVar (_Var # z) + [ Core.EBinop Add [Core.EVar (_Var # x), Core.EVar (_Var # y)], + Core.EVar (_Var # z) ] ) @@ -153,25 +154,24 @@ encode_boolean_eq :: (GaloisField a) => (Var, Var, Var) -> State (CEnv a) () encode_boolean_eq (x, y, z) = cs_of_exp z e where e = - EBinop + Core.EBinop Add - [ EBinop Mult [EVar (_Var # x), EVar (_Var # y)], - EBinop + [ Core.EBinop Mult [Core.EVar (_Var # x), Core.EVar (_Var # y)], + Core.EBinop Mult - [ EBinop Sub [EVal 1, EVar (_Var # x)], - EBinop Sub [EVal 1, EVar (_Var # y)] + [ Core.EBinop Sub [Core.EVal 1, Core.EVar (_Var # x)], + Core.EBinop Sub [Core.EVal 1, Core.EVar (_Var # y)] ] ] -- | Constraint 'x == y = z'. -- The encoding is: z = (x-y == 0). encode_eq :: (GaloisField a) => (Var, Var, Var) -> State (CEnv a) () -encode_eq (x, y, z) = cs_of_exp z e - where - e = - EAssert - (EVar (_Var # z)) - (EUnop ZEq (EBinop Sub [EVar (_Var # x), EVar (_Var # y)])) +encode_eq (x, y, z) = + cs_of_assignment $ + Core.Assignment + (_Var # z) + (Core.EUnop ZEq (Core.EBinop Sub [Core.EVar (_Var # x), Core.EVar (_Var # y)])) -- | Constraint 'y = x!=0 ? 1 : 0'. -- The encoding is: @@ -191,8 +191,8 @@ encode_zneq (x, y) = nm <- fresh_var add_constraint (CMagic nm [x, m] mf) -- END magic. - cs_of_exp y (EBinop Mult [EVar (_Var # x), EVar (_Var # m)]) - cs_of_exp neg_y (EBinop Sub [EVal 1, EVar (_Var # y)]) + cs_of_exp y (Core.EBinop Mult [Core.EVar (_Var # x), Core.EVar (_Var # m)]) + cs_of_exp neg_y (Core.EBinop Sub [Core.EVal 1, Core.EVar (_Var # y)]) add_constraint (CMult (1, neg_y) (1, x) (0, Nothing)) where @@ -219,7 +219,7 @@ encode_zeq (x, y) = do neg_y <- fresh_var encode_zneq (x, neg_y) - cs_of_exp y (EBinop Sub [EVal 1, EVar (_Var # neg_y)]) + cs_of_exp y (Core.EBinop Sub [Core.EVal 1, Core.EVar (_Var # neg_y)]) -- | Encode the constraint 'un_op x = y' encode_unop :: (GaloisField a) => UnOp -> (Var, Var) -> State (CEnv a) () @@ -261,20 +261,20 @@ encode_linear out xs = remove_consts (Left p : l) = p : remove_consts l remove_consts (Right _ : l) = remove_consts l -cs_of_exp :: (GaloisField a) => Var -> Exp a -> State (CEnv a) () +cs_of_exp :: (GaloisField a) => Var -> Core.Exp a -> State (CEnv a) () cs_of_exp out e = case e of - EVar x -> + Core.EVar x -> ensure_equal (out, view _Var x) - EVal c -> + Core.EVal c -> ensure_const (out, c) - EUnop op (EVar x) -> + Core.EUnop op (Core.EVar x) -> encode_unop op (view _Var x, out) - EUnop op e1 -> + Core.EUnop op e1 -> do e1_out <- fresh_var cs_of_exp e1_out e1 encode_unop op (e1_out, out) - EBinop op es -> + Core.EBinop op es -> -- [NOTE linear combination optimization:] cf. also -- 'encode_linear' above. 'go_linear' returns a list of -- (label*coeff + constant) pairs. @@ -287,33 +287,33 @@ cs_of_exp out e = case e of -- We special-case linear combinations in this way to avoid having -- to introduce new multiplication gates for multiplication by -- constant scalars. - let go_linear :: (GaloisField a) => [Exp a] -> State (CEnv a) [Either (Var, a) a] + let go_linear :: (GaloisField a) => [Core.Exp a] -> State (CEnv a) [Either (Var, a) a] go_linear [] = return [] - go_linear (EBinop Mult [EVar x, EVal coeff] : es') = + go_linear (Core.EBinop Mult [Core.EVar x, Core.EVal coeff] : es') = do labels <- go_linear es' return $ Left (x ^. _Var, coeff) : labels - go_linear (EBinop Mult [EVal coeff, EVar y] : es') = + go_linear (Core.EBinop Mult [Core.EVal coeff, Core.EVar y] : es') = do labels <- go_linear es' return $ Left (y ^. _Var, coeff) : labels - go_linear (EBinop Mult [e_left, EVal coeff] : es') = + go_linear (Core.EBinop Mult [e_left, Core.EVal coeff] : es') = do e_left_out <- fresh_var cs_of_exp e_left_out e_left labels <- go_linear es' return $ Left (e_left_out, coeff) : labels - go_linear (EBinop Mult [EVal coeff, e_right] : es') = + go_linear (Core.EBinop Mult [Core.EVal coeff, e_right] : es') = do e_right_out <- fresh_var cs_of_exp e_right_out e_right labels <- go_linear es' return $ Left (e_right_out, coeff) : labels - go_linear (EVal c : es') = + go_linear (Core.EVal c : es') = do labels <- go_linear es' return $ Right c : labels - go_linear (EVar x : es') = + go_linear (Core.EVar x : es') = do labels <- go_linear es' return $ Left (x ^. _Var, 1) : labels @@ -338,9 +338,9 @@ cs_of_exp out e = case e of rev_pol (Left (x, c) : ls) = Left (x, -c) : rev_pol ls rev_pol (Right c : ls) = Right (-c) : rev_pol ls - go_other :: (GaloisField a) => [Exp a] -> State (CEnv a) [Var] + go_other :: (GaloisField a) => [Core.Exp a] -> State (CEnv a) [Var] go_other [] = return [] - go_other (EVar x : es') = + go_other (Core.EVar x : es') = do labels <- go_other es' return $ (x ^. _Var) : labels @@ -378,37 +378,40 @@ cs_of_exp out e = case e of encode_labels labels -- Encoding: out = b*e1 + (1-b)e2 - EIf b e1 e2 -> cs_of_exp out e0 + Core.EIf b e1 e2 -> cs_of_exp out e0 where e0 = - EBinop + Core.EBinop Add - [ EBinop Mult [b, e1], - EBinop Mult [EBinop Sub [EVal 1, b], e2] + [ Core.EBinop Mult [b, e1], + Core.EBinop Mult [Core.EBinop Sub [Core.EVal 1, b], e2] ] - -- NOTE: when compiling assignments, the naive thing to do is - -- to introduce a new var, e2_out, bound to result of e2 and - -- then ensure that e2_out == x. We optimize by passing x to - -- compilation of e2 directly. - EAssert e1 e2 -> - do - let x = var_of_exp e1 - cs_of_exp (x ^. _Var) e2 - ESeq le -> - do - x <- fresh_var -- x is garbage - go x le - where - go _ [] = failWith $ ErrMsg "internal error: empty ESeq" - go _ [e1] = cs_of_exp out e1 - go garbage_var (e1 : le') = - do - cs_of_exp garbage_var e1 - go garbage_var le' - EUnit -> + ---- NOTE: when compiling assignments, the naive thing to do is + ---- to introduce a new var, e2_out, bound to result of e2 and + ---- then ensure that e2_out == x. We optimize by passing x to + ---- compilation of e2 directly. + -- EAssert e1 e2 -> + -- do + -- let x = var_of_exp e1 + -- cs_of_exp (x ^. _Var) e2 + -- ESeq le -> + -- do + -- x <- fresh_var -- x is garbage + -- go x le + -- where + -- go _ [] = failWith $ ErrMsg "internal error: empty ESeq" + -- go _ [e1] = cs_of_exp out e1 + -- go garbage_var (e1 : le') = + -- do + -- cs_of_exp garbage_var e1 + -- go garbage_var le' + Core.EUnit -> -- NOTE: [[ EUnit ]]_{out} = [[ EVal zero ]]_{out}. - cs_of_exp out (EVal 0) + cs_of_exp out (Core.EVal 0) + +cs_of_assignment :: (GaloisField a) => Core.Assignment a -> State (CEnv a) () +cs_of_assignment (Core.Assignment x e) = cs_of_exp (view _Var x) e data SimplParam = NoSimplify @@ -503,7 +506,7 @@ compileCompToTexp mf = -- | Snarkl.Compile 'TExp's to constraint systems. Re-exported from 'Snarkl.Compile.Snarkl.Compile'. compileTexpToConstraints :: - (Typeable ty, GaloisField k) => + (Typeable ty, GaloisField k, Pretty k) => TExpPkg ty k -> ConstraintSystem k compileTexpToConstraints (TExpPkg _out _in_vars te) = @@ -516,10 +519,11 @@ compileTexpToConstraints (TExpPkg _out _in_vars te) = Set.toList $ Set.fromList in_vars `Set.intersection` Set.fromList (map (view _Var) $ booleanVarsOfTexp te) - e0 = expOfTExp te - e = do_const_prop e0 + Core.Program assignments e = expOfTExp te + traverse_ cs_of_assignment assignments + -- e = do_const_prop e0 -- Snarkl.Compile 'e' to constraints 'cs', with output wire 'out'. - cs_of_exp out e + cs_of_assignment $ Core.Assignment (_Var # out) e -- Add boolean constraints mapM_ ensure_boolean boolean_in_vars cs <- get_constraints @@ -536,7 +540,7 @@ compileTexpToConstraints (TExpPkg _out _in_vars te) = -- | Snarkl.Compile Snarkl computations to constraint systems. compileCompToConstraints :: - (Typeable ty, GaloisField k) => + (Typeable ty, GaloisField k, Pretty k) => Comp ty k -> ConstraintSystem k compileCompToConstraints = compileTexpToConstraints . compileCompToTexp @@ -549,7 +553,7 @@ compileCompToConstraints = compileTexpToConstraints . compileCompToTexp -- | Snarkl.Compile 'TExp's to 'R1CS'. compileTExpToR1CS :: - (Typeable ty, GaloisField k) => + (Typeable ty, GaloisField k, Pretty k) => SimplParam -> TExpPkg ty k -> R1CS k @@ -557,7 +561,7 @@ compileTExpToR1CS simpl = compileConstraintsToR1CS simpl . compileTexpToConstrai -- | Snarkl.Compile Snarkl computations to 'R1CS'. compileCompToR1CS :: - (Typeable ty, GaloisField k) => + (Typeable ty, GaloisField k, Pretty k) => SimplParam -> Comp ty k -> R1CS k diff --git a/src/Snarkl/Interp.hs b/src/Snarkl/Interp.hs index 94de305..c0b0176 100644 --- a/src/Snarkl/Interp.hs +++ b/src/Snarkl/Interp.hs @@ -8,11 +8,14 @@ where import Control.Monad (ap, foldM) import Data.Data (Typeable) import Data.Field.Galois (GaloisField) +import Data.Foldable (traverse_) import Data.Map (Map) import qualified Data.Map as Map +import Prettyprinter (Pretty) import Snarkl.Common (Op (..), UnOp (ZEq)) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language (Exp (..), TExp, Variable, expOfTExp) +import Snarkl.Language (TExp, Variable, expOfTExp) +import qualified Snarkl.Language.Core as Core type Env a = Map Variable (Maybe a) @@ -78,21 +81,23 @@ boolOfField v = interpTExp :: ( GaloisField a, - Typeable ty + Typeable ty, + Pretty a ) => TExp ty a -> InterpM a (Maybe a) interpTExp e = do let _exp = expOfTExp e - interpExpr _exp + interpProg _exp interp :: - (GaloisField a, Typeable ty) => + (GaloisField a, Typeable ty, Pretty a) => Map Variable a -> TExp ty a -> Either ErrMsg (Env a, Maybe a) interp rho e = runInterpM (interpTExp e) $ Map.map Just rho +{- interpExpr :: (GaloisField a) => Exp a -> @@ -159,3 +164,83 @@ interpExpr e = case e of XOr -> return $ fieldOfBool $ (b1 && not b2) || (b2 && not b1) BEq -> return $ fieldOfBool $ b1 == b2 _ -> failWith $ ErrMsg "internal error in interp_binop" +-} + +interpProg :: + (GaloisField a) => + Core.Program a -> + InterpM a (Maybe a) +interpProg (Core.Program assignments finalExp) = + let f (Core.Assignment var e) = do + e' <- interpCoreExpr e + addBinds [(var, e')] + in do + traverse_ f assignments + interpCoreExpr finalExp + +interpCoreExpr :: + (GaloisField a) => + Core.Exp a -> + InterpM a (Maybe a) +interpCoreExpr = \case + Core.EVar x -> lookupVar x + Core.EVal v -> pure $ Just v + Core.EUnop op e2 -> do + v2 <- interpCoreExpr e2 + case v2 of + Nothing -> pure Nothing + Just v2' -> case op of + ZEq -> return $ Just $ fieldOfBool (v2' == 0) + Core.EBinop op _es -> case _es of + [] -> failWith $ ErrMsg "empty binary args" + (a : as) -> do + b <- interpCoreExpr a + foldM (interpBinopExpr op) b as + Core.EIf eb e1 e2 -> + do + mb <- interpCoreExpr eb + case mb of + Nothing -> pure Nothing + Just _b -> boolOfField _b >>= \b -> if b then interpCoreExpr e1 else interpCoreExpr e2 + -- CoreEAssert e1 e2 -> + -- case (e1, e2) of + -- (Core.EVar x, _) -> + -- do + -- v2 <- interpExpr e2 + -- addBinds [(x, v2)] + -- (_, _) -> raiseErr $ ErrMsg $ show e1 ++ " not a variable" + -- CESeq es -> case es of + -- [] -> failWith $ ErrMsg "empty sequence" + -- _ -> last <$> mapM interpExpr es + Core.EUnit -> return $ Just 1 + where + interpBinopExpr :: (GaloisField a) => Op -> Maybe a -> Core.Exp a -> InterpM a (Maybe a) + interpBinopExpr _ Nothing _ = return Nothing + interpBinopExpr _op (Just a1) _exp = do + ma2 <- interpCoreExpr _exp + case ma2 of + Nothing -> return Nothing + Just a2 -> Just <$> op a1 a2 + where + op :: (GaloisField a) => a -> a -> InterpM a a + op a b = case _op of + Add -> pure $ a + b + Sub -> pure $ a - b + Mult -> pure $ a * b + Div -> pure $ a / b + And -> interpBooleanBinop a b + Or -> interpBooleanBinop a b + XOr -> interpBooleanBinop a b + BEq -> interpBooleanBinop a b + Eq -> pure $ fieldOfBool $ a == b + interpBooleanBinop :: (GaloisField a) => a -> a -> InterpM a a + interpBooleanBinop a b = + do + b1 <- boolOfField a + b2 <- boolOfField b + case _op of + And -> return $ fieldOfBool $ b1 && b2 + Or -> return $ fieldOfBool $ b1 || b2 + XOr -> return $ fieldOfBool $ (b1 && not b2) || (b2 && not b1) + BEq -> return $ fieldOfBool $ b1 == b2 + _ -> failWith $ ErrMsg "internal error in interp_binop" \ No newline at end of file diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index 733a182..0b188f7 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -2,7 +2,7 @@ module Snarkl.Language ( expOfTExp, booleanVarsOfTexp, TExp, - module Snarkl.Language.Expr, + module Snarkl.Language.Core, -- | SyntaxMonad Comp, runState, @@ -80,6 +80,9 @@ where import Data.Data (Typeable) import Data.Field.Galois (GaloisField) +import Debug.Trace (trace) +import Prettyprinter (Pretty (pretty)) +import Snarkl.Language.Core import Snarkl.Language.Expr import Snarkl.Language.LambdaExpr (expOfLambdaExp) import Snarkl.Language.Syntax @@ -87,5 +90,10 @@ import Snarkl.Language.SyntaxMonad import Snarkl.Language.TExpr import qualified Prelude -expOfTExp :: (GaloisField a, Typeable ty) => TExp ty a -> Exp a -expOfTExp = expOfLambdaExp Prelude.. lambdaExpOfTExp +expOfTExp :: (Prelude.Show a, GaloisField a, Typeable ty, Pretty a) => TExp ty a -> Program a +expOfTExp te = + trace (Prelude.show te) Prelude.$ + let e = do_const_prop Prelude.. expOfLambdaExp Prelude.. lambdaExpOfTExp Prelude.$ te + in case mkProgram e of + Prelude.Right p -> p + Prelude.Left err -> Prelude.error Prelude.$ "expOfTExp: failed to convert TExp to Program: " Prelude.<> err diff --git a/src/Snarkl/Language/Core.hs b/src/Snarkl/Language/Core.hs new file mode 100644 index 0000000..5dc485f --- /dev/null +++ b/src/Snarkl/Language/Core.hs @@ -0,0 +1,27 @@ +{-# LANGUAGE LambdaCase #-} + +module Snarkl.Language.Core where + +import Data.Field.Galois (GaloisField) +import Data.Kind (Type) +import Prettyprinter (Pretty) +import Snarkl.Common + +newtype Variable = Variable Int deriving (Eq, Ord, Show, Pretty) + +data Exp :: Type -> Type where + EVar :: Variable -> Exp a + EVal :: (GaloisField a) => a -> Exp a + EUnop :: UnOp -> Exp a -> Exp a + EBinop :: Op -> [Exp a] -> Exp a + EIf :: Exp a -> Exp a -> Exp a -> Exp a + EUnit :: Exp a + +deriving instance (Eq a) => Eq (Exp a) + +deriving instance (Show a) => Show (Exp a) + +data Assignment a = Assignment Variable (Exp a) + +data Program :: Type -> Type where + Program :: [Assignment a] -> Exp a -> Program a \ No newline at end of file diff --git a/src/Snarkl/Language/Expr.hs b/src/Snarkl/Language/Expr.hs index d8149fe..8bd7e8c 100644 --- a/src/Snarkl/Language/Expr.hs +++ b/src/Snarkl/Language/Expr.hs @@ -1,19 +1,28 @@ +{-# LANGUAGE LambdaCase #-} + module Snarkl.Language.Expr ( Exp (..), - Variable (..), - exp_binop, - exp_seq, - is_pure, var_of_exp, do_const_prop, + mkProgram, + expSeq, ) where -import Control.Monad.State (State, evalState, gets, modify) +import Control.Error (hoistEither, runExceptT) +import Control.Monad.Except + ( ExceptT, + MonadError (throwError), + MonadPlus (mzero), + ) +import Control.Monad.State (State, evalState, gets, modify, runState) import Data.Field.Galois (GaloisField) +import Data.Foldable (toList) import Data.Kind (Type) import Data.Map (Map) import qualified Data.Map as Map +import Data.Sequence (Seq, (|>)) +import Debug.Trace (trace) import Prettyprinter ( Pretty (pretty), hsep, @@ -21,13 +30,12 @@ import Prettyprinter punctuate, (<+>), ) -import Snarkl.Common (Op, UnOp, isAssoc) +import Snarkl.Common (Op, UnOp) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) - -newtype Variable = Variable Int deriving (Eq, Ord, Show, Pretty) +import qualified Snarkl.Language.Core as Core data Exp :: Type -> Type where - EVar :: Variable -> Exp a + EVar :: Core.Variable -> Exp a EVal :: (GaloisField a) => a -> Exp a EUnop :: UnOp -> Exp a -> Exp a EBinop :: Op -> [Exp a] -> Exp a @@ -40,50 +48,12 @@ deriving instance (Eq a) => Eq (Exp a) deriving instance (Show a) => Show (Exp a) -var_of_exp :: (Show a) => Exp a -> Variable +var_of_exp :: (Show a) => Exp a -> Core.Variable var_of_exp e = case e of EVar x -> x _ -> failWith $ ErrMsg ("var_of_exp: expected variable: " ++ show e) --- | Smart constructor for EBinop, ensuring all expressions (involving --- associative operations) are flattened to top level. -exp_binop :: Op -> Exp a -> Exp a -> Exp a -exp_binop op e1 e2 = - case (e1, e2) of - (EBinop op1 l1, EBinop op2 l2) - | op1 == op2 && op2 == op && isAssoc op -> - EBinop op (l1 ++ l2) - (EBinop op1 l1, _) - | op1 == op && isAssoc op -> - EBinop op (l1 ++ [e2]) - (_, EBinop op2 l2) - | op2 == op && isAssoc op -> - EBinop op (e1 : l2) - (_, _) -> EBinop op [e1, e2] - --- | Smart constructor for sequence, ensuring all expressions are --- flattened to top level. -exp_seq :: Exp a -> Exp a -> Exp a -exp_seq e1 e2 = - case (e1, e2) of - (ESeq l1, ESeq l2) -> ESeq (l1 ++ l2) - (ESeq l1, _) -> ESeq (l1 ++ [e2]) - (_, ESeq l2) -> ESeq (e1 : l2) - (_, _) -> ESeq [e1, e2] - -is_pure :: Exp a -> Bool -is_pure e = - case e of - EVar _ -> True - EVal _ -> True - EUnop _ e1 -> is_pure e1 - EBinop _ es -> all is_pure es - EIf b e1 e2 -> is_pure b && is_pure e1 && is_pure e2 - EAssert _ _ -> False - ESeq es -> all is_pure es - EUnit -> True - -const_prop :: (GaloisField a) => Exp a -> State (Map Variable a) (Exp a) +const_prop :: (GaloisField a) => Exp a -> State (Map Core.Variable a) (Exp a) const_prop e = case e of EVar x -> lookup_var x @@ -114,14 +84,14 @@ const_prop e = return $ ESeq es' EUnit -> return EUnit where - lookup_var :: (GaloisField a) => Variable -> State (Map Variable a) (Exp a) + lookup_var :: (GaloisField a) => Core.Variable -> State (Map Core.Variable a) (Exp a) lookup_var x0 = gets ( \m -> case Map.lookup x0 m of Nothing -> EVar x0 Just c -> EVal c ) - add_bind :: (Variable, a) -> State (Map Variable a) (Exp a) + add_bind :: (Core.Variable, a) -> State (Map Core.Variable a) (Exp a) add_bind (x0, c0) = do modify (Map.insert x0 c0) @@ -141,3 +111,47 @@ instance (Pretty a) => Pretty (Exp a) where pretty (EAssert e1 e2) = pretty e1 <+> ":=" <+> pretty e2 pretty (ESeq es) = parens $ hsep $ punctuate ";" $ map pretty es pretty EUnit = "()" + +mkExpression :: (Show a) => Exp a -> Either String (Core.Exp a) +mkExpression = \case + EVar x -> pure $ Core.EVar x + EVal v -> pure $ Core.EVal v + EUnop op e -> Core.EUnop op <$> mkExpression e + EBinop op es -> Core.EBinop op <$> traverse mkExpression es + EIf e1 e2 e3 -> Core.EIf <$> mkExpression e1 <*> mkExpression e2 <*> mkExpression e3 + EUnit -> pure Core.EUnit + e -> throwError $ "mkExpression: " ++ show e + +-- | Smart constructor for sequence, ensuring all expressions are +-- flattened to top level. +expSeq :: Exp a -> Exp a -> Exp a +expSeq e1 e2 = + case (e1, e2) of + (ESeq l1, ESeq l2) -> ESeq (l1 ++ l2) + (ESeq l1, _) -> ESeq (l1 ++ [e2]) + (_, ESeq l2) -> ESeq (e1 : l2) + (_, _) -> ESeq [e1, e2] + +mkAssignment :: (Show a) => Exp a -> Either String (Core.Assignment a) +mkAssignment (EAssert (EVar v) e) = Core.Assignment v <$> mkExpression e +mkAssignment e = throwError $ "mkAssignment: expected EAssert, got " <> show e + +mkProgram :: (Show a) => Exp a -> Either String (Core.Program a) +mkProgram e@(ESeq es) = trace ("mkProgram ESeq: " <> show e) $ do + let (eexpr, assignments) = runState (runExceptT $ go es) mempty + Core.Program (toList assignments) <$> eexpr + where + go :: (Show a) => [Exp a] -> ExceptT String (State (Seq (Core.Assignment a))) (Core.Exp a) + go = \case + [] -> mzero + [e] -> hoistEither $ mkExpression e + e : rest -> do + case e of + EUnit -> go rest + _ -> do + assignment <- hoistEither $ mkAssignment e + modify (|> assignment) + go rest +mkProgram e = trace ("mkProgram " <> show e) $ do + e' <- mkExpression e + pure $ Core.Program [] e' \ No newline at end of file diff --git a/src/Snarkl/Language/LambdaExpr.hs b/src/Snarkl/Language/LambdaExpr.hs index 57edbd1..832cad4 100644 --- a/src/Snarkl/Language/LambdaExpr.hs +++ b/src/Snarkl/Language/LambdaExpr.hs @@ -12,8 +12,9 @@ import Control.Monad.Error.Class (throwError) import Data.Field.Galois (GaloisField) import Data.Kind (Type) import Snarkl.Common (Op, UnOp, isAssoc) -import Snarkl.Language.Expr (Variable) -import qualified Snarkl.Language.Expr as Core +import Snarkl.Language.Core (Variable) +import Snarkl.Language.Expr (expSeq) +import qualified Snarkl.Language.Expr as E data Exp :: Type -> Type where EVar :: Variable -> Exp a @@ -57,7 +58,7 @@ betaNormalize = \case EBinop op es -> EBinop op (substitute (var, e1) <$> es) EIf b e2 e3 -> EIf (substitute (var, e1) b) (substitute (var, e1) e2) (substitute (var, e1) e3) EAssert e2 e3 -> EAssert (substitute (var, e1) e2) (substitute (var, e1) e3) - ESeq e2 e3 -> ESeq (substitute (var, e1) e2) (substitute (var, e1) e3) + ESeq l r -> ESeq (substitute (var, e1) l) (substitute (var, e1) r) EAbs var' e -> EAbs var' (substitute (var, e1) e) EApp e2 e3 -> EApp (substitute (var, e1) e2) (substitute (var, e1) e3) @@ -75,31 +76,21 @@ expBinop op e1 e2 = EBinop op (e1 : l2) (_, _) -> EBinop op [e1, e2] --- | Smart constructor for sequence, ensuring all expressions are --- flattened to top level. -expSeq :: Core.Exp a -> Core.Exp a -> Core.Exp a -expSeq e1 e2 = - case (e1, e2) of - (Core.ESeq l1, Core.ESeq l2) -> Core.ESeq (l1 ++ l2) - (Core.ESeq l1, _) -> Core.ESeq (l1 ++ [e2]) - (_, Core.ESeq l2) -> Core.ESeq (e1 : l2) - (_, _) -> Core.ESeq [e1, e2] - -expOfLambdaExp :: (Show a) => Exp a -> Core.Exp a +expOfLambdaExp :: (Show a) => Exp a -> E.Exp a expOfLambdaExp _exp = let coreExp = betaNormalize _exp in case expOfLambdaExp' coreExp of Left err -> error err Right e -> e where - expOfLambdaExp' :: (Show a) => Exp a -> Either String (Core.Exp a) + expOfLambdaExp' :: (Show a) => Exp a -> Either String (E.Exp a) expOfLambdaExp' = \case - EVar var -> pure $ Core.EVar var - EVal v -> pure $ Core.EVal v - EUnit -> pure Core.EUnit - EUnop op e -> Core.EUnop op <$> expOfLambdaExp' e - EBinop op es -> Core.EBinop op <$> mapM expOfLambdaExp' es - EIf b e1 e2 -> Core.EIf <$> expOfLambdaExp' b <*> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 - EAssert e1 e2 -> Core.EAssert <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 + EVar var -> pure $ E.EVar var + EVal v -> pure $ E.EVal v + EUnit -> pure E.EUnit + EUnop op e -> E.EUnop op <$> expOfLambdaExp' e + EBinop op es -> E.EBinop op <$> mapM expOfLambdaExp' es + EIf b e1 e2 -> E.EIf <$> expOfLambdaExp' b <*> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 + EAssert e1 e2 -> E.EAssert <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 ESeq e1 e2 -> expSeq <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 e -> throwError ("Impossible after lambda simplicifaction: " <> show e) diff --git a/src/Snarkl/Language/Syntax.hs b/src/Snarkl/Language/Syntax.hs index fe741d6..a7c82bf 100644 --- a/src/Snarkl/Language/Syntax.hs +++ b/src/Snarkl/Language/Syntax.hs @@ -104,6 +104,7 @@ import Snarkl.Language.TExpr TUnop (TUnop), Ty (TArr, TBool, TField, TFun, TMu, TProd, TSum, TUnit), Val (VFalse, VField, VTrue, VUnit), + teSeq, ) import Unsafe.Coerce (unsafeCoerce) import Prelude hiding @@ -762,6 +763,23 @@ lambda f = do ) _ -> error "impossible: lambda" +-- lambda :: +-- (Typeable a) => +-- (Typeable b) => +-- (TExp a k -> Comp b k) -> +-- Comp ('TFun a b) k +-- lambda f = +-- +-- State +-- ( \s -> +-- case runState fresh_var s of +-- Left err -> Left err +-- Right (e, s') -> +-- case runState (f e) s' of +-- Left err -> Left err +-- Right (e', s'') -> Right (e `teSeq` (Abs )', s'') +-- ) + curry :: (Typeable a) => (Typeable b) => diff --git a/src/Snarkl/Language/SyntaxMonad.hs b/src/Snarkl/Language/SyntaxMonad.hs index 1aa7ead..236812f 100644 --- a/src/Snarkl/Language/SyntaxMonad.hs +++ b/src/Snarkl/Language/SyntaxMonad.hs @@ -55,7 +55,7 @@ import qualified Data.Map.Strict as Map import Data.String (IsString (..)) import Data.Typeable (Typeable) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language.Expr (Variable (..)) +import Snarkl.Language.Core (Variable (..)) import Snarkl.Language.TExpr ( Loc, TExp (..), diff --git a/src/Snarkl/Language/TExpr.hs b/src/Snarkl/Language/TExpr.hs index 7690c79..746e050 100644 --- a/src/Snarkl/Language/TExpr.hs +++ b/src/Snarkl/Language/TExpr.hs @@ -28,7 +28,7 @@ import Data.Typeable (Proxy (..), Typeable, eqT, typeOf, typeRep, type (:~:) (Re import Prettyprinter (Pretty (pretty), line, parens, (<+>)) import Snarkl.Common (Op, UnOp) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language.Expr (Variable) +import Snarkl.Language.Core (Variable) import qualified Snarkl.Language.LambdaExpr as LE data TFunct where diff --git a/src/Snarkl/Toplevel.hs b/src/Snarkl/Toplevel.hs index 53503ab..4e3f6f1 100644 --- a/src/Snarkl/Toplevel.hs +++ b/src/Snarkl/Toplevel.hs @@ -41,7 +41,7 @@ import Prelude -- | Using the executable semantics for the 'TExp' language, execute -- the computation on the provided inputs, returning the 'k' result. comp_interp :: - (Typeable ty, GaloisField k) => + (Typeable ty, Pretty k, GaloisField k) => Comp ty k -> [k] -> k @@ -86,7 +86,7 @@ instance (Pretty k) => Pretty (Result k) where -- (3) Check whether 'w' satisfies the constraint system produced in (1). -- (4) Check whether the R1CS result matches the interpreter result. -- (5) Return the 'Result'. -execute :: (Typeable ty, PrimeField k) => SimplParam -> Comp ty k -> [k] -> Result k +execute :: (Typeable ty, PrimeField k, Pretty k) => SimplParam -> Comp ty k -> [k] -> Result k execute simpl mf inputs = let TExpPkg nv in_vars e = compileCompToTexp mf r1cs = compileTExpToR1CS simpl (TExpPkg nv in_vars e) diff --git a/tests/Test/ArkworksBridge.hs b/tests/Test/ArkworksBridge.hs index e1e4703..8ac8dba 100644 --- a/tests/Test/ArkworksBridge.hs +++ b/tests/Test/ArkworksBridge.hs @@ -3,6 +3,7 @@ module Test.ArkworksBridge where import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (GaloisField, PrimeField) import Data.Typeable (Typeable) +import Prettyprinter (Pretty) import Snarkl.Backend.R1CS import Snarkl.Compile (SimplParam, compileCompToR1CS) import Snarkl.Language (Comp) @@ -14,7 +15,7 @@ data CMD k where CreateProof :: (Typeable ty, GaloisField k) => FilePath -> String -> SimplParam -> Comp ty k -> [k] -> CMD k RunR1CS :: (Typeable ty, GaloisField k) => FilePath -> String -> SimplParam -> Comp ty k -> [k] -> CMD k -runCMD :: (PrimeField k) => CMD k -> IO GHC.ExitCode +runCMD :: (PrimeField k, Pretty k) => CMD k -> IO GHC.ExitCode runCMD (CreateTrustedSetup rootDir name simpl c) = do let r1cs = compileCompToR1CS simpl c r1csFilePath = mkR1CSFilePath rootDir name diff --git a/tests/Test/Snarkl/UnitSpec.hs b/tests/Test/Snarkl/UnitSpec.hs index 0a3701a..25bf1ff 100644 --- a/tests/Test/Snarkl/UnitSpec.hs +++ b/tests/Test/Snarkl/UnitSpec.hs @@ -5,6 +5,7 @@ module Test.Snarkl.UnitSpec where import Data.Field.Galois (PrimeField) import Data.Typeable (Typeable) +import Prettyprinter (Pretty) import Snarkl.Compile import Snarkl.Example.Keccak import Snarkl.Example.Lam @@ -21,7 +22,7 @@ import Test.Hspec (Spec, describe, it, shouldBe, shouldReturn) import Test.Snarkl.Unit.Programs import Prelude -test_comp :: (Typeable ty, PrimeField k) => SimplParam -> Comp ty k -> [k] -> IO (Either ExitCode k) +test_comp :: (Typeable ty, Pretty k, PrimeField k) => SimplParam -> Comp ty k -> [k] -> IO (Either ExitCode k) test_comp simpl mf args = do exit_code <- runCMD $ RunR1CS "./scripts" "hspec" simpl mf args From 449d2bbb4b22f42de776815c674b6426447d2ca2 Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 17:08:12 -0800 Subject: [PATCH 06/19] clean up language modules --- examples/Snarkl/Example/Lam.hs | 10 +- examples/Snarkl/Example/Peano.hs | 4 +- examples/Snarkl/Example/Tree.hs | 6 +- snarkl.cabal | 1 + src/Snarkl/Compile.hs | 36 ++----- src/Snarkl/Interp.hs | 89 +----------------- src/Snarkl/Language.hs | 40 +++++--- src/Snarkl/Language/Core.hs | 3 +- src/Snarkl/Language/Expr.hs | 84 +++++++++-------- src/Snarkl/Language/LambdaExpr.hs | 36 +++---- src/Snarkl/Language/Syntax.hs | 23 +---- src/Snarkl/Language/SyntaxMonad.hs | 2 +- src/Snarkl/Language/TExpr.hs | 146 ++++++++++------------------- src/Snarkl/Language/Type.hs | 77 +++++++++++++++ src/Snarkl/Toplevel.hs | 4 +- 15 files changed, 240 insertions(+), 321 deletions(-) create mode 100644 src/Snarkl/Language/Type.hs diff --git a/examples/Snarkl/Example/Lam.hs b/examples/Snarkl/Example/Lam.hs index 7f7c184..32e53d6 100644 --- a/examples/Snarkl/Example/Lam.hs +++ b/examples/Snarkl/Example/Lam.hs @@ -10,9 +10,7 @@ import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) import Snarkl.Errors -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr +import Snarkl.Language import Prelude hiding ( fromRational, negate, @@ -35,7 +33,7 @@ type TFSubst = 'TFSum ('TFConst 'TField) ('TFProd ('TFConst TTerm) 'TFId) type TSubst = 'TMu TFSubst -subst_nil :: (GaloisField k) => TExp 'TField k -> State (Env k) (TExp TSubst k) +subst_nil :: (GaloisField k) => TExp 'TField k -> Comp TSubst k subst_nil n = do n' <- inl n @@ -159,7 +157,7 @@ shift n t = fix go t app t1' t2' ) -compose :: (GaloisField k) => TExp TSubst k -> TExp ('TMu TFSubst) k -> State (Env k) (TExp ('TMu TFSubst) k) +compose :: (GaloisField k) => TExp TSubst k -> TExp ('TMu TFSubst) k -> Comp ('TMu TFSubst) k compose sigma1 sigma2 = do p <- pair sigma1 sigma2 @@ -192,7 +190,7 @@ compose sigma1 sigma2 = subst_cons t'' s2'' ) -subst_term :: (GaloisField k) => TExp ('TMu TFSubst) k -> TExp TTerm k -> State (Env k) (TExp ('TMu TF) k) +subst_term :: (GaloisField k) => TExp ('TMu TFSubst) k -> TExp TTerm k -> Comp ('TMu TF) k subst_term sigma t = do p <- pair sigma t diff --git a/examples/Snarkl/Example/Peano.hs b/examples/Snarkl/Example/Peano.hs index 9d47e7f..cfbbdb3 100644 --- a/examples/Snarkl/Example/Peano.hs +++ b/examples/Snarkl/Example/Peano.hs @@ -4,9 +4,7 @@ module Snarkl.Example.Peano where import Data.Field.Galois (GaloisField, Prime) import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr +import Snarkl.Language import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Tree.hs b/examples/Snarkl/Example/Tree.hs index 643ba0c..b5a1af7 100644 --- a/examples/Snarkl/Example/Tree.hs +++ b/examples/Snarkl/Example/Tree.hs @@ -5,9 +5,7 @@ module Snarkl.Example.Tree where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr +import Snarkl.Language import Prelude hiding ( fromRational, negate, @@ -69,7 +67,7 @@ map_tree :: Derive a1 k, GaloisField k ) => - (TExp a k -> State (Env k) (TExp a1 k)) -> + (TExp a k -> Comp a1 k) -> TExp (TTree a) k -> Comp (TTree a1) k map_tree f t = diff --git a/snarkl.cabal b/snarkl.cabal index 742f9c4..343cce7 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -55,6 +55,7 @@ library Snarkl.Language.Core Snarkl.Language.SyntaxMonad Snarkl.Language.TExpr + Snarkl.Language.Type Snarkl.Toplevel default-extensions: diff --git a/src/Snarkl/Compile.hs b/src/Snarkl/Compile.hs index 285beb7..a34f069 100644 --- a/src/Snarkl/Compile.hs +++ b/src/Snarkl/Compile.hs @@ -52,7 +52,7 @@ import Snarkl.Language TExp, Variable (Variable), booleanVarsOfTexp, - expOfTExp, + compileTExpToProgram, runState, ) import qualified Snarkl.Language.Core as Core @@ -386,30 +386,14 @@ cs_of_exp out e = case e of [ Core.EBinop Mult [b, e1], Core.EBinop Mult [Core.EBinop Sub [Core.EVal 1, b], e2] ] - - ---- NOTE: when compiling assignments, the naive thing to do is - ---- to introduce a new var, e2_out, bound to result of e2 and - ---- then ensure that e2_out == x. We optimize by passing x to - ---- compilation of e2 directly. - -- EAssert e1 e2 -> - -- do - -- let x = var_of_exp e1 - -- cs_of_exp (x ^. _Var) e2 - -- ESeq le -> - -- do - -- x <- fresh_var -- x is garbage - -- go x le - -- where - -- go _ [] = failWith $ ErrMsg "internal error: empty ESeq" - -- go _ [e1] = cs_of_exp out e1 - -- go garbage_var (e1 : le') = - -- do - -- cs_of_exp garbage_var e1 - -- go garbage_var le' Core.EUnit -> -- NOTE: [[ EUnit ]]_{out} = [[ EVal zero ]]_{out}. cs_of_exp out (Core.EVal 0) +---- NOTE: when compiling assignments, the naive thing to do is +---- to introduce a new var, e2_out, bound to result of e2 and +---- then ensure that e2_out == x. We optimize by passing x to +---- compilation of e2 directly. cs_of_assignment :: (GaloisField a) => Core.Assignment a -> State (CEnv a) () cs_of_assignment (Core.Assignment x e) = cs_of_exp (view _Var x) e @@ -506,7 +490,7 @@ compileCompToTexp mf = -- | Snarkl.Compile 'TExp's to constraint systems. Re-exported from 'Snarkl.Compile.Snarkl.Compile'. compileTexpToConstraints :: - (Typeable ty, GaloisField k, Pretty k) => + (Typeable ty, GaloisField k) => TExpPkg ty k -> ConstraintSystem k compileTexpToConstraints (TExpPkg _out _in_vars te) = @@ -519,7 +503,7 @@ compileTexpToConstraints (TExpPkg _out _in_vars te) = Set.toList $ Set.fromList in_vars `Set.intersection` Set.fromList (map (view _Var) $ booleanVarsOfTexp te) - Core.Program assignments e = expOfTExp te + Core.Program assignments e = compileTExpToProgram te traverse_ cs_of_assignment assignments -- e = do_const_prop e0 -- Snarkl.Compile 'e' to constraints 'cs', with output wire 'out'. @@ -540,7 +524,7 @@ compileTexpToConstraints (TExpPkg _out _in_vars te) = -- | Snarkl.Compile Snarkl computations to constraint systems. compileCompToConstraints :: - (Typeable ty, GaloisField k, Pretty k) => + (Typeable ty, GaloisField k) => Comp ty k -> ConstraintSystem k compileCompToConstraints = compileTexpToConstraints . compileCompToTexp @@ -553,7 +537,7 @@ compileCompToConstraints = compileTexpToConstraints . compileCompToTexp -- | Snarkl.Compile 'TExp's to 'R1CS'. compileTExpToR1CS :: - (Typeable ty, GaloisField k, Pretty k) => + (Typeable ty, GaloisField k) => SimplParam -> TExpPkg ty k -> R1CS k @@ -561,7 +545,7 @@ compileTExpToR1CS simpl = compileConstraintsToR1CS simpl . compileTexpToConstrai -- | Snarkl.Compile Snarkl computations to 'R1CS'. compileCompToR1CS :: - (Typeable ty, GaloisField k, Pretty k) => + (Typeable ty, GaloisField k) => SimplParam -> Comp ty k -> R1CS k diff --git a/src/Snarkl/Interp.hs b/src/Snarkl/Interp.hs index c0b0176..2989d22 100644 --- a/src/Snarkl/Interp.hs +++ b/src/Snarkl/Interp.hs @@ -11,10 +11,9 @@ import Data.Field.Galois (GaloisField) import Data.Foldable (traverse_) import Data.Map (Map) import qualified Data.Map as Map -import Prettyprinter (Pretty) import Snarkl.Common (Op (..), UnOp (ZEq)) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language (TExp, Variable, expOfTExp) +import Snarkl.Language (TExp, Variable, compileTExpToProgram) import qualified Snarkl.Language.Core as Core type Env a = Map Variable (Maybe a) @@ -81,91 +80,21 @@ boolOfField v = interpTExp :: ( GaloisField a, - Typeable ty, - Pretty a + Typeable ty ) => TExp ty a -> InterpM a (Maybe a) interpTExp e = do - let _exp = expOfTExp e + let _exp = compileTExpToProgram e interpProg _exp interp :: - (GaloisField a, Typeable ty, Pretty a) => + (GaloisField a, Typeable ty) => Map Variable a -> TExp ty a -> Either ErrMsg (Env a, Maybe a) interp rho e = runInterpM (interpTExp e) $ Map.map Just rho -{- -interpExpr :: - (GaloisField a) => - Exp a -> - InterpM a (Maybe a) -interpExpr e = case e of - EVar x -> lookupVar x - EVal v -> pure $ Just v - EUnop op e2 -> do - v2 <- interpExpr e2 - case v2 of - Nothing -> pure Nothing - Just v2' -> case op of - ZEq -> return $ Just $ fieldOfBool (v2' == 0) - EBinop op _es -> case _es of - [] -> failWith $ ErrMsg "empty binary args" - (a : as) -> do - b <- interpExpr a - foldM (interpBinopExpr op) b as - EIf eb e1 e2 -> - do - mb <- interpExpr eb - case mb of - Nothing -> pure Nothing - Just _b -> boolOfField _b >>= \b -> if b then interpExpr e1 else interpExpr e2 - EAssert e1 e2 -> - case (e1, e2) of - (EVar x, _) -> - do - v2 <- interpExpr e2 - addBinds [(x, v2)] - (_, _) -> raiseErr $ ErrMsg $ show e1 ++ " not a variable" - ESeq es -> case es of - [] -> failWith $ ErrMsg "empty sequence" - _ -> last <$> mapM interpExpr es - EUnit -> return $ Just 1 - where - interpBinopExpr :: (GaloisField a) => Op -> Maybe a -> Exp a -> InterpM a (Maybe a) - interpBinopExpr _ Nothing _ = return Nothing - interpBinopExpr _op (Just a1) _exp = do - ma2 <- interpExpr _exp - case ma2 of - Nothing -> return Nothing - Just a2 -> Just <$> op a1 a2 - where - op :: (GaloisField a) => a -> a -> InterpM a a - op a b = case _op of - Add -> pure $ a + b - Sub -> pure $ a - b - Mult -> pure $ a * b - Div -> pure $ a / b - And -> interpBooleanBinop a b - Or -> interpBooleanBinop a b - XOr -> interpBooleanBinop a b - BEq -> interpBooleanBinop a b - Eq -> pure $ fieldOfBool $ a == b - interpBooleanBinop :: (GaloisField a) => a -> a -> InterpM a a - interpBooleanBinop a b = - do - b1 <- boolOfField a - b2 <- boolOfField b - case _op of - And -> return $ fieldOfBool $ b1 && b2 - Or -> return $ fieldOfBool $ b1 || b2 - XOr -> return $ fieldOfBool $ (b1 && not b2) || (b2 && not b1) - BEq -> return $ fieldOfBool $ b1 == b2 - _ -> failWith $ ErrMsg "internal error in interp_binop" --} - interpProg :: (GaloisField a) => Core.Program a -> @@ -202,16 +131,6 @@ interpCoreExpr = \case case mb of Nothing -> pure Nothing Just _b -> boolOfField _b >>= \b -> if b then interpCoreExpr e1 else interpCoreExpr e2 - -- CoreEAssert e1 e2 -> - -- case (e1, e2) of - -- (Core.EVar x, _) -> - -- do - -- v2 <- interpExpr e2 - -- addBinds [(x, v2)] - -- (_, _) -> raiseErr $ ErrMsg $ show e1 ++ " not a variable" - -- CESeq es -> case es of - -- [] -> failWith $ ErrMsg "empty sequence" - -- _ -> last <$> mapM interpExpr es Core.EUnit -> return $ Just 1 where interpBinopExpr :: (GaloisField a) => Op -> Maybe a -> Core.Exp a -> InterpM a (Maybe a) diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index 0b188f7..684573d 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -1,9 +1,16 @@ module Snarkl.Language - ( expOfTExp, + ( compileTExpToProgram, + -- | Snarkl.Language.TExpr, booleanVarsOfTexp, TExp, - module Snarkl.Language.Core, - -- | SyntaxMonad + -- | Snarkl.Language.Core, + Variable (..), + Program (..), + Assignment (..), + Exp (..), + -- types + module Snarkl.Language.Type, + -- | SyntaxMonad and Syntax Comp, runState, return, @@ -80,20 +87,23 @@ where import Data.Data (Typeable) import Data.Field.Galois (GaloisField) -import Debug.Trace (trace) -import Prettyprinter (Pretty (pretty)) import Snarkl.Language.Core -import Snarkl.Language.Expr + ( Assignment (..), + Exp (..), + Program (..), + Variable (..), + ) +import Snarkl.Language.Expr (mkProgram) import Snarkl.Language.LambdaExpr (expOfLambdaExp) import Snarkl.Language.Syntax import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import qualified Prelude +import Snarkl.Language.TExpr (TExp, booleanVarsOfTexp, tExpToLambdaExp) +import Snarkl.Language.Type +import Prelude (Either (..), error, ($), (.), (<>)) -expOfTExp :: (Prelude.Show a, GaloisField a, Typeable ty, Pretty a) => TExp ty a -> Program a -expOfTExp te = - trace (Prelude.show te) Prelude.$ - let e = do_const_prop Prelude.. expOfLambdaExp Prelude.. lambdaExpOfTExp Prelude.$ te - in case mkProgram e of - Prelude.Right p -> p - Prelude.Left err -> Prelude.error Prelude.$ "expOfTExp: failed to convert TExp to Program: " Prelude.<> err +compileTExpToProgram :: (GaloisField a, Typeable ty) => TExp ty a -> Program a +compileTExpToProgram te = + let eprog = mkProgram . expOfLambdaExp . tExpToLambdaExp $ te + in case eprog of + Right p -> p + Left err -> error $ "compileTExpToProgram: failed to convert TExp to Program: " <> err \ No newline at end of file diff --git a/src/Snarkl/Language/Core.hs b/src/Snarkl/Language/Core.hs index 5dc485f..3709b9d 100644 --- a/src/Snarkl/Language/Core.hs +++ b/src/Snarkl/Language/Core.hs @@ -4,6 +4,7 @@ module Snarkl.Language.Core where import Data.Field.Galois (GaloisField) import Data.Kind (Type) +import Data.Sequence (Seq) import Prettyprinter (Pretty) import Snarkl.Common @@ -24,4 +25,4 @@ deriving instance (Show a) => Show (Exp a) data Assignment a = Assignment Variable (Exp a) data Program :: Type -> Type where - Program :: [Assignment a] -> Exp a -> Program a \ No newline at end of file + Program :: Seq (Assignment a) -> Exp a -> Program a \ No newline at end of file diff --git a/src/Snarkl/Language/Expr.hs b/src/Snarkl/Language/Expr.hs index 8bd7e8c..7624f4c 100644 --- a/src/Snarkl/Language/Expr.hs +++ b/src/Snarkl/Language/Expr.hs @@ -1,11 +1,11 @@ {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} module Snarkl.Language.Expr ( Exp (..), - var_of_exp, - do_const_prop, mkProgram, expSeq, + expBinop, ) where @@ -13,7 +13,6 @@ import Control.Error (hoistEither, runExceptT) import Control.Monad.Except ( ExceptT, MonadError (throwError), - MonadPlus (mzero), ) import Control.Monad.State (State, evalState, gets, modify, runState) import Data.Field.Galois (GaloisField) @@ -21,8 +20,7 @@ import Data.Foldable (toList) import Data.Kind (Type) import Data.Map (Map) import qualified Data.Map as Map -import Data.Sequence (Seq, (|>)) -import Debug.Trace (trace) +import Data.Sequence (Seq, fromList, (<|), (><), (|>), pattern Empty, pattern (:<|)) import Prettyprinter ( Pretty (pretty), hsep, @@ -30,8 +28,7 @@ import Prettyprinter punctuate, (<+>), ) -import Snarkl.Common (Op, UnOp) -import Snarkl.Errors (ErrMsg (ErrMsg), failWith) +import Snarkl.Common (Op, UnOp, isAssoc) import qualified Snarkl.Language.Core as Core data Exp :: Type -> Type where @@ -41,18 +38,13 @@ data Exp :: Type -> Type where EBinop :: Op -> [Exp a] -> Exp a EIf :: Exp a -> Exp a -> Exp a -> Exp a EAssert :: Exp a -> Exp a -> Exp a - ESeq :: [Exp a] -> Exp a + ESeq :: Seq (Exp a) -> Exp a EUnit :: Exp a deriving instance (Eq a) => Eq (Exp a) deriving instance (Show a) => Show (Exp a) -var_of_exp :: (Show a) => Exp a -> Core.Variable -var_of_exp e = case e of - EVar x -> x - _ -> failWith $ ErrMsg ("var_of_exp: expected variable: " ++ show e) - const_prop :: (GaloisField a) => Exp a -> State (Map Core.Variable a) (Exp a) const_prop e = case e of @@ -109,7 +101,7 @@ instance (Pretty a) => Pretty (Exp a) where pretty (EIf b e1 e2) = "if" <+> pretty b <+> "then" <+> pretty e1 <+> "else" <+> pretty e2 pretty (EAssert e1 e2) = pretty e1 <+> ":=" <+> pretty e2 - pretty (ESeq es) = parens $ hsep $ punctuate ";" $ map pretty es + pretty (ESeq es) = parens $ hsep $ punctuate ";" $ map pretty (toList es) pretty EUnit = "()" mkExpression :: (Show a) => Exp a -> Either String (Core.Exp a) @@ -127,31 +119,49 @@ mkExpression = \case expSeq :: Exp a -> Exp a -> Exp a expSeq e1 e2 = case (e1, e2) of - (ESeq l1, ESeq l2) -> ESeq (l1 ++ l2) - (ESeq l1, _) -> ESeq (l1 ++ [e2]) - (_, ESeq l2) -> ESeq (e1 : l2) - (_, _) -> ESeq [e1, e2] + (ESeq l1, ESeq l2) -> ESeq (l1 >< l2) + (ESeq l1, _) -> ESeq (l1 |> e2) + (_, ESeq l2) -> ESeq (e1 <| l2) + (_, _) -> ESeq (fromList [e1, e2]) + +expBinop :: Op -> Exp a -> Exp a -> Exp a +expBinop op e1 e2 = + case (e1, e2) of + (EBinop op1 l1, EBinop op2 l2) + | op1 == op2 && op2 == op && isAssoc op -> + EBinop op (l1 ++ l2) + (EBinop op1 l1, _) + | op1 == op && isAssoc op -> + EBinop op (l1 ++ [e2]) + (_, EBinop op2 l2) + | op2 == op && isAssoc op -> + EBinop op (e1 : l2) + (_, _) -> EBinop op [e1, e2] mkAssignment :: (Show a) => Exp a -> Either String (Core.Assignment a) mkAssignment (EAssert (EVar v) e) = Core.Assignment v <$> mkExpression e mkAssignment e = throwError $ "mkAssignment: expected EAssert, got " <> show e -mkProgram :: (Show a) => Exp a -> Either String (Core.Program a) -mkProgram e@(ESeq es) = trace ("mkProgram ESeq: " <> show e) $ do - let (eexpr, assignments) = runState (runExceptT $ go es) mempty - Core.Program (toList assignments) <$> eexpr - where - go :: (Show a) => [Exp a] -> ExceptT String (State (Seq (Core.Assignment a))) (Core.Exp a) - go = \case - [] -> mzero - [e] -> hoistEither $ mkExpression e - e : rest -> do - case e of - EUnit -> go rest - _ -> do - assignment <- hoistEither $ mkAssignment e - modify (|> assignment) - go rest -mkProgram e = trace ("mkProgram " <> show e) $ do - e' <- mkExpression e - pure $ Core.Program [] e' \ No newline at end of file +-- At this point the expression should be either: +-- 1. A sequence of assignments, followed by an expression +-- 2. An expression +mkProgram :: (GaloisField a) => Exp a -> Either String (Core.Program a) +mkProgram _exp = do + let e' = do_const_prop _exp + case e' of + ESeq es -> do + let (eexpr, assignments) = runState (runExceptT $ go es) mempty + Core.Program assignments <$> eexpr + where + go :: (Show a) => Seq (Exp a) -> ExceptT String (State (Seq (Core.Assignment a))) (Core.Exp a) + go = \case + Empty -> throwError "mkProgram: empty sequence" + e :<| Empty -> hoistEither $ mkExpression e + e :<| rest -> do + case e of + EUnit -> go rest + _ -> do + assignment <- hoistEither $ mkAssignment e + modify (|> assignment) + go rest + _ -> Core.Program Empty <$> mkExpression e' \ No newline at end of file diff --git a/src/Snarkl/Language/LambdaExpr.hs b/src/Snarkl/Language/LambdaExpr.hs index 832cad4..a7bf8e7 100644 --- a/src/Snarkl/Language/LambdaExpr.hs +++ b/src/Snarkl/Language/LambdaExpr.hs @@ -3,24 +3,30 @@ module Snarkl.Language.LambdaExpr ( Exp (..), expOfLambdaExp, - expBinop, - betaNormalize, ) where import Control.Monad.Error.Class (throwError) import Data.Field.Galois (GaloisField) import Data.Kind (Type) -import Snarkl.Common (Op, UnOp, isAssoc) +import Snarkl.Common (Op, UnOp) +import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Language.Core (Variable) import Snarkl.Language.Expr (expSeq) import qualified Snarkl.Language.Expr as E +-- This expression language is just the untyped version of the typed +-- expression language TEExp. It is used to remove lamba application +-- and abstraction before passing on to the next expression language. +-- There is also a certain amount of "flattening" between this representation +-- and the underlying Expr language -- we reassociate all of the Seq constructors +-- to the right and then flatten them. Similarly nested Binops of the same +-- operator are flattened into a single list if that operator is associative. data Exp :: Type -> Type where EVar :: Variable -> Exp a EVal :: (GaloisField a) => a -> Exp a EUnop :: UnOp -> Exp a -> Exp a - EBinop :: Op -> [Exp a] -> Exp a + EBinop :: Op -> Exp a -> Exp a -> Exp a EIf :: Exp a -> Exp a -> Exp a -> Exp a EAssert :: Exp a -> Exp a -> Exp a ESeq :: Exp a -> Exp a -> Exp a @@ -37,7 +43,7 @@ betaNormalize = \case EVar x -> EVar x EVal v -> EVal v EUnop op e -> EUnop op (betaNormalize e) - EBinop op es -> EBinop op (betaNormalize <$> es) + EBinop op l r -> EBinop op (betaNormalize l) (betaNormalize r) EIf e1 e2 e3 -> EIf (betaNormalize e1) (betaNormalize e2) (betaNormalize e3) EAssert e1 e2 -> EAssert (betaNormalize e1) (betaNormalize e2) ESeq e1 e2 -> ESeq (betaNormalize e1) (betaNormalize e2) @@ -55,32 +61,18 @@ betaNormalize = \case e@(EVal _) -> e EUnit -> EUnit EUnop op e -> EUnop op (substitute (var, e1) e) - EBinop op es -> EBinop op (substitute (var, e1) <$> es) + EBinop op l r -> EBinop op (substitute (var, e1) l) (substitute (var, e1) r) EIf b e2 e3 -> EIf (substitute (var, e1) b) (substitute (var, e1) e2) (substitute (var, e1) e3) EAssert e2 e3 -> EAssert (substitute (var, e1) e2) (substitute (var, e1) e3) ESeq l r -> ESeq (substitute (var, e1) l) (substitute (var, e1) r) EAbs var' e -> EAbs var' (substitute (var, e1) e) EApp e2 e3 -> EApp (substitute (var, e1) e2) (substitute (var, e1) e3) -expBinop :: Op -> Exp a -> Exp a -> Exp a -expBinop op e1 e2 = - case (e1, e2) of - (EBinop op1 l1, EBinop op2 l2) - | op1 == op2 && op2 == op && isAssoc op -> - EBinop op (l1 ++ l2) - (EBinop op1 l1, _) - | op1 == op && isAssoc op -> - EBinop op (l1 ++ [e2]) - (_, EBinop op2 l2) - | op2 == op && isAssoc op -> - EBinop op (e1 : l2) - (_, _) -> EBinop op [e1, e2] - expOfLambdaExp :: (Show a) => Exp a -> E.Exp a expOfLambdaExp _exp = let coreExp = betaNormalize _exp in case expOfLambdaExp' coreExp of - Left err -> error err + Left err -> failWith $ ErrMsg err Right e -> e where expOfLambdaExp' :: (Show a) => Exp a -> Either String (E.Exp a) @@ -89,7 +81,7 @@ expOfLambdaExp _exp = EVal v -> pure $ E.EVal v EUnit -> pure E.EUnit EUnop op e -> E.EUnop op <$> expOfLambdaExp' e - EBinop op es -> E.EBinop op <$> mapM expOfLambdaExp' es + EBinop op l r -> E.expBinop op <$> expOfLambdaExp' l <*> expOfLambdaExp' r EIf b e1 e2 -> E.EIf <$> expOfLambdaExp' b <*> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 EAssert e1 e2 -> E.EAssert <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 ESeq e1 e2 -> expSeq <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 diff --git a/src/Snarkl/Language/Syntax.hs b/src/Snarkl/Language/Syntax.hs index a7c82bf..391e12c 100644 --- a/src/Snarkl/Language/Syntax.hs +++ b/src/Snarkl/Language/Syntax.hs @@ -98,14 +98,12 @@ import Snarkl.Language.SyntaxMonad (>>=), ) import Snarkl.Language.TExpr - ( Rep, - TExp (TEAbs, TEApp, TEBinop, TEBot, TEIf, TEUnop, TEVal, TEVar), + ( TExp (TEAbs, TEApp, TEBinop, TEBot, TEIf, TEUnop, TEVal, TEVar), TOp (TOp), TUnop (TUnop), - Ty (TArr, TBool, TField, TFun, TMu, TProd, TSum, TUnit), Val (VFalse, VField, VTrue, VUnit), - teSeq, ) +import Snarkl.Language.Type (Rep, Ty (..)) import Unsafe.Coerce (unsafeCoerce) import Prelude hiding ( curry, @@ -763,23 +761,6 @@ lambda f = do ) _ -> error "impossible: lambda" --- lambda :: --- (Typeable a) => --- (Typeable b) => --- (TExp a k -> Comp b k) -> --- Comp ('TFun a b) k --- lambda f = --- --- State --- ( \s -> --- case runState fresh_var s of --- Left err -> Left err --- Right (e, s') -> --- case runState (f e) s' of --- Left err -> Left err --- Right (e', s'') -> Right (e `teSeq` (Abs )', s'') --- ) - curry :: (Typeable a) => (Typeable b) => diff --git a/src/Snarkl/Language/SyntaxMonad.hs b/src/Snarkl/Language/SyntaxMonad.hs index 236812f..f460ca3 100644 --- a/src/Snarkl/Language/SyntaxMonad.hs +++ b/src/Snarkl/Language/SyntaxMonad.hs @@ -61,13 +61,13 @@ import Snarkl.Language.TExpr TExp (..), TLoc (TLoc), TVar (TVar), - Ty (TArr, TBool, TProd, TUnit), Val (VFalse, VLoc, VTrue, VUnit), lastSeq, locOfTexp, teSeq, varOfTExp, ) +import Snarkl.Language.Type (Ty (..)) import Prelude hiding ( fromRational, negate, diff --git a/src/Snarkl/Language/TExpr.hs b/src/Snarkl/Language/TExpr.hs index 746e050..d2f6d65 100644 --- a/src/Snarkl/Language/TExpr.hs +++ b/src/Snarkl/Language/TExpr.hs @@ -4,21 +4,17 @@ module Snarkl.Language.TExpr ( Val (..), TExp (..), - TFunct (..), - Ty (..), - Rep, TUnop (..), TOp (..), TVar (..), Loc, TLoc (..), + tExpToLambdaExp, booleanVarsOfTexp, - lambdaExpOfTExp, varOfTExp, locOfTexp, teSeq, lastSeq, - -- expOfTExp, ) where @@ -30,72 +26,7 @@ import Snarkl.Common (Op, UnOp) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Language.Core (Variable) import qualified Snarkl.Language.LambdaExpr as LE - -data TFunct where - TFConst :: Ty -> TFunct - TFId :: TFunct - TFProd :: TFunct -> TFunct -> TFunct - TFSum :: TFunct -> TFunct -> TFunct - TFComp :: TFunct -> TFunct -> TFunct - deriving (Typeable) - -instance Pretty TFunct where - pretty f = case f of - TFConst ty -> "Const" <+> pretty ty - TFId -> "Id" - TFProd f1 f2 -> parens (pretty f1 <+> "⊗" <+> pretty f2) - TFSum f1 f2 -> parens (pretty f1 <+> "⊕" <+> pretty f2) - TFComp f1 f2 -> parens (pretty f1 <+> "∘" <+> pretty f2) - -data Ty where - TField :: Ty - TBool :: Ty - TArr :: Ty -> Ty - TProd :: Ty -> Ty -> Ty - TSum :: Ty -> Ty -> Ty - TMu :: TFunct -> Ty - TUnit :: Ty - TFun :: Ty -> Ty -> Ty - deriving (Typeable) - -deriving instance Typeable 'TField - -deriving instance Typeable 'TBool - -deriving instance Typeable 'TArr - -deriving instance Typeable 'TProd - -deriving instance Typeable 'TSum - -deriving instance Typeable 'TMu - -deriving instance Typeable 'TUnit - -deriving instance Typeable 'TFun - -instance Pretty Ty where - pretty ty = case ty of - TField -> "Field" - TBool -> "Bool" - TArr _ty -> "Array" <+> pretty _ty - TProd ty1 ty2 -> parens (pretty ty1 <+> "⨉" <+> pretty ty2) - TSum ty1 ty2 -> parens (pretty ty1 <+> "+" <+> pretty ty2) - TMu f -> "μ" <> parens (pretty f) - TUnit -> "()" - TFun ty1 ty2 -> parens (pretty ty1 <+> "->" <+> pretty ty2) - -type family Rep (f :: TFunct) (x :: Ty) :: Ty - -type instance Rep ('TFConst ty) x = ty - -type instance Rep 'TFId x = x - -type instance Rep ('TFProd f g) x = 'TProd (Rep f x) (Rep g x) - -type instance Rep ('TFSum f g) x = 'TSum (Rep f x) (Rep g x) - -type instance Rep ('TFComp f g) x = Rep f (Rep g x) +import Snarkl.Language.Type (Ty (TBool, TField, TFun, TUnit)) newtype TVar (ty :: Ty) = TVar Variable deriving (Eq, Show) @@ -194,22 +125,38 @@ instance (Eq a) => Eq (TExp (b :: Ty) a) where TEBot == TEBot = True _ == _ = False -lambdaExpOfTExp :: (GaloisField a, Typeable ty) => TExp ty a -> LE.Exp a -lambdaExpOfTExp te = case te of +instance (Pretty a, Typeable ty) => Pretty (TExp ty a) where + pretty (TEVar var) = pretty var + pretty (TEVal val) = pretty val + pretty (TEUnop unop _exp) = pretty unop <+> pretty _exp + pretty (TEBinop binop exp1 exp2) = pretty exp1 <+> pretty binop <+> pretty exp2 + pretty (TEIf condExp thenExp elseExp) = "if" <+> pretty condExp <+> "then" <+> pretty thenExp <+> "else" <+> pretty elseExp + pretty (TEAssert exp1 exp2) = pretty exp1 <+> ":=" <+> pretty exp2 + pretty (TESeq exp1 exp2) = parens (pretty exp1 <+> ";" <> line <> pretty exp2) + pretty TEBot = "⊥" + pretty (TEAbs var _exp) = parens ("\\" <> pretty var <+> "->" <+> pretty _exp) + pretty (TEApp exp1 exp2) = parens (pretty exp1 <+> pretty exp2) + +tExpToLambdaExp :: + (GaloisField a) => + (Typeable ty) => + TExp ty a -> + LE.Exp a +tExpToLambdaExp te = case te of TEVar (TVar x) -> LE.EVar x TEVal v -> lambdaExpOfVal v TEUnop (TUnop op) te1 -> - LE.EUnop op (lambdaExpOfTExp te1) + LE.EUnop op (tExpToLambdaExp te1) TEBinop (TOp op) te1 te2 -> - LE.expBinop op (lambdaExpOfTExp te1) (lambdaExpOfTExp te2) + LE.EBinop op (tExpToLambdaExp te1) (tExpToLambdaExp te2) TEIf te1 te2 te3 -> - LE.EIf (lambdaExpOfTExp te1) (lambdaExpOfTExp te2) (lambdaExpOfTExp te3) + LE.EIf (tExpToLambdaExp te1) (tExpToLambdaExp te2) (tExpToLambdaExp te3) TEAssert te1 te2 -> - LE.EAssert (lambdaExpOfTExp te1) (lambdaExpOfTExp te2) - TESeq te1 te2 -> LE.ESeq (lambdaExpOfTExp te1) (lambdaExpOfTExp te2) + LE.EAssert (tExpToLambdaExp te1) (tExpToLambdaExp te2) + TESeq te1 te2 -> LE.ESeq (tExpToLambdaExp te1) (tExpToLambdaExp te2) TEBot -> LE.EUnit - TEAbs (TVar v) e -> LE.EAbs v (lambdaExpOfTExp e) - TEApp e1 e2 -> LE.EApp (lambdaExpOfTExp e1) (lambdaExpOfTExp e2) + TEAbs (TVar v) e -> LE.EAbs v (tExpToLambdaExp e) + TEApp e1 e2 -> LE.EApp (tExpToLambdaExp e1) (tExpToLambdaExp e2) where lambdaExpOfVal :: (GaloisField a) => Val ty a -> LE.Exp a lambdaExpOfVal v = case v of @@ -222,13 +169,20 @@ lambdaExpOfTExp te = case te of -- | Smart constructor for 'TESeq'. Simplify 'TESeq te1 te2' to 'te2' -- whenever the normal form of 'te1' (with seq's reassociated right) -- is *not* equal 'TEAssert _ _'. -teSeq :: (Typeable ty1) => TExp ty1 a -> TExp ty2 a -> TExp ty2 a +teSeq :: + (Typeable ty1) => + TExp ty1 a -> + TExp ty2 a -> + TExp ty2 a teSeq te1 te2 = case (te1, te2) of (TEAssert _ _, _) -> TESeq te1 te2 (TESeq tx ty, _) -> teSeq tx (teSeq ty te2) (_, _) -> te2 -booleanVarsOfTexp :: (Typeable ty) => TExp ty a -> [Variable] +booleanVarsOfTexp :: + (Typeable ty) => + TExp ty a -> + [Variable] booleanVarsOfTexp = go [] where go :: (Typeable ty) => [Variable] -> TExp ty a -> [Variable] @@ -247,29 +201,25 @@ booleanVarsOfTexp = go [] go vars (TEAbs _ e) = go vars e go vars (TEApp e1 e2) = go (go vars e1) e2 -varOfTExp :: (Show (TExp ty a)) => TExp ty a -> Variable +varOfTExp :: + (Show a) => + TExp ty a -> + Variable varOfTExp te = case lastSeq te of TEVar (TVar x) -> x _ -> failWith $ ErrMsg ("varOfTExp: expected var: " ++ show te) -locOfTexp :: (Show (TExp ty a)) => TExp ty a -> Loc +locOfTexp :: + (Show a) => + TExp ty a -> + Loc locOfTexp te = case lastSeq te of TEVal (VLoc (TLoc l)) -> l _ -> failWith $ ErrMsg ("locOfTexp: expected loc: " ++ show te) -lastSeq :: TExp ty a -> TExp ty a +lastSeq :: + TExp ty a -> + TExp ty a lastSeq te = case te of TESeq _ te2 -> lastSeq te2 - _ -> te - -instance (Pretty a, Typeable ty) => Pretty (TExp ty a) where - pretty (TEVar var) = pretty var - pretty (TEVal val) = pretty val - pretty (TEUnop unop _exp) = pretty unop <+> pretty _exp - pretty (TEBinop binop exp1 exp2) = pretty exp1 <+> pretty binop <+> pretty exp2 - pretty (TEIf condExp thenExp elseExp) = "if" <+> pretty condExp <+> "then" <+> pretty thenExp <+> "else" <+> pretty elseExp - pretty (TEAssert exp1 exp2) = pretty exp1 <+> ":=" <+> pretty exp2 - pretty (TESeq exp1 exp2) = parens (pretty exp1 <+> ";" <> line <> pretty exp2) - pretty TEBot = "⊥" - pretty (TEAbs var _exp) = parens ("\\" <> pretty var <+> "->" <+> pretty _exp) - pretty (TEApp exp1 exp2) = parens (pretty exp1 <+> pretty exp2) + _ -> te \ No newline at end of file diff --git a/src/Snarkl/Language/Type.hs b/src/Snarkl/Language/Type.hs new file mode 100644 index 0000000..fa3a489 --- /dev/null +++ b/src/Snarkl/Language/Type.hs @@ -0,0 +1,77 @@ +{-# LANGUAGE UndecidableInstances #-} + +module Snarkl.Language.Type + ( TFunct (..), + Ty (..), + Rep, + ) +where + +import Data.Typeable (Typeable) +import Prettyprinter (Pretty (pretty), parens, (<+>)) + +data TFunct where + TFConst :: Ty -> TFunct + TFId :: TFunct + TFProd :: TFunct -> TFunct -> TFunct + TFSum :: TFunct -> TFunct -> TFunct + TFComp :: TFunct -> TFunct -> TFunct + deriving (Typeable) + +instance Pretty TFunct where + pretty f = case f of + TFConst ty -> "Const" <+> pretty ty + TFId -> "Id" + TFProd f1 f2 -> parens (pretty f1 <+> "⊗" <+> pretty f2) + TFSum f1 f2 -> parens (pretty f1 <+> "⊕" <+> pretty f2) + TFComp f1 f2 -> parens (pretty f1 <+> "∘" <+> pretty f2) + +data Ty where + TField :: Ty + TBool :: Ty + TArr :: Ty -> Ty + TProd :: Ty -> Ty -> Ty + TSum :: Ty -> Ty -> Ty + TMu :: TFunct -> Ty + TUnit :: Ty + TFun :: Ty -> Ty -> Ty + deriving (Typeable) + +deriving instance Typeable 'TField + +deriving instance Typeable 'TBool + +deriving instance Typeable 'TArr + +deriving instance Typeable 'TProd + +deriving instance Typeable 'TSum + +deriving instance Typeable 'TMu + +deriving instance Typeable 'TUnit + +deriving instance Typeable 'TFun + +instance Pretty Ty where + pretty ty = case ty of + TField -> "Field" + TBool -> "Bool" + TArr _ty -> "Array" <+> pretty _ty + TProd ty1 ty2 -> parens (pretty ty1 <+> "⨉" <+> pretty ty2) + TSum ty1 ty2 -> parens (pretty ty1 <+> "+" <+> pretty ty2) + TMu f -> "μ" <> parens (pretty f) + TUnit -> "()" + TFun ty1 ty2 -> parens (pretty ty1 <+> "->" <+> pretty ty2) + +type family Rep (f :: TFunct) (x :: Ty) :: Ty + +type instance Rep ('TFConst ty) x = ty + +type instance Rep 'TFId x = x + +type instance Rep ('TFProd f g) x = 'TProd (Rep f x) (Rep g x) + +type instance Rep ('TFSum f g) x = 'TSum (Rep f x) (Rep g x) + +type instance Rep ('TFComp f g) x = Rep f (Rep g x) \ No newline at end of file diff --git a/src/Snarkl/Toplevel.hs b/src/Snarkl/Toplevel.hs index 4e3f6f1..53503ab 100644 --- a/src/Snarkl/Toplevel.hs +++ b/src/Snarkl/Toplevel.hs @@ -41,7 +41,7 @@ import Prelude -- | Using the executable semantics for the 'TExp' language, execute -- the computation on the provided inputs, returning the 'k' result. comp_interp :: - (Typeable ty, Pretty k, GaloisField k) => + (Typeable ty, GaloisField k) => Comp ty k -> [k] -> k @@ -86,7 +86,7 @@ instance (Pretty k) => Pretty (Result k) where -- (3) Check whether 'w' satisfies the constraint system produced in (1). -- (4) Check whether the R1CS result matches the interpreter result. -- (5) Return the 'Result'. -execute :: (Typeable ty, PrimeField k, Pretty k) => SimplParam -> Comp ty k -> [k] -> Result k +execute :: (Typeable ty, PrimeField k) => SimplParam -> Comp ty k -> [k] -> Result k execute simpl mf inputs = let TExpPkg nv in_vars e = compileCompToTexp mf r1cs = compileTExpToR1CS simpl (TExpPkg nv in_vars e) From 1e964784a2d7232e5c328b34a7364113e35d7cab Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 17:35:27 -0800 Subject: [PATCH 07/19] Clean up exports --- snarkl.cabal | 29 +++++++------- src/Snarkl/Errors.hs | 4 +- src/Snarkl/Interp.hs | 4 +- src/Snarkl/Language.hs | 67 ++++++++++++++++++++++++++++++++- src/Snarkl/Language/Core.hs | 6 +-- src/Snarkl/Language/Expr.hs | 3 +- src/Snarkl/Language/TExpr.hs | 3 +- src/Snarkl/Language/Type.hs | 2 +- tests/Test/ArkworksBridge.hs | 3 +- tests/Test/Snarkl/LambdaSpec.hs | 2 +- 10 files changed, 91 insertions(+), 32 deletions(-) diff --git a/snarkl.cabal b/snarkl.cabal index 343cce7..50b1947 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -25,8 +25,7 @@ source-repository head library ghc-options: - -Wall -Wredundant-constraints -funbox-strict-fields - -optc-O3 + -Wall -Wredundant-constraints -funbox-strict-fields -optc-O3 -- -threaded exposed-modules: @@ -49,10 +48,10 @@ library Snarkl.Field Snarkl.Interp Snarkl.Language + Snarkl.Language.Core Snarkl.Language.Expr Snarkl.Language.LambdaExpr Snarkl.Language.Syntax - Snarkl.Language.Core Snarkl.Language.SyntaxMonad Snarkl.Language.TExpr Snarkl.Language.Type @@ -67,11 +66,13 @@ library GADTs GeneralizedNewtypeDeriving KindSignatures + LambdaCase OverloadedStrings PolyKinds RankNTypes ScopedTypeVariables StandaloneDeriving + TypeApplications TypeFamilies TypeSynonymInstances UndecidableInstances @@ -127,24 +128,24 @@ test-suite spec hs-source-dirs: tests examples default-language: Haskell2010 build-depends: - base >=4.7 + base >=4.7 , bytestring - , Cabal >=1.22 - , containers >=0.5 && <0.6 - , criterion >=1.0 - , galois-field >=1.0.4 - , hspec >=2.0 - , mtl >=2.2 && <2.3 - , parallel >=3.2 && <3.3 - , process >=1.2 + , Cabal >=1.22 + , containers >=0.5 && <0.6 + , criterion >=1.0 + , galois-field >=1.0.4 + , hspec >=2.0 + , mtl >=2.2 && <2.3 + , parallel >=3.2 && <3.3 , prettyprinter + , process >=1.2 , QuickCheck - , snarkl >=0.1.0.0 + , snarkl >=0.1.0.0 benchmark criterion type: exitcode-stdio-1.0 main-is: Main.hs - ghc-options: -threaded -O2 + ghc-options: -threaded -O2 other-modules: Harness Snarkl.Example.Basic diff --git a/src/Snarkl/Errors.hs b/src/Snarkl/Errors.hs index 2a7c076..f917226 100644 --- a/src/Snarkl/Errors.hs +++ b/src/Snarkl/Errors.hs @@ -1,8 +1,8 @@ module Snarkl.Errors where -import Control.Exception +import Control.Exception (Exception, throw) import Data.String (IsString) -import Data.Typeable +import Data.Typeable (Typeable) newtype ErrMsg = ErrMsg {errMsg :: String} deriving (Typeable, IsString) diff --git a/src/Snarkl/Interp.hs b/src/Snarkl/Interp.hs index 2989d22..ba269e7 100644 --- a/src/Snarkl/Interp.hs +++ b/src/Snarkl/Interp.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE LambdaCase #-} - module Snarkl.Interp ( interp, ) @@ -162,4 +160,4 @@ interpCoreExpr = \case Or -> return $ fieldOfBool $ b1 || b2 XOr -> return $ fieldOfBool $ (b1 && not b2) || (b2 && not b1) BEq -> return $ fieldOfBool $ b1 == b2 - _ -> failWith $ ErrMsg "internal error in interp_binop" \ No newline at end of file + _ -> failWith $ ErrMsg "internal error in interp_binop" diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index 684573d..886eee7 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -96,7 +96,72 @@ import Snarkl.Language.Core import Snarkl.Language.Expr (mkProgram) import Snarkl.Language.LambdaExpr (expOfLambdaExp) import Snarkl.Language.Syntax + ( Derive, + Zippable, + apply, + arr, + arr2, + arr3, + beq, + bigsum, + case_sum, + curry, + dec, + eq, + exp_of_int, + fix, + fixN, + forall, + forall2, + forall3, + fromField, + fst_pair, + get, + get2, + get3, + get4, + ifThenElse, + inc, + inl, + input_arr, + input_arr2, + input_arr3, + inr, + iter, + iterM, + lambda, + negate, + not, + pair, + roll, + set, + set2, + set3, + set4, + snd_pair, + times, + uncurry, + unroll, + xor, + zeq, + (&&), + (*), + (+), + (-), + (/), + ) import Snarkl.Language.SyntaxMonad + ( Comp, + Env (..), + false, + fresh_input, + return, + runState, + true, + unit, + (>>), + (>>=), + ) import Snarkl.Language.TExpr (TExp, booleanVarsOfTexp, tExpToLambdaExp) import Snarkl.Language.Type import Prelude (Either (..), error, ($), (.), (<>)) @@ -106,4 +171,4 @@ compileTExpToProgram te = let eprog = mkProgram . expOfLambdaExp . tExpToLambdaExp $ te in case eprog of Right p -> p - Left err -> error $ "compileTExpToProgram: failed to convert TExp to Program: " <> err \ No newline at end of file + Left err -> error $ "compileTExpToProgram: failed to convert TExp to Program: " <> err diff --git a/src/Snarkl/Language/Core.hs b/src/Snarkl/Language/Core.hs index 3709b9d..d851af1 100644 --- a/src/Snarkl/Language/Core.hs +++ b/src/Snarkl/Language/Core.hs @@ -1,12 +1,10 @@ -{-# LANGUAGE LambdaCase #-} - module Snarkl.Language.Core where import Data.Field.Galois (GaloisField) import Data.Kind (Type) import Data.Sequence (Seq) import Prettyprinter (Pretty) -import Snarkl.Common +import Snarkl.Common (Op, UnOp) newtype Variable = Variable Int deriving (Eq, Ord, Show, Pretty) @@ -25,4 +23,4 @@ deriving instance (Show a) => Show (Exp a) data Assignment a = Assignment Variable (Exp a) data Program :: Type -> Type where - Program :: Seq (Assignment a) -> Exp a -> Program a \ No newline at end of file + Program :: Seq (Assignment a) -> Exp a -> Program a diff --git a/src/Snarkl/Language/Expr.hs b/src/Snarkl/Language/Expr.hs index 7624f4c..01f8944 100644 --- a/src/Snarkl/Language/Expr.hs +++ b/src/Snarkl/Language/Expr.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternSynonyms #-} module Snarkl.Language.Expr @@ -164,4 +163,4 @@ mkProgram _exp = do assignment <- hoistEither $ mkAssignment e modify (|> assignment) go rest - _ -> Core.Program Empty <$> mkExpression e' \ No newline at end of file + _ -> Core.Program Empty <$> mkExpression e' diff --git a/src/Snarkl/Language/TExpr.hs b/src/Snarkl/Language/TExpr.hs index d2f6d65..ac6a16b 100644 --- a/src/Snarkl/Language/TExpr.hs +++ b/src/Snarkl/Language/TExpr.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} module Snarkl.Language.TExpr @@ -222,4 +221,4 @@ lastSeq :: TExp ty a lastSeq te = case te of TESeq _ te2 -> lastSeq te2 - _ -> te \ No newline at end of file + _ -> te diff --git a/src/Snarkl/Language/Type.hs b/src/Snarkl/Language/Type.hs index fa3a489..e7203f3 100644 --- a/src/Snarkl/Language/Type.hs +++ b/src/Snarkl/Language/Type.hs @@ -74,4 +74,4 @@ type instance Rep ('TFProd f g) x = 'TProd (Rep f x) (Rep g x) type instance Rep ('TFSum f g) x = 'TSum (Rep f x) (Rep g x) -type instance Rep ('TFComp f g) x = Rep f (Rep g x) \ No newline at end of file +type instance Rep ('TFComp f g) x = Rep f (Rep g x) diff --git a/tests/Test/ArkworksBridge.hs b/tests/Test/ArkworksBridge.hs index 8ac8dba..e1e4703 100644 --- a/tests/Test/ArkworksBridge.hs +++ b/tests/Test/ArkworksBridge.hs @@ -3,7 +3,6 @@ module Test.ArkworksBridge where import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (GaloisField, PrimeField) import Data.Typeable (Typeable) -import Prettyprinter (Pretty) import Snarkl.Backend.R1CS import Snarkl.Compile (SimplParam, compileCompToR1CS) import Snarkl.Language (Comp) @@ -15,7 +14,7 @@ data CMD k where CreateProof :: (Typeable ty, GaloisField k) => FilePath -> String -> SimplParam -> Comp ty k -> [k] -> CMD k RunR1CS :: (Typeable ty, GaloisField k) => FilePath -> String -> SimplParam -> Comp ty k -> [k] -> CMD k -runCMD :: (PrimeField k, Pretty k) => CMD k -> IO GHC.ExitCode +runCMD :: (PrimeField k) => CMD k -> IO GHC.ExitCode runCMD (CreateTrustedSetup rootDir name simpl c) = do let r1cs = compileCompToR1CS simpl c r1csFilePath = mkR1CSFilePath rootDir name diff --git a/tests/Test/Snarkl/LambdaSpec.hs b/tests/Test/Snarkl/LambdaSpec.hs index ecf636e..5feedb8 100644 --- a/tests/Test/Snarkl/LambdaSpec.hs +++ b/tests/Test/Snarkl/LambdaSpec.hs @@ -10,6 +10,7 @@ import qualified Data.Map as Map import GHC.TypeLits (KnownNat) import Snarkl.Field import Snarkl.Interp (interp) +import Snarkl.Language (TExp, Ty (TField, TFun, TProd)) import Snarkl.Language.Syntax ( apply, curry, @@ -20,7 +21,6 @@ import Snarkl.Language.Syntax (+), ) import qualified Snarkl.Language.SyntaxMonad as SM -import Snarkl.Language.TExpr (TExp, Ty (TField, TFun, TProd)) import Snarkl.Toplevel (comp_interp) import Test.Hspec (Spec, describe, it, shouldBe) import Test.QuickCheck (Testable (property)) From 6c2971d9991ad15cb45ad8a22f12d287a333e55a Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 19:08:37 -0800 Subject: [PATCH 08/19] bump cachix versions --- .github/workflows/nix-ci.yml | 4 ++-- dev-profile | 1 + dev-profile-1-link | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) create mode 120000 dev-profile create mode 120000 dev-profile-1-link diff --git a/.github/workflows/nix-ci.yml b/.github/workflows/nix-ci.yml index 8f62193..0177596 100644 --- a/.github/workflows/nix-ci.yml +++ b/.github/workflows/nix-ci.yml @@ -20,11 +20,11 @@ jobs: repo: torsion-labs/arkworks-bridge tag: v0.2.0 - - uses: cachix/install-nix-action@v22 + - uses: cachix/install-nix-action@v24 with: nix_path: nixpkgs=channel:nixos-unstable - - uses: cachix/cachix-action@v12 + - uses: cachix/cachix-action@v13 with: name: martyall authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' diff --git a/dev-profile b/dev-profile new file mode 120000 index 0000000..0de5ccd --- /dev/null +++ b/dev-profile @@ -0,0 +1 @@ +dev-profile-1-link \ No newline at end of file diff --git a/dev-profile-1-link b/dev-profile-1-link new file mode 120000 index 0000000..1deb0e4 --- /dev/null +++ b/dev-profile-1-link @@ -0,0 +1 @@ +/nix/store/q5hddqvwpq1r39vq0b8r29n0kfrcw90g-ghc-shell-for-packages-env \ No newline at end of file From a5870ccefbfa47f89d15fee1959c1b98260114be Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 19:08:56 -0800 Subject: [PATCH 09/19] Remove generated files --- dev-profile | 1 - dev-profile-1-link | 1 - 2 files changed, 2 deletions(-) delete mode 120000 dev-profile delete mode 120000 dev-profile-1-link diff --git a/dev-profile b/dev-profile deleted file mode 120000 index 0de5ccd..0000000 --- a/dev-profile +++ /dev/null @@ -1 +0,0 @@ -dev-profile-1-link \ No newline at end of file diff --git a/dev-profile-1-link b/dev-profile-1-link deleted file mode 120000 index 1deb0e4..0000000 --- a/dev-profile-1-link +++ /dev/null @@ -1 +0,0 @@ -/nix/store/q5hddqvwpq1r39vq0b8r29n0kfrcw90g-ghc-shell-for-packages-env \ No newline at end of file From b1eb42ff9c5cc74f9ffb17d3d6eb90fc9116dbe0 Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 21:10:24 -0800 Subject: [PATCH 10/19] clean up compile --- app/Main.hs | 17 +++++++++++++---- examples/Snarkl/Example/Basic.hs | 13 ++++++------- snarkl.cabal | 2 -- tests/Test/Snarkl/Unit/Programs.hs | 1 - 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/app/Main.hs b/app/Main.hs index 3c55fb4..40b9008 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -4,18 +4,27 @@ import Control.Monad (unless) import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (PrimeField) import Data.Typeable (Typeable) -import Prettyprinter -import Snarkl.Compile (SimplParam (NoSimplify)) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Field +import Snarkl.Field (F_BN128) import Snarkl.Toplevel + ( Comp, + Result (..), + SimplParam (..), + execute, + mkInputsFilePath, + mkR1CSFilePath, + mkWitnessFilePath, + serializeInputsAsJson, + serializeR1CSAsJson, + serializeWitnessAsJson, + ) import qualified Test.Snarkl.Unit.Programs as Programs main :: IO () main = do executeAndWriteArtifacts "./snarkl-output" "prog2" NoSimplify (Programs.prog2 10) [1 :: F_BN128] -executeAndWriteArtifacts :: (Typeable ty, Pretty k, PrimeField k) => FilePath -> String -> SimplParam -> Comp ty k -> [k] -> IO () +executeAndWriteArtifacts :: (Typeable ty, PrimeField k) => FilePath -> String -> SimplParam -> Comp ty k -> [k] -> IO () executeAndWriteArtifacts fp name simpl mf inputs = do let Result {result_sat = isSatisfied, result_r1cs = r1cs, result_witness = wit} = execute simpl mf inputs unless isSatisfied $ failWith $ ErrMsg "R1CS is not satisfied" diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index 91a9a7c..6572b3d 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -5,7 +5,6 @@ module Snarkl.Example.Basic where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable (Typeable) import GHC.TypeLits (KnownNat) -import Prettyprinter (Pretty (pretty)) import Snarkl.Compile import Snarkl.Field (F_BN128) import Snarkl.Language.Syntax @@ -43,10 +42,10 @@ arr_ex x = do p1 :: (GaloisField k) => Comp 'TField k p1 = arr_ex $ fromField 1 -desugar1 :: (GaloisField k, Pretty k) => TExpPkg 'TField k +desugar1 :: (GaloisField k) => TExpPkg 'TField k desugar1 = compileCompToTexp p1 -interp1 :: (GaloisField k, Pretty k) => k +interp1 :: (GaloisField k) => k interp1 = comp_interp p1 [] p2 = do @@ -55,13 +54,13 @@ p2 = do desugar2 = compileCompToTexp p2 -interp2 :: (GaloisField k, Pretty k) => k +interp2 :: (GaloisField k) => k interp2 = comp_interp p2 [] -interp2' :: (GaloisField k, Pretty k) => k +interp2' :: (GaloisField k) => k interp2' = comp_interp p2 [256] -compile1 :: (GaloisField k, Pretty k) => R1CS k +compile1 :: (GaloisField k) => R1CS k compile1 = compileCompToR1CS Simplify p1 comp1 :: (GaloisField k, Typeable a) => Comp ('TSum 'TBool a) k @@ -74,4 +73,4 @@ test1 :: (GaloisField k) => State (Env k) (TExp 'TBool k) test1 = do b <- fresh_input z <- if return b then comp1 else comp2 - case_sum return (const $ return false) z + case_sum return (const $ return false) z \ No newline at end of file diff --git a/snarkl.cabal b/snarkl.cabal index 50b1947..7d50800 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -249,6 +249,4 @@ executable compile , bytestring , containers , galois-field >=1.0.4 - , hspec >=2.0 - , prettyprinter , snarkl >=0.1.0.0 diff --git a/tests/Test/Snarkl/Unit/Programs.hs b/tests/Test/Snarkl/Unit/Programs.hs index fa316ef..0202751 100644 --- a/tests/Test/Snarkl/Unit/Programs.hs +++ b/tests/Test/Snarkl/Unit/Programs.hs @@ -20,7 +20,6 @@ import Snarkl.Language.Syntax import Snarkl.Language.SyntaxMonad import Snarkl.Language.TExpr import Snarkl.Toplevel -import Test.Hspec (Spec, describe, it, shouldBe, shouldReturn) import Prelude hiding ( fromRational, negate, From 4050e1c038db70af6e4d105cf0d64694f2eeb60e Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 22:00:45 -0800 Subject: [PATCH 11/19] clean up constraints in language module --- print-examples/Main.hs | 4 +- snarkl.cabal | 7 ++- src/Snarkl/Backend/R1CS/Poly.hs | 21 ++++----- src/Snarkl/Backend/R1CS/R1CS.hs | 20 ++++---- src/Snarkl/Common.hs | 4 +- src/Snarkl/Compile.hs | 4 +- src/Snarkl/Field.hs | 8 +--- src/Snarkl/Language/Core.hs | 18 +++---- src/Snarkl/Language/Expr.hs | 48 +++++++++---------- src/Snarkl/Language/LambdaExpr.hs | 28 +++++------ src/Snarkl/Language/Syntax.hs | 75 +++++++++++------------------- src/Snarkl/Language/SyntaxMonad.hs | 39 ++++++---------- src/Snarkl/Language/TExpr.hs | 42 ++++++++--------- src/Snarkl/Language/Type.hs | 2 +- src/Snarkl/Toplevel.hs | 4 +- tests/Test/Snarkl/UnitSpec.hs | 2 +- 16 files changed, 141 insertions(+), 185 deletions(-) diff --git a/print-examples/Main.hs b/print-examples/Main.hs index 63a7904..91c579d 100644 --- a/print-examples/Main.hs +++ b/print-examples/Main.hs @@ -3,11 +3,11 @@ module Main where import Data.Foldable (traverse_) -import Prettyprinter -import Prettyprinter.Render.String (renderString) import Snarkl.Field () import Snarkl.Toplevel (compileCompToTexp) import Test.Snarkl.Unit.Programs +import Text.PrettyPrint.Leijen.Text +import Text.PrettyPrint.Leijen.Text.Render.String (renderString) main :: IO () main = do diff --git a/snarkl.cabal b/snarkl.cabal index 7d50800..183b982 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -67,6 +67,7 @@ library GeneralizedNewtypeDeriving KindSignatures LambdaCase + MultiParamTypeClasses OverloadedStrings PolyKinds RankNTypes @@ -90,9 +91,9 @@ library , lens , mtl >=2.2 && <2.3 , parallel >=3.2 && <3.3 - , prettyprinter , process >=1.2 , transformers + , wl-pprint-text hs-source-dirs: src default-language: Haskell2010 @@ -137,7 +138,6 @@ test-suite spec , hspec >=2.0 , mtl >=2.2 && <2.3 , parallel >=3.2 && <3.3 - , prettyprinter , process >=1.2 , QuickCheck , snarkl >=0.1.0.0 @@ -217,7 +217,6 @@ executable print-examples , containers , galois-field >=1.0.4 , hspec >=2.0 - , prettyprinter , snarkl >=0.1.0.0 executable compile @@ -249,4 +248,4 @@ executable compile , bytestring , containers , galois-field >=1.0.4 - , snarkl >=0.1.0.0 + , snarkl >=0.1.0.0 \ No newline at end of file diff --git a/src/Snarkl/Backend/R1CS/Poly.hs b/src/Snarkl/Backend/R1CS/Poly.hs index 72b07ca..23e5a32 100644 --- a/src/Snarkl/Backend/R1CS/Poly.hs +++ b/src/Snarkl/Backend/R1CS/Poly.hs @@ -5,15 +5,15 @@ module Snarkl.Backend.R1CS.Poly where import qualified Data.Aeson as A import Data.Field.Galois (GaloisField, PrimeField, fromP) import qualified Data.Map as Map -import Prettyprinter (Pretty (..)) import Snarkl.Common +import Text.PrettyPrint.Leijen.Text (Pretty (..)) -data Poly a where - Poly :: (GaloisField a) => Assgn a -> Poly a +data Poly k where + Poly :: (GaloisField k) => Assgn k -> Poly k -deriving instance (Show a) => Show (Poly a) +deriving instance Show (Poly k) -instance (Pretty a) => Pretty (Poly a) where +instance Pretty (Poly k) where pretty (Poly m) = pretty $ Map.toList m -- The reason we use incVar is that we want to use -1 internally as the constant @@ -21,22 +21,21 @@ instance (Pretty a) => Pretty (Poly a) where -- harder to work with downstream where e.g. arkworks expects positive indices). -- The reason we use show is because it's hard to deserialize large integers -- in certain langauges (e.g. javascript, even rust). -instance (PrimeField a) => A.ToJSON (Poly a) where - toJSON :: Poly a -> A.Value +instance (PrimeField k) => A.ToJSON (Poly k) where toJSON (Poly m) = let kvs = map (\(var, coeff) -> (show $ fromP coeff, incVar var)) $ Map.toList m in A.toJSON kvs -- | The constant polynomial equal 'c' -const_poly :: (GaloisField a) => a -> Poly a +const_poly :: (GaloisField k) => k -> Poly k const_poly c = Poly $ Map.insert (Var (-1)) c Map.empty -- | The polynomial equal variable 'x' var_poly :: - (GaloisField a) => + (GaloisField k) => -- | Variable, with coeff - (a, Var) -> + (k, Var) -> -- | Resulting polynomial - Poly a + Poly k var_poly (coeff, x) = Poly $ Map.insert x coeff Map.empty diff --git a/src/Snarkl/Backend/R1CS/R1CS.hs b/src/Snarkl/Backend/R1CS/R1CS.hs index 0e47e12..e077fd6 100644 --- a/src/Snarkl/Backend/R1CS/R1CS.hs +++ b/src/Snarkl/Backend/R1CS/R1CS.hs @@ -11,19 +11,19 @@ import Control.Parallel.Strategies import qualified Data.Aeson as A import Data.Field.Galois (GaloisField, PrimeField) import qualified Data.Map as Map -import Prettyprinter (Pretty (..), (<+>)) import Snarkl.Backend.R1CS.Poly import Snarkl.Common import Snarkl.Errors +import Text.PrettyPrint.Leijen.Text (Pretty (..), (<+>)) ---------------------------------------------------------------- -- Rank-1 Constraint Systems -- ---------------------------------------------------------------- -data R1C a where - R1C :: (GaloisField a) => (Poly a, Poly a, Poly a) -> R1C a +data R1C k where + R1C :: (GaloisField k) => (Poly k, Poly k, Poly k) -> R1C k -deriving instance (Show a) => Show (R1C a) +deriving instance (Show k) => Show (R1C k) instance (PrimeField k) => A.ToJSON (R1C k) where toJSON (R1C (a, b, c)) = @@ -33,7 +33,7 @@ instance (PrimeField k) => A.ToJSON (R1C k) where "C" A..= c ] -instance (Pretty a) => Pretty (R1C a) where +instance Pretty (R1C k) where pretty (R1C (aV, bV, cV)) = pretty aV <+> "*" <+> pretty bV <+> "==" <+> pretty cV data R1CS a = R1CS @@ -44,19 +44,19 @@ data R1CS a = R1CS r1cs_gen_witness :: Assgn a -> Assgn a } -instance (Show a) => Show (R1CS a) where +instance (Show k) => Show (R1CS k) where show (R1CS cs nvs ivs ovs _) = show (cs, nvs, ivs, ovs) -num_constraints :: R1CS a -> Int +num_constraints :: R1CS k -> Int num_constraints = length . r1cs_clauses -- sat_r1c: Does witness 'w' satisfy constraint 'c'? -sat_r1c :: (GaloisField a) => Assgn a -> R1C a -> Bool +sat_r1c :: (GaloisField k) => Assgn k -> R1C k -> Bool sat_r1c w c | R1C (aV, bV, cV) <- c = inner aV w * inner bV w == inner cV w where - inner :: (GaloisField a) => Poly a -> Assgn a -> a + inner :: (GaloisField k) => Poly k -> Assgn k -> k inner (Poly v) w' = let c0 = Map.findWithDefault 0 (Var (-1)) v in Map.foldlWithKey (f w') c0 v @@ -65,7 +65,7 @@ sat_r1c w c (v_val * Map.findWithDefault 0 v_key w') + acc -- sat_r1cs: Does witness 'w' satisfy constraint set 'cs'? -sat_r1cs :: (GaloisField a) => Assgn a -> R1CS a -> Bool +sat_r1cs :: (GaloisField k) => Assgn k -> R1CS k -> Bool sat_r1cs w cs = and $ is_sat (r1cs_clauses cs) where is_sat cs0 = map g cs0 `using` parListChunk (chunk_sz cs0) rseq diff --git a/src/Snarkl/Common.hs b/src/Snarkl/Common.hs index 708a954..374e523 100644 --- a/src/Snarkl/Common.hs +++ b/src/Snarkl/Common.hs @@ -4,7 +4,7 @@ module Snarkl.Common where import qualified Data.Aeson as A import qualified Data.Map as Map -import Prettyprinter (Pretty (pretty)) +import Text.PrettyPrint.Leijen.Text (Pretty (pretty)) newtype Var = Var Int deriving (Eq, Ord, Show, A.ToJSON) @@ -67,4 +67,4 @@ isAssoc op = case op of Or -> True XOr -> True Eq -> True - BEq -> True + BEq -> True \ No newline at end of file diff --git a/src/Snarkl/Compile.hs b/src/Snarkl/Compile.hs index a34f069..124eca4 100644 --- a/src/Snarkl/Compile.hs +++ b/src/Snarkl/Compile.hs @@ -29,7 +29,6 @@ import Data.List (sort) import qualified Data.Map as Map import qualified Data.Set as Set import Data.Typeable (Typeable) -import Prettyprinter (Pretty (..)) import Snarkl.Backend.R1CS.R1CS (R1CS) import Snarkl.Common (Op (..), UnOp (..), Var (Var), incVar) import Snarkl.Constraint @@ -56,6 +55,7 @@ import Snarkl.Language runState, ) import qualified Snarkl.Language.Core as Core +import Text.PrettyPrint.Leijen.Text (Pretty (..)) ---------------------------------------------------------------- -- @@ -457,7 +457,7 @@ data TExpPkg ty k = TExpPkg } deriving (Show) -instance (Typeable ty, Pretty k) => Pretty (TExpPkg ty k) where +instance (Typeable ty) => Pretty (TExpPkg ty k) where pretty (TExpPkg _ _ e) = pretty e deriving instance (Eq (TExp ty k)) => Eq (TExpPkg ty k) diff --git a/src/Snarkl/Field.hs b/src/Snarkl/Field.hs index f9616ca..f8f6e9f 100644 --- a/src/Snarkl/Field.hs +++ b/src/Snarkl/Field.hs @@ -2,12 +2,8 @@ module Snarkl.Field where -import Data.Field.Galois (Prime, fromP) -import Prettyprinter (Pretty (..)) +import Data.Field.Galois (Prime) type P_BN128 = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -type F_BN128 = Prime P_BN128 - -instance Pretty F_BN128 where - pretty = pretty . fromP +type F_BN128 = Prime P_BN128 \ No newline at end of file diff --git a/src/Snarkl/Language/Core.hs b/src/Snarkl/Language/Core.hs index d851af1..4a75db2 100644 --- a/src/Snarkl/Language/Core.hs +++ b/src/Snarkl/Language/Core.hs @@ -3,22 +3,18 @@ module Snarkl.Language.Core where import Data.Field.Galois (GaloisField) import Data.Kind (Type) import Data.Sequence (Seq) -import Prettyprinter (Pretty) import Snarkl.Common (Op, UnOp) +import Text.PrettyPrint.Leijen.Text (Pretty) newtype Variable = Variable Int deriving (Eq, Ord, Show, Pretty) data Exp :: Type -> Type where - EVar :: Variable -> Exp a - EVal :: (GaloisField a) => a -> Exp a - EUnop :: UnOp -> Exp a -> Exp a - EBinop :: Op -> [Exp a] -> Exp a - EIf :: Exp a -> Exp a -> Exp a -> Exp a - EUnit :: Exp a - -deriving instance (Eq a) => Eq (Exp a) - -deriving instance (Show a) => Show (Exp a) + EVar :: Variable -> Exp k + EVal :: (GaloisField k) => k -> Exp k + EUnop :: UnOp -> Exp k -> Exp k + EBinop :: Op -> [Exp k] -> Exp k + EIf :: Exp k -> Exp a -> Exp k -> Exp k + EUnit :: Exp k data Assignment a = Assignment Variable (Exp a) diff --git a/src/Snarkl/Language/Expr.hs b/src/Snarkl/Language/Expr.hs index 01f8944..4480513 100644 --- a/src/Snarkl/Language/Expr.hs +++ b/src/Snarkl/Language/Expr.hs @@ -20,31 +20,31 @@ import Data.Kind (Type) import Data.Map (Map) import qualified Data.Map as Map import Data.Sequence (Seq, fromList, (<|), (><), (|>), pattern Empty, pattern (:<|)) -import Prettyprinter +import Snarkl.Common (Op, UnOp, isAssoc) +import qualified Snarkl.Language.Core as Core +import Text.PrettyPrint.Leijen.Text ( Pretty (pretty), hsep, parens, punctuate, (<+>), ) -import Snarkl.Common (Op, UnOp, isAssoc) -import qualified Snarkl.Language.Core as Core data Exp :: Type -> Type where - EVar :: Core.Variable -> Exp a - EVal :: (GaloisField a) => a -> Exp a - EUnop :: UnOp -> Exp a -> Exp a - EBinop :: Op -> [Exp a] -> Exp a - EIf :: Exp a -> Exp a -> Exp a -> Exp a - EAssert :: Exp a -> Exp a -> Exp a - ESeq :: Seq (Exp a) -> Exp a - EUnit :: Exp a + EVar :: Core.Variable -> Exp k + EVal :: (GaloisField k) => k -> Exp k + EUnop :: UnOp -> Exp k -> Exp k + EBinop :: Op -> [Exp k] -> Exp k + EIf :: Exp k -> Exp k -> Exp k -> Exp k + EAssert :: Exp k -> Exp k -> Exp k + ESeq :: Seq (Exp k) -> Exp k + EUnit :: Exp k -deriving instance (Eq a) => Eq (Exp a) +deriving instance Eq (Exp k) -deriving instance (Show a) => Show (Exp a) +deriving instance Show (Exp k) -const_prop :: (GaloisField a) => Exp a -> State (Map Core.Variable a) (Exp a) +const_prop :: (GaloisField k) => Exp k -> State (Map Core.Variable k) (Exp k) const_prop e = case e of EVar x -> lookup_var x @@ -75,23 +75,23 @@ const_prop e = return $ ESeq es' EUnit -> return EUnit where - lookup_var :: (GaloisField a) => Core.Variable -> State (Map Core.Variable a) (Exp a) + lookup_var :: (GaloisField k) => Core.Variable -> State (Map Core.Variable k) (Exp k) lookup_var x0 = gets ( \m -> case Map.lookup x0 m of Nothing -> EVar x0 Just c -> EVal c ) - add_bind :: (Core.Variable, a) -> State (Map Core.Variable a) (Exp a) + add_bind :: (Core.Variable, k) -> State (Map Core.Variable k) (Exp k) add_bind (x0, c0) = do modify (Map.insert x0 c0) return EUnit -do_const_prop :: (GaloisField a) => Exp a -> Exp a +do_const_prop :: (GaloisField k) => Exp k -> Exp k do_const_prop e = evalState (const_prop e) Map.empty -instance (Pretty a) => Pretty (Exp a) where +instance Pretty (Exp k) where pretty (EVar x) = "var_" <> pretty x pretty (EVal c) = pretty c pretty (EUnop op e1) = pretty op <> parens (pretty e1) @@ -103,7 +103,7 @@ instance (Pretty a) => Pretty (Exp a) where pretty (ESeq es) = parens $ hsep $ punctuate ";" $ map pretty (toList es) pretty EUnit = "()" -mkExpression :: (Show a) => Exp a -> Either String (Core.Exp a) +mkExpression :: Exp k -> Either String (Core.Exp k) mkExpression = \case EVar x -> pure $ Core.EVar x EVal v -> pure $ Core.EVal v @@ -115,7 +115,7 @@ mkExpression = \case -- | Smart constructor for sequence, ensuring all expressions are -- flattened to top level. -expSeq :: Exp a -> Exp a -> Exp a +expSeq :: Exp k -> Exp k -> Exp k expSeq e1 e2 = case (e1, e2) of (ESeq l1, ESeq l2) -> ESeq (l1 >< l2) @@ -123,7 +123,7 @@ expSeq e1 e2 = (_, ESeq l2) -> ESeq (e1 <| l2) (_, _) -> ESeq (fromList [e1, e2]) -expBinop :: Op -> Exp a -> Exp a -> Exp a +expBinop :: Op -> Exp k -> Exp k -> Exp k expBinop op e1 e2 = case (e1, e2) of (EBinop op1 l1, EBinop op2 l2) @@ -137,14 +137,14 @@ expBinop op e1 e2 = EBinop op (e1 : l2) (_, _) -> EBinop op [e1, e2] -mkAssignment :: (Show a) => Exp a -> Either String (Core.Assignment a) +mkAssignment :: Exp k -> Either String (Core.Assignment k) mkAssignment (EAssert (EVar v) e) = Core.Assignment v <$> mkExpression e mkAssignment e = throwError $ "mkAssignment: expected EAssert, got " <> show e -- At this point the expression should be either: -- 1. A sequence of assignments, followed by an expression -- 2. An expression -mkProgram :: (GaloisField a) => Exp a -> Either String (Core.Program a) +mkProgram :: (GaloisField k) => Exp k -> Either String (Core.Program k) mkProgram _exp = do let e' = do_const_prop _exp case e' of @@ -152,7 +152,7 @@ mkProgram _exp = do let (eexpr, assignments) = runState (runExceptT $ go es) mempty Core.Program assignments <$> eexpr where - go :: (Show a) => Seq (Exp a) -> ExceptT String (State (Seq (Core.Assignment a))) (Core.Exp a) + go :: (Show k) => Seq (Exp k) -> ExceptT String (State (Seq (Core.Assignment k))) (Core.Exp k) go = \case Empty -> throwError "mkProgram: empty sequence" e :<| Empty -> hoistEither $ mkExpression e diff --git a/src/Snarkl/Language/LambdaExpr.hs b/src/Snarkl/Language/LambdaExpr.hs index a7bf8e7..897383f 100644 --- a/src/Snarkl/Language/LambdaExpr.hs +++ b/src/Snarkl/Language/LambdaExpr.hs @@ -23,20 +23,20 @@ import qualified Snarkl.Language.Expr as E -- to the right and then flatten them. Similarly nested Binops of the same -- operator are flattened into a single list if that operator is associative. data Exp :: Type -> Type where - EVar :: Variable -> Exp a - EVal :: (GaloisField a) => a -> Exp a - EUnop :: UnOp -> Exp a -> Exp a - EBinop :: Op -> Exp a -> Exp a -> Exp a - EIf :: Exp a -> Exp a -> Exp a -> Exp a - EAssert :: Exp a -> Exp a -> Exp a - ESeq :: Exp a -> Exp a -> Exp a - EUnit :: Exp a - EAbs :: Variable -> Exp a -> Exp a - EApp :: Exp a -> Exp a -> Exp a + EVar :: Variable -> Exp k + EVal :: (GaloisField k) => k -> Exp k + EUnop :: UnOp -> Exp k -> Exp k + EBinop :: Op -> Exp k -> Exp k -> Exp k + EIf :: Exp k -> Exp k -> Exp k -> Exp k + EAssert :: Exp k -> Exp k -> Exp k + ESeq :: Exp k -> Exp k -> Exp k + EUnit :: Exp k + EAbs :: Variable -> Exp k -> Exp k + EApp :: Exp k -> Exp k -> Exp k -deriving instance (Show a) => Show (Exp a) +deriving instance Show (Exp k) -deriving instance (Eq a) => Eq (Exp a) +deriving instance Eq (Exp k) betaNormalize :: Exp a -> Exp a betaNormalize = \case @@ -68,14 +68,14 @@ betaNormalize = \case EAbs var' e -> EAbs var' (substitute (var, e1) e) EApp e2 e3 -> EApp (substitute (var, e1) e2) (substitute (var, e1) e3) -expOfLambdaExp :: (Show a) => Exp a -> E.Exp a +expOfLambdaExp :: Exp k -> E.Exp k expOfLambdaExp _exp = let coreExp = betaNormalize _exp in case expOfLambdaExp' coreExp of Left err -> failWith $ ErrMsg err Right e -> e where - expOfLambdaExp' :: (Show a) => Exp a -> Either String (E.Exp a) + expOfLambdaExp' :: Exp k -> Either String (E.Exp k) expOfLambdaExp' = \case EVar var -> pure $ E.EVar var EVal v -> pure $ E.EVal v diff --git a/src/Snarkl/Language/Syntax.hs b/src/Snarkl/Language/Syntax.hs index 391e12c..7cae059 100644 --- a/src/Snarkl/Language/Syntax.hs +++ b/src/Snarkl/Language/Syntax.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RebindableSyntax #-} module Snarkl.Language.Syntax @@ -136,7 +135,7 @@ dec n = (P.-) n 1 -- | 2-d arrays. 'width' is the size, in "bits" (#field elements), of -- each array element. -arr2 :: (Typeable ty, GaloisField k) => Int -> Int -> Comp ('TArr ('TArr ty)) k +arr2 :: (Typeable ty) => Int -> Int -> Comp ('TArr ('TArr ty)) k arr2 len width = do a <- arr len @@ -151,7 +150,7 @@ arr2 len width = return a -- | 3-d arrays. -arr3 :: (Typeable ty, GaloisField k) => Int -> Int -> Int -> Comp ('TArr ('TArr ('TArr ty))) k +arr3 :: (Typeable ty) => Int -> Int -> Int -> Comp ('TArr ('TArr ('TArr ty))) k arr3 len width height = do a <- arr2 len width @@ -165,7 +164,7 @@ arr3 len width height = ) return a -input_arr2 :: (Typeable ty, GaloisField k) => Int -> Int -> Comp ('TArr ('TArr ty)) k +input_arr2 :: (Typeable ty) => Int -> Int -> Comp ('TArr ('TArr ty)) k input_arr2 0 _ = raise_err $ ErrMsg "array must have size > 0" input_arr2 len width = do @@ -180,7 +179,7 @@ input_arr2 len width = ) return a -input_arr3 :: (Typeable ty, GaloisField k) => Int -> Int -> Int -> Comp ('TArr ('TArr ('TArr ty))) k +input_arr3 :: (Typeable ty) => Int -> Int -> Int -> Comp ('TArr ('TArr ('TArr ty))) k input_arr3 len width height = do a <- arr2 len width @@ -194,13 +193,13 @@ input_arr3 len width height = ) return a -set2 :: (Typeable ty2, GaloisField k) => (TExp ('TArr ('TArr ty2)) k, Int, Int) -> TExp ty2 k -> Comp 'TUnit k +set2 :: (Typeable ty2) => (TExp ('TArr ('TArr ty2)) k, Int, Int) -> TExp ty2 k -> Comp 'TUnit k set2 (a, i, j) e = do a' <- get (a, i) set (a', j) e set3 :: - (Typeable ty, GaloisField k) => + (Typeable ty) => ( TExp ('TArr ('TArr ('TArr ty))) k, Int, Int, @@ -213,7 +212,7 @@ set3 (a, i, j, k) e = do set (a', k) e set4 :: - (Typeable ty, GaloisField k) => + (Typeable ty) => ( TExp ('TArr ('TArr ('TArr ('TArr ty)))) k, Int, Int, @@ -226,13 +225,13 @@ set4 (a, i, j, k, l) e = do a' <- get3 (a, i, j, k) set (a', l) e -get2 :: (Typeable ty2, GaloisField k) => (TExp ('TArr ('TArr ty2)) k, Int, Int) -> State (Env k) (TExp ty2 k) +get2 :: (Typeable ty2) => (TExp ('TArr ('TArr ty2)) k, Int, Int) -> State (Env k) (TExp ty2 k) get2 (a, i, j) = do a' <- get (a, i) get (a', j) get3 :: - (Typeable ty, GaloisField k) => + (Typeable ty) => ( TExp ('TArr ('TArr ('TArr ty))) k, Int, Int, @@ -244,7 +243,7 @@ get3 (a, i, j, k) = do get (a', k) get4 :: - (Typeable ty, GaloisField k) => + (Typeable ty) => ( TExp ('TArr ('TArr ('TArr ('TArr ty)))) k, Int, Int, @@ -273,8 +272,7 @@ unrep_sum :: unrep_sum = unsafe_cast inl :: - (GaloisField k) => - forall ty1 ty2. + forall ty1 ty2 k. ( Typeable ty1, Typeable ty2 ) => @@ -294,8 +292,7 @@ inl te1 = inr :: forall ty1 ty2 k. ( Typeable ty1, - Typeable ty2, - GaloisField k + Typeable ty2 ) => TExp ty2 k -> Comp ('TSum ty1 ty2) k @@ -314,9 +311,7 @@ case_sum :: forall ty1 ty2 ty k. ( Typeable ty1, Typeable ty2, - Typeable ty, - Zippable ty k, - GaloisField k + Zippable ty k ) => (TExp ty1 k -> Comp ty k) -> (TExp ty2 k -> Comp ty k) -> @@ -356,7 +351,7 @@ instance Derive 'TBool k where instance (GaloisField k) => Derive 'TField k where derive _ = return $ TEVal (VField 0) -instance (Typeable ty, Derive ty k, GaloisField k) => Derive ('TArr ty) k where +instance (Typeable ty, Derive ty k) => Derive ('TArr ty) k where derive n = do a <- arr 1 @@ -368,8 +363,7 @@ instance ( Typeable ty1, Derive ty1 k, Typeable ty2, - Derive ty2 k, - GaloisField k + Derive ty2 k ) => Derive ('TProd ty1 ty2) k where @@ -382,8 +376,7 @@ instance instance ( Typeable ty1, Derive ty1 k, - Typeable ty2, - GaloisField k + Typeable ty2 ) => Derive ('TSum ty1 ty2) k where @@ -393,10 +386,7 @@ instance inl v1 instance - ( Typeable f, - Typeable (Rep f ('TMu f)), - Derive (Rep f ('TMu f)) k, - GaloisField k + ( Derive (Rep f ('TMu f)) k ) => Derive ('TMu f) k where @@ -426,7 +416,7 @@ class Zippable ty k where instance Zippable 'TUnit k where zip_vals _ _ _ = return unit -zip_base :: (Typeable ty, GaloisField k) => TExp 'TBool k -> TExp ty k -> TExp ty k -> Comp ty k +zip_base :: (Typeable ty) => TExp 'TBool k -> TExp ty k -> TExp ty k -> Comp ty k zip_base TEBot _ _ = return TEBot zip_base _ TEBot e2 = return e2 zip_base _ e1 TEBot = return e1 @@ -450,18 +440,17 @@ zip_base b e1 e2 = ) b -instance (GaloisField k) => Zippable 'TBool k where +instance Zippable 'TBool k where zip_vals b b1 b2 = zip_base b b1 b2 -instance (GaloisField k) => Zippable 'TField k where +instance Zippable 'TField k where zip_vals b e1 e2 = zip_base b e1 e2 fuel :: Int fuel = 1 check_bots :: - ( Derive ty k, - GaloisField k + ( Derive ty k ) => Comp ty k -> TExp 'TBool k -> @@ -491,8 +480,7 @@ instance Derive ty1 k, Zippable ty2 k, Typeable ty2, - Derive ty2 k, - GaloisField k + Derive ty2 k ) => Zippable ('TProd ty1 ty2) k where @@ -513,8 +501,7 @@ instance Derive ty1 k, Zippable ty2 k, Typeable ty2, - Derive ty2 k, - GaloisField k + Derive ty2 k ) => Zippable ('TSum ty1 ty2) k where @@ -527,11 +514,8 @@ instance return $ unrep_sum p' instance - ( Typeable f, - Typeable (Rep f ('TMu f)), - Zippable (Rep f ('TMu f)) k, - Derive (Rep f ('TMu f)) k, - GaloisField k + ( Zippable (Rep f ('TMu f)) k, + Derive (Rep f ('TMu f)) k ) => Zippable ('TMu f) k where @@ -631,7 +615,7 @@ fix = fixN 100 zeq :: TExp 'TField k -> TExp 'TBool k zeq e = TEUnop (TUnop ZEq) e -not :: (Eq k) => TExp 'TBool k -> TExp 'TBool k +not :: TExp 'TBool k -> TExp 'TBool k not e = ifThenElse_aux e false true xor :: TExp 'TBool k -> TExp 'TBool k -> TExp 'TBool k @@ -650,7 +634,6 @@ exp_of_int :: (GaloisField k) => Int -> TExp 'TField k exp_of_int i = TEVal (VField $ fromIntegral i) ifThenElse_aux :: - (Eq a) => TExp 'TBool a -> TExp ty a -> TExp ty a -> @@ -664,8 +647,7 @@ ifThenElse_aux b e1 e2 _ -> TEIf b e1 e2 ifThenElse :: - ( Zippable ty k, - Typeable ty + ( Zippable ty k ) => Comp 'TBool k -> Comp ty k -> @@ -698,7 +680,6 @@ iter n f e = g n f e g m f' e' = f' m $ g (dec m) f' e' iterM :: - (Typeable ty) => Int -> (Int -> TExp ty k -> Comp ty k) -> TExp ty k -> @@ -765,7 +746,6 @@ curry :: (Typeable a) => (Typeable b) => (Typeable c) => - (GaloisField k) => (TExp ('TProd a b) k -> Comp c k) -> TExp a k -> Comp ('TFun b c) k @@ -778,7 +758,6 @@ uncurry :: (Typeable a) => (Typeable b) => (Typeable c) => - (GaloisField k) => (TExp a k -> Comp ('TFun b c) k) -> TExp ('TProd a b) k -> Comp c k diff --git a/src/Snarkl/Language/SyntaxMonad.hs b/src/Snarkl/Language/SyntaxMonad.hs index f460ca3..2bf94a1 100644 --- a/src/Snarkl/Language/SyntaxMonad.hs +++ b/src/Snarkl/Language/SyntaxMonad.hs @@ -50,7 +50,6 @@ where import Control.Monad (forM, replicateM) import Control.Monad.Supply (Supply, runSupply) import Control.Monad.Supply.Class (MonadSupply (fresh)) -import Data.Field.Galois (GaloisField) import qualified Data.Map.Strict as Map import Data.String (IsString (..)) import Data.Typeable (Typeable) @@ -103,7 +102,6 @@ raise_err msg = State (\_ -> Left msg) -- of the results of 'mf', 'g' (not just whatever 'g' returns) (>>=) :: forall (ty1 :: Ty) (ty2 :: Ty) s a. - (Typeable ty1) => State s (TExp ty1 a) -> (TExp ty1 a -> State s (TExp ty2 a)) -> State s (TExp ty2 a) @@ -119,7 +117,6 @@ raise_err msg = State (\_ -> Left msg) (>>) :: forall (ty1 :: Ty) (ty2 :: Ty) s a. - (Typeable ty1) => State s (TExp ty1 a) -> State s (TExp ty2 a) -> State s (TExp ty2 a) @@ -286,7 +283,6 @@ get_addr (l, i) = guard :: (Typeable ty2) => - (GaloisField k) => (TExp ty k -> State (Env k) (TExp ty2 k)) -> TExp ty k -> State (Env k) (TExp ty2 k) @@ -300,19 +296,18 @@ guard f e = guarded_get_addr :: (Typeable ty2) => - (GaloisField k) => TExp ty k -> Int -> State (Env k) (TExp ty2 k) guarded_get_addr e i = guard (\e0 -> get_addr (locOfTexp e0, i)) e -get :: (Typeable ty) => (GaloisField k) => (TExp ('TArr ty) k, Int) -> Comp ty k +get :: (Typeable ty) => (TExp ('TArr ty) k, Int) -> Comp ty k get (TEBot, _) = return TEBot get (a, i) = guarded_get_addr a i -- | Smart constructor for TEAssert -te_assert :: (Typeable ty) => (GaloisField k) => TExp ty k -> TExp ty k -> Comp 'TUnit k +te_assert :: (Typeable ty) => TExp ty k -> TExp ty k -> Comp 'TUnit k te_assert x@(TEVar _) e = do e_bot <- is_bot e @@ -333,7 +328,6 @@ te_assert _ e = -- in the object map. set_addr :: (Typeable ty) => - (GaloisField k) => (TExp ('TArr ty) k, Int) -> TExp ty k -> Comp 'TUnit k @@ -361,7 +355,7 @@ set_addr (TEVal (VLoc (TLoc l)), i) e = set_addr (e1, _) _ = raise_err $ ErrMsg ("expected " ++ show e1 ++ " a loc") -set :: (Typeable ty, GaloisField k) => (TExp ('TArr ty) k, Int) -> TExp ty k -> Comp 'TUnit k +set :: (Typeable ty) => (TExp ('TArr ty) k, Int) -> TExp ty k -> Comp 'TUnit k set (a, i) e = set_addr (a, i) e {----------------------------------------------- @@ -370,8 +364,7 @@ set (a, i) e = set_addr (a, i) e pair :: ( Typeable ty1, - Typeable ty2, - GaloisField k + Typeable ty2 ) => TExp ty1 k -> TExp ty2 k -> @@ -411,15 +404,13 @@ pair te1 te2 = fst_pair :: (Typeable ty1) => - (GaloisField k) => TExp ('TProd ty1 ty2) k -> Comp ty1 k fst_pair TEBot = return TEBot fst_pair e = guarded_get_addr e 0 snd_pair :: - ( Typeable ty2, - GaloisField k + ( Typeable ty2 ) => TExp ('TProd ty1 ty2) k -> Comp ty2 k @@ -494,7 +485,7 @@ add_statics binds = ) -- | Does boolean expression 'e' resolve (statically) to 'b'? -is_bool :: (GaloisField k) => TExp ty k -> Bool -> Comp 'TBool k +is_bool :: TExp ty k -> Bool -> Comp 'TBool k is_bool (TEVal VFalse) False = return true is_bool (TEVal VTrue) True = return true is_bool e@(TEVar _) b = @@ -511,24 +502,24 @@ is_bool e@(TEVar _) b = ) is_bool _ _ = return false -is_false :: (GaloisField k) => TExp ty k -> Comp 'TBool k +is_false :: TExp ty k -> Comp 'TBool k is_false = flip is_bool False -is_true :: (GaloisField k) => TExp ty k -> Comp 'TBool k +is_true :: TExp ty k -> Comp 'TBool k is_true = flip is_bool True -- | Add binding 'x = b'. -assert_bool :: (GaloisField k) => TExp ty k -> Bool -> Comp 'TUnit k +assert_bool :: TExp ty k -> Bool -> Comp 'TUnit k assert_bool (TEVar (TVar x)) b = add_statics [(x, AnalBool b)] assert_bool e _ = raise_err $ ErrMsg $ "expected " ++ show e ++ " a variable" -assert_false :: (GaloisField k) => TExp ty k -> Comp 'TUnit k +assert_false :: TExp ty k -> Comp 'TUnit k assert_false = flip assert_bool False -assert_true :: (GaloisField k) => TExp ty k -> Comp 'TUnit k +assert_true :: TExp ty k -> Comp 'TUnit k assert_true = flip assert_bool True -var_is_bot :: (GaloisField k) => TExp ty k -> Comp 'TBool k +var_is_bot :: TExp ty k -> Comp 'TBool k var_is_bot e@(TEVar (TVar _)) = State ( \s -> @@ -542,7 +533,7 @@ var_is_bot e@(TEVar (TVar _)) = ) var_is_bot _ = return false -is_bot :: (GaloisField k) => TExp ty k -> Comp 'TBool k +is_bot :: TExp ty k -> Comp 'TBool k is_bot e = case e of e0@(TEVar _) -> var_is_bot e0 @@ -552,7 +543,7 @@ is_bot e = TEBot -> return true _ -> return false where - either_is_bot :: (GaloisField k) => TExp ty1 k -> TExp ty2 k -> Comp 'TBool k + either_is_bot :: TExp ty1 k -> TExp ty2 k -> Comp 'TBool k either_is_bot e10 e20 = do e1_bot <- is_bot e10 @@ -562,6 +553,6 @@ is_bot e = (_, TEVal VTrue) -> return true _ -> return false -assert_bot :: (GaloisField k) => TExp ty k -> Comp 'TUnit k +assert_bot :: TExp ty k -> Comp 'TUnit k assert_bot (TEVar (TVar x)) = add_statics [(x, AnalBot)] assert_bot e = raise_err $ ErrMsg $ "in assert_bot, expected " ++ show e ++ " a variable" diff --git a/src/Snarkl/Language/TExpr.hs b/src/Snarkl/Language/TExpr.hs index ac6a16b..7d97d81 100644 --- a/src/Snarkl/Language/TExpr.hs +++ b/src/Snarkl/Language/TExpr.hs @@ -20,12 +20,12 @@ where import Data.Field.Galois (GaloisField) import Data.Kind (Type) import Data.Typeable (Proxy (..), Typeable, eqT, typeOf, typeRep, type (:~:) (Refl)) -import Prettyprinter (Pretty (pretty), line, parens, (<+>)) import Snarkl.Common (Op, UnOp) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Language.Core (Variable) import qualified Snarkl.Language.LambdaExpr as LE import Snarkl.Language.Type (Ty (TBool, TField, TFun, TUnit)) +import Text.PrettyPrint.Leijen.Text (Pretty (pretty), line, parens, (<+>)) newtype TVar (ty :: Ty) = TVar Variable deriving (Eq, Show) @@ -58,17 +58,17 @@ instance Pretty (TOp ty1 ty2 ty3) where pretty (TOp op) = pretty op data Val :: Ty -> Type -> Type where - VField :: (GaloisField a) => a -> Val 'TField a - VTrue :: Val 'TBool a - VFalse :: Val 'TBool a - VUnit :: Val 'TUnit a - VLoc :: TLoc ty -> Val ty a + VField :: (GaloisField k) => k -> Val 'TField k + VTrue :: Val 'TBool k + VFalse :: Val 'TBool k + VUnit :: Val 'TUnit k + VLoc :: TLoc ty -> Val ty k -deriving instance (Eq a) => Eq (Val (b :: Ty) a) +deriving instance Eq (Val (b :: Ty) k) -deriving instance (Show a) => Show (Val (b :: Ty) a) +deriving instance Show (Val (b :: Ty) k) -instance (Pretty a) => Pretty (Val ty a) where +instance Pretty (Val ty k) where pretty v = case v of VField a -> pretty a VTrue -> "true" @@ -100,22 +100,22 @@ data TExp :: Ty -> Type -> Type where TEAbs :: (Typeable ty, Typeable ty1) => TVar ty -> TExp ty1 a -> TExp ('TFun ty ty1) a TEApp :: (Typeable ty, Typeable ty1) => TExp ('TFun ty ty1) a -> TExp ty a -> TExp ty1 a -deriving instance (Show a) => Show (TExp (b :: Ty) a) +deriving instance Show (TExp (b :: Ty) k) -instance (Eq a) => Eq (TExp (b :: Ty) a) where +instance Eq (TExp (b :: Ty) k) where TEVar x == TEVar y = x == y TEVal a == TEVal b = a == b TEUnop (op :: TUnop ty1 ty) e1 == TEUnop (op' :: TUnop ty2 ty) e1' = case eqT @ty1 @ty2 of Just Refl -> op == op' && e1 == e1' Nothing -> False - TEBinop (op :: TOp ty1 ty2 ty) (e1 :: TExp ty1 a) (e2 :: TExp ty2 a) == TEBinop (op' :: TOp ty3 ty4 ty) e1' e2' = + TEBinop (op :: TOp ty1 ty2 ty) (e1 :: TExp ty1 k) (e2 :: TExp ty2 k) == TEBinop (op' :: TOp ty3 ty4 ty) e1' e2' = case (eqT @ty1 @ty3, eqT @ty2 @ty4) of (Just Refl, Just Refl) -> op == op' && e1 == e1' && e2 == e2' _ -> False TEIf e e1 e2 == TEIf e' e1' e2' = e == e' && e1 == e1' && e2 == e2' - TEAssert (e1 :: TExp ty1 a) (e2 :: TExp ty1 a) == TEAssert (e1' :: TExp ty2 a) (e2' :: TExp ty2 a) = + TEAssert (e1 :: TExp ty1 k) (e2 :: TExp ty1 a) == TEAssert (e1' :: TExp ty2 k) (e2' :: TExp ty2 k) = case eqT @ty1 @ty2 of Just Refl -> e1 == e1' && e2 == e2' Nothing -> False @@ -124,7 +124,7 @@ instance (Eq a) => Eq (TExp (b :: Ty) a) where TEBot == TEBot = True _ == _ = False -instance (Pretty a, Typeable ty) => Pretty (TExp ty a) where +instance Pretty (TExp ty a) where pretty (TEVar var) = pretty var pretty (TEVal val) = pretty val pretty (TEUnop unop _exp) = pretty unop <+> pretty _exp @@ -137,10 +137,9 @@ instance (Pretty a, Typeable ty) => Pretty (TExp ty a) where pretty (TEApp exp1 exp2) = parens (pretty exp1 <+> pretty exp2) tExpToLambdaExp :: - (GaloisField a) => - (Typeable ty) => - TExp ty a -> - LE.Exp a + (GaloisField k) => + TExp ty k -> + LE.Exp k tExpToLambdaExp te = case te of TEVar (TVar x) -> LE.EVar x TEVal v -> lambdaExpOfVal v @@ -157,7 +156,7 @@ tExpToLambdaExp te = case te of TEAbs (TVar v) e -> LE.EAbs v (tExpToLambdaExp e) TEApp e1 e2 -> LE.EApp (tExpToLambdaExp e1) (tExpToLambdaExp e2) where - lambdaExpOfVal :: (GaloisField a) => Val ty a -> LE.Exp a + lambdaExpOfVal :: (GaloisField k) => Val ty k -> LE.Exp k lambdaExpOfVal v = case v of VField c -> LE.EVal c VTrue -> LE.EVal 1 @@ -169,7 +168,6 @@ tExpToLambdaExp te = case te of -- whenever the normal form of 'te1' (with seq's reassociated right) -- is *not* equal 'TEAssert _ _'. teSeq :: - (Typeable ty1) => TExp ty1 a -> TExp ty2 a -> TExp ty2 a @@ -201,7 +199,6 @@ booleanVarsOfTexp = go [] go vars (TEApp e1 e2) = go (go vars e1) e2 varOfTExp :: - (Show a) => TExp ty a -> Variable varOfTExp te = case lastSeq te of @@ -209,7 +206,6 @@ varOfTExp te = case lastSeq te of _ -> failWith $ ErrMsg ("varOfTExp: expected var: " ++ show te) locOfTexp :: - (Show a) => TExp ty a -> Loc locOfTexp te = case lastSeq te of @@ -221,4 +217,4 @@ lastSeq :: TExp ty a lastSeq te = case te of TESeq _ te2 -> lastSeq te2 - _ -> te + _ -> te \ No newline at end of file diff --git a/src/Snarkl/Language/Type.hs b/src/Snarkl/Language/Type.hs index e7203f3..145b2f9 100644 --- a/src/Snarkl/Language/Type.hs +++ b/src/Snarkl/Language/Type.hs @@ -8,7 +8,7 @@ module Snarkl.Language.Type where import Data.Typeable (Typeable) -import Prettyprinter (Pretty (pretty), parens, (<+>)) +import Text.PrettyPrint.Leijen.Text (Pretty (pretty), parens, (<+>)) data TFunct where TFConst :: Ty -> TFunct diff --git a/src/Snarkl/Toplevel.hs b/src/Snarkl/Toplevel.hs index 53503ab..99df998 100644 --- a/src/Snarkl/Toplevel.hs +++ b/src/Snarkl/Toplevel.hs @@ -22,7 +22,6 @@ import Data.Field.Galois (GaloisField, PrimeField) import Data.List (intercalate) import qualified Data.Map as Map import Data.Typeable (Typeable) -import Prettyprinter (Pretty (..), line, (<+>)) import Snarkl.Backend.R1CS import Snarkl.Common (Assgn) import Snarkl.Compile @@ -30,6 +29,7 @@ import Snarkl.Constraint import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Interp (interp) import Snarkl.Language +import Text.PrettyPrint.Leijen.Text (Pretty (..), line, (<+>)) import Prelude ---------------------------------------------------- @@ -64,7 +64,7 @@ data Result k = Result } deriving (Show) -instance (Pretty k) => Pretty (Result k) where +instance Pretty (Result k) where pretty (Result sat vars constraints result _ _) = mconcat $ intercalate diff --git a/tests/Test/Snarkl/UnitSpec.hs b/tests/Test/Snarkl/UnitSpec.hs index 25bf1ff..04c9800 100644 --- a/tests/Test/Snarkl/UnitSpec.hs +++ b/tests/Test/Snarkl/UnitSpec.hs @@ -5,7 +5,6 @@ module Test.Snarkl.UnitSpec where import Data.Field.Galois (PrimeField) import Data.Typeable (Typeable) -import Prettyprinter (Pretty) import Snarkl.Compile import Snarkl.Example.Keccak import Snarkl.Example.Lam @@ -20,6 +19,7 @@ import System.Exit (ExitCode (..)) import Test.ArkworksBridge (CMD (RunR1CS), runCMD) import Test.Hspec (Spec, describe, it, shouldBe, shouldReturn) import Test.Snarkl.Unit.Programs +import Text.PrettyPrint.Leijen.Text (Pretty) import Prelude test_comp :: (Typeable ty, Pretty k, PrimeField k) => SimplParam -> Comp ty k -> [k] -> IO (Either ExitCode k) From 438a71dcb4090e7b83846fcd1de8be829a22c1aa Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 22:32:40 -0800 Subject: [PATCH 12/19] Redundant constraints --- examples/Snarkl/Example/Basic.hs | 2 +- examples/Snarkl/Example/Games.hs | 30 +++++++++++------------------- examples/Snarkl/Example/Queue.hs | 6 ++---- snarkl.cabal | 2 ++ src/Snarkl/Common.hs | 2 +- src/Snarkl/Compile.hs | 14 +++++++------- src/Snarkl/Field.hs | 2 +- src/Snarkl/Interp.hs | 16 +++++++--------- src/Snarkl/Language.hs | 3 +-- src/Snarkl/Language/Core.hs | 6 +++++- src/Snarkl/Language/Expr.hs | 2 +- src/Snarkl/Language/LambdaExpr.hs | 4 ++-- src/Snarkl/Language/Syntax.hs | 2 +- src/Snarkl/Language/TExpr.hs | 2 +- src/Snarkl/Toplevel.hs | 5 +++-- tests/Test/ArkworksBridge.hs | 10 +++++++++- tests/Test/Snarkl/DataflowSpec.hs | 8 +++----- tests/Test/Snarkl/LambdaSpec.hs | 29 +++++++++++++---------------- tests/Test/Snarkl/Unit/Programs.hs | 5 +---- tests/Test/Snarkl/UnitSpec.hs | 16 +++++----------- tests/Test/UnionFindSpec.hs | 16 +++++++++++++--- 21 files changed, 90 insertions(+), 92 deletions(-) diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index 6572b3d..7a86fd1 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -73,4 +73,4 @@ test1 :: (GaloisField k) => State (Env k) (TExp 'TBool k) test1 = do b <- fresh_input z <- if return b then comp1 else comp2 - case_sum return (const $ return false) z \ No newline at end of file + case_sum return (const $ return false) z diff --git a/examples/Snarkl/Example/Games.hs b/examples/Snarkl/Example/Games.hs index e152e2c..0d0bf4e 100644 --- a/examples/Snarkl/Example/Games.hs +++ b/examples/Snarkl/Example/Games.hs @@ -6,9 +6,9 @@ module Snarkl.Example.Games where -import Data.Field.Galois (GaloisField, Prime) +import Data.Field.Galois (GaloisField) +import Data.Kind (Type) import Data.Typeable -import GHC.TypeLits (KnownNat, Nat) import Snarkl.Errors import Snarkl.Field (F_BN128) import Snarkl.Language.Syntax @@ -38,7 +38,7 @@ data ISO (t :: Ty) (s :: Ty) k = Iso from :: TExp s k -> Comp t k } -data Game :: Ty -> * -> * where +data Game :: Ty -> Type -> Type where Single :: forall (s :: Ty) (t :: Ty) k. ( Typeable s, @@ -108,8 +108,7 @@ sum_game :: Zippable t1 k, Zippable t2 k, Derive t1 k, - Derive t2 k, - GaloisField k + Derive t2 k ) => Game t1 k -> Game t2 k -> @@ -133,9 +132,7 @@ t2 :: F_BN128 t2 = comp_interp basic_test [1, 23, 88] -- 88 (+>) :: - ( Typeable t, - Typeable s, - Zippable t k, + ( Typeable s, Zippable s k ) => Game t k -> @@ -151,8 +148,7 @@ prodI :: ( Typeable a, Typeable b, Typeable c, - Typeable d, - GaloisField k + Typeable d ) => ISO a b k -> ISO c d k -> @@ -174,13 +170,12 @@ prodI (Iso f g) (Iso f' g') = pair y1 y2 ) -seqI :: (Typeable b) => ISO a b p -> ISO b c p -> ISO a c p +seqI :: ISO a b p -> ISO b c p -> ISO a c p seqI (Iso f g) (Iso f' g') = Iso (\a -> f a >>= f') (\c -> g' c >>= g) prodLInputI :: ( Typeable a, - Typeable b, - GaloisField k + Typeable b ) => ISO ('TProd a b) b k prodLInputI = @@ -200,8 +195,7 @@ prodLSumI :: Zippable c k, Derive a k, Derive b k, - Derive c k, - GaloisField k + Derive c k ) => ISO ('TProd ('TSum b c) a) ('TSum ('TProd b a) ('TProd c a)) k prodLSumI = @@ -301,8 +295,7 @@ instance (GaloisField k) => Gameable 'TUnit k where mkGame = unit_game instance - ( Typeable a, - Typeable b, + ( Typeable b, Zippable a k, Zippable b k, Derive a k, @@ -323,8 +316,7 @@ instance Derive a k, Derive b k, Gameable a k, - Gameable b k, - GaloisField k + Gameable b k ) => Gameable ('TSum a b) k where diff --git a/examples/Snarkl/Example/Queue.hs b/examples/Snarkl/Example/Queue.hs index b306692..ed06e12 100644 --- a/examples/Snarkl/Example/Queue.hs +++ b/examples/Snarkl/Example/Queue.hs @@ -2,10 +2,8 @@ module Snarkl.Example.Queue where -import Data.Field.Galois (GaloisField, Prime) +import Data.Field.Galois (GaloisField) import Data.Typeable -import GHC.TypeLits (KnownNat) -import Snarkl.Compile import Snarkl.Example.List import Snarkl.Example.Stack import Snarkl.Language.Syntax @@ -36,7 +34,7 @@ empty_queue = do pair l r enqueue :: - (Zippable a k, Derive a k, Typeable a, GaloisField k) => + (Typeable a, GaloisField k) => TExp a k -> Queue a k -> Comp (TQueue a) k diff --git a/snarkl.cabal b/snarkl.cabal index 183b982..8a7091c 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -141,6 +141,8 @@ test-suite spec , process >=1.2 , QuickCheck , snarkl >=0.1.0.0 + ghc-options: + -Wredundant-constraints benchmark criterion type: exitcode-stdio-1.0 diff --git a/src/Snarkl/Common.hs b/src/Snarkl/Common.hs index 374e523..953d841 100644 --- a/src/Snarkl/Common.hs +++ b/src/Snarkl/Common.hs @@ -67,4 +67,4 @@ isAssoc op = case op of Or -> True XOr -> True Eq -> True - BEq -> True \ No newline at end of file + BEq -> True diff --git a/src/Snarkl/Compile.hs b/src/Snarkl/Compile.hs index 124eca4..d11957a 100644 --- a/src/Snarkl/Compile.hs +++ b/src/Snarkl/Compile.hs @@ -249,19 +249,19 @@ encode_binop op (x, y, z) = go op add_constraint $ CMult (1, y) (1, z) (1, Just x) -encode_linear :: (GaloisField a) => Var -> [Either (Var, a) a] -> State (CEnv a) () +encode_linear :: (GaloisField k) => Var -> [Either (Var, k) k] -> State (CEnv k) () encode_linear out xs = let c = foldl (flip (+)) 0 $ map (fromRight 0) xs in add_constraint $ cadd c $ (out, -1) : remove_consts xs where - remove_consts :: [Either (Var, a) a] -> [(Var, a)] + remove_consts :: [Either (Var, k) k] -> [(Var, k)] remove_consts [] = [] remove_consts (Left p : l) = p : remove_consts l remove_consts (Right _ : l) = remove_consts l -cs_of_exp :: (GaloisField a) => Var -> Core.Exp a -> State (CEnv a) () +cs_of_exp :: (GaloisField k) => Var -> Core.Exp k -> State (CEnv k) () cs_of_exp out e = case e of Core.EVar x -> ensure_equal (out, view _Var x) @@ -287,7 +287,7 @@ cs_of_exp out e = case e of -- We special-case linear combinations in this way to avoid having -- to introduce new multiplication gates for multiplication by -- constant scalars. - let go_linear :: (GaloisField a) => [Core.Exp a] -> State (CEnv a) [Either (Var, a) a] + let go_linear :: (GaloisField k) => [Core.Exp k] -> State (CEnv k) [Either (Var, k) k] go_linear [] = return [] go_linear (Core.EBinop Mult [Core.EVar x, Core.EVal coeff] : es') = do @@ -338,7 +338,7 @@ cs_of_exp out e = case e of rev_pol (Left (x, c) : ls) = Left (x, -c) : rev_pol ls rev_pol (Right c : ls) = Right (-c) : rev_pol ls - go_other :: (GaloisField a) => [Core.Exp a] -> State (CEnv a) [Var] + go_other :: (GaloisField k) => [Core.Exp k] -> State (CEnv k) [Var] go_other [] = return [] go_other (Core.EVar x : es') = do @@ -457,10 +457,10 @@ data TExpPkg ty k = TExpPkg } deriving (Show) -instance (Typeable ty) => Pretty (TExpPkg ty k) where +instance Pretty (TExpPkg ty k) where pretty (TExpPkg _ _ e) = pretty e -deriving instance (Eq (TExp ty k)) => Eq (TExpPkg ty k) +deriving instance Eq (TExpPkg ty k) -- | Desugar a 'Comp'utation to a pair of: -- the total number of vars, diff --git a/src/Snarkl/Field.hs b/src/Snarkl/Field.hs index f8f6e9f..1f193af 100644 --- a/src/Snarkl/Field.hs +++ b/src/Snarkl/Field.hs @@ -6,4 +6,4 @@ import Data.Field.Galois (Prime) type P_BN128 = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -type F_BN128 = Prime P_BN128 \ No newline at end of file +type F_BN128 = Prime P_BN128 diff --git a/src/Snarkl/Interp.hs b/src/Snarkl/Interp.hs index ba269e7..4d66dd5 100644 --- a/src/Snarkl/Interp.hs +++ b/src/Snarkl/Interp.hs @@ -4,7 +4,6 @@ module Snarkl.Interp where import Control.Monad (ap, foldM) -import Data.Data (Typeable) import Data.Field.Galois (GaloisField) import Data.Foldable (traverse_) import Data.Map (Map) @@ -77,20 +76,19 @@ boolOfField v = ) interpTExp :: - ( GaloisField a, - Typeable ty + ( GaloisField k ) => - TExp ty a -> - InterpM a (Maybe a) + TExp ty k -> + InterpM k (Maybe k) interpTExp e = do let _exp = compileTExpToProgram e interpProg _exp interp :: - (GaloisField a, Typeable ty) => - Map Variable a -> - TExp ty a -> - Either ErrMsg (Env a, Maybe a) + (GaloisField k) => + Map Variable k -> + TExp ty k -> + Either ErrMsg (Env k, Maybe k) interp rho e = runInterpM (interpTExp e) $ Map.map Just rho interpProg :: diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index 886eee7..ca413c9 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -85,7 +85,6 @@ module Snarkl.Language ) where -import Data.Data (Typeable) import Data.Field.Galois (GaloisField) import Snarkl.Language.Core ( Assignment (..), @@ -166,7 +165,7 @@ import Snarkl.Language.TExpr (TExp, booleanVarsOfTexp, tExpToLambdaExp) import Snarkl.Language.Type import Prelude (Either (..), error, ($), (.), (<>)) -compileTExpToProgram :: (GaloisField a, Typeable ty) => TExp ty a -> Program a +compileTExpToProgram :: (GaloisField k) => TExp ty k -> Program k compileTExpToProgram te = let eprog = mkProgram . expOfLambdaExp . tExpToLambdaExp $ te in case eprog of diff --git a/src/Snarkl/Language/Core.hs b/src/Snarkl/Language/Core.hs index 4a75db2..9dffa2e 100644 --- a/src/Snarkl/Language/Core.hs +++ b/src/Snarkl/Language/Core.hs @@ -13,9 +13,13 @@ data Exp :: Type -> Type where EVal :: (GaloisField k) => k -> Exp k EUnop :: UnOp -> Exp k -> Exp k EBinop :: Op -> [Exp k] -> Exp k - EIf :: Exp k -> Exp a -> Exp k -> Exp k + EIf :: Exp k -> Exp k -> Exp k -> Exp k EUnit :: Exp k +deriving instance Eq (Exp k) + +deriving instance Show (Exp k) + data Assignment a = Assignment Variable (Exp a) data Program :: Type -> Type where diff --git a/src/Snarkl/Language/Expr.hs b/src/Snarkl/Language/Expr.hs index 4480513..287e34e 100644 --- a/src/Snarkl/Language/Expr.hs +++ b/src/Snarkl/Language/Expr.hs @@ -152,7 +152,7 @@ mkProgram _exp = do let (eexpr, assignments) = runState (runExceptT $ go es) mempty Core.Program assignments <$> eexpr where - go :: (Show k) => Seq (Exp k) -> ExceptT String (State (Seq (Core.Assignment k))) (Core.Exp k) + go :: Seq (Exp k) -> ExceptT String (State (Seq (Core.Assignment k))) (Core.Exp k) go = \case Empty -> throwError "mkProgram: empty sequence" e :<| Empty -> hoistEither $ mkExpression e diff --git a/src/Snarkl/Language/LambdaExpr.hs b/src/Snarkl/Language/LambdaExpr.hs index 897383f..d1b79cf 100644 --- a/src/Snarkl/Language/LambdaExpr.hs +++ b/src/Snarkl/Language/LambdaExpr.hs @@ -38,7 +38,7 @@ deriving instance Show (Exp k) deriving instance Eq (Exp k) -betaNormalize :: Exp a -> Exp a +betaNormalize :: Exp k -> Exp k betaNormalize = \case EVar x -> EVar x EVal v -> EVal v @@ -55,7 +55,7 @@ betaNormalize = \case EUnit -> EUnit where -- substitute x e1 e2 = e2 [x := e1 ] - substitute :: (Variable, Exp a) -> Exp a -> Exp a + substitute :: (Variable, Exp k) -> Exp k -> Exp k substitute (var, e1) = \case e@(EVar var') -> if var == var' then e1 else e e@(EVal _) -> e diff --git a/src/Snarkl/Language/Syntax.hs b/src/Snarkl/Language/Syntax.hs index 7cae059..1f5e419 100644 --- a/src/Snarkl/Language/Syntax.hs +++ b/src/Snarkl/Language/Syntax.hs @@ -768,4 +768,4 @@ uncurry f p = do return $ TEApp g y apply :: (Typeable a, Typeable b) => TExp ('TFun a b) k -> TExp a k -> Comp b k -apply f x = return $ TEApp f x +apply f x = return $ TEApp f x \ No newline at end of file diff --git a/src/Snarkl/Language/TExpr.hs b/src/Snarkl/Language/TExpr.hs index 7d97d81..36c294e 100644 --- a/src/Snarkl/Language/TExpr.hs +++ b/src/Snarkl/Language/TExpr.hs @@ -217,4 +217,4 @@ lastSeq :: TExp ty a lastSeq te = case te of TESeq _ te2 -> lastSeq te2 - _ -> te \ No newline at end of file + _ -> te diff --git a/src/Snarkl/Toplevel.hs b/src/Snarkl/Toplevel.hs index 99df998..4f24429 100644 --- a/src/Snarkl/Toplevel.hs +++ b/src/Snarkl/Toplevel.hs @@ -41,7 +41,8 @@ import Prelude -- | Using the executable semantics for the 'TExp' language, execute -- the computation on the provided inputs, returning the 'k' result. comp_interp :: - (Typeable ty, GaloisField k) => + forall ty k. + (GaloisField k) => Comp ty k -> [k] -> k @@ -64,7 +65,7 @@ data Result k = Result } deriving (Show) -instance Pretty (Result k) where +instance (Pretty k) => Pretty (Result k) where pretty (Result sat vars constraints result _ _) = mconcat $ intercalate diff --git a/tests/Test/ArkworksBridge.hs b/tests/Test/ArkworksBridge.hs index e1e4703..f847a9b 100644 --- a/tests/Test/ArkworksBridge.hs +++ b/tests/Test/ArkworksBridge.hs @@ -4,8 +4,16 @@ import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (GaloisField, PrimeField) import Data.Typeable (Typeable) import Snarkl.Backend.R1CS + ( mkInputsFilePath, + mkR1CSFilePath, + mkWitnessFilePath, + serializeInputsAsJson, + serializeR1CSAsJson, + serializeWitnessAsJson, + wit_of_r1cs, + ) import Snarkl.Compile (SimplParam, compileCompToR1CS) -import Snarkl.Language (Comp) +import Snarkl.Language.SyntaxMonad (Comp) import qualified System.Exit as GHC import System.Process (createProcess, shell, waitForProcess) diff --git a/tests/Test/Snarkl/DataflowSpec.hs b/tests/Test/Snarkl/DataflowSpec.hs index e4a5384..5dc2d1b 100644 --- a/tests/Test/Snarkl/DataflowSpec.hs +++ b/tests/Test/Snarkl/DataflowSpec.hs @@ -2,9 +2,7 @@ module Test.Snarkl.DataflowSpec where -import Data.Field.Galois (GaloisField, Prime, PrimeField, toP) -import qualified Data.IntMap as IntMap -import qualified Data.Map as Map +import Data.Field.Galois (Prime, toP) import qualified Data.Set as Set import GHC.TypeLits (KnownNat) import Snarkl.Common (Var (Var)) @@ -30,8 +28,8 @@ constraint2 :: (KnownNat p) => Constraint (Prime p) constraint2 = CMult (toP 2, Var 1) (toP 3, Var 2) (toP 4, Just $ Var 3) -- NOTE: notice 4 doesn't count as a variable here, WHY? -constraint3 :: (GaloisField k) => Constraint k -constraint3 = CMagic (Var 4) [Var 2, Var 3] $ \vars -> return True +constraint3 :: Constraint k +constraint3 = CMagic (Var 4) [Var 2, Var 3] $ \_ -> return True -- 4 is independent from 1,2,3 constraint4 :: (KnownNat p) => Constraint (Prime p) diff --git a/tests/Test/Snarkl/LambdaSpec.hs b/tests/Test/Snarkl/LambdaSpec.hs index 5feedb8..08f1f5d 100644 --- a/tests/Test/Snarkl/LambdaSpec.hs +++ b/tests/Test/Snarkl/LambdaSpec.hs @@ -5,12 +5,7 @@ {-# HLINT ignore "Redundant uncurry" #-} module Test.Snarkl.LambdaSpec where -import Data.Field.Galois (GaloisField, Prime) -import qualified Data.Map as Map -import GHC.TypeLits (KnownNat) -import Snarkl.Field -import Snarkl.Interp (interp) -import Snarkl.Language (TExp, Ty (TField, TFun, TProd)) +import Snarkl.Field (F_BN128) import Snarkl.Language.Syntax ( apply, curry, @@ -21,27 +16,29 @@ import Snarkl.Language.Syntax (+), ) import qualified Snarkl.Language.SyntaxMonad as SM +import Snarkl.Language.TExpr (TExp) +import Snarkl.Language.Type (Ty (TField, TFun, TProd)) import Snarkl.Toplevel (comp_interp) -import Test.Hspec (Spec, describe, it, shouldBe) +import Test.Hspec (Spec, describe, it) import Test.QuickCheck (Testable (property)) -import Prelude hiding (apply, curry, return, uncurry, (*), (+)) +import Prelude hiding (curry, return, uncurry, (*), (+)) spec :: Spec spec = do describe "Snarkl.Lambda" $ do describe "curry/uncurry identities for simply operations" $ do it "curry . uncurry == id" $ do - let f :: (GaloisField k) => TExp 'TField k -> SM.Comp ('TFun 'TField 'TField) k + let f :: TExp 'TField k -> SM.Comp ('TFun 'TField 'TField) k f x = lambda $ \y -> SM.return (x + y) - g :: (GaloisField k) => TExp 'TField k -> SM.Comp ('TFun 'TField 'TField) k + g :: TExp 'TField k -> SM.Comp ('TFun 'TField 'TField) k g = curry (uncurry f) - prog1 :: (GaloisField k) => SM.Comp 'TField k + prog1 :: SM.Comp 'TField k prog1 = SM.fresh_input SM.>>= \a -> SM.fresh_input SM.>>= \b -> f a SM.>>= \k -> apply k b - prog2 :: (GaloisField k) => SM.Comp 'TField k + prog2 :: SM.Comp 'TField k prog2 = SM.fresh_input SM.>>= \a -> SM.fresh_input SM.>>= \b -> @@ -51,21 +48,21 @@ spec = do comp_interp @_ @F_BN128 prog1 [a, b] == comp_interp prog2 [a, b] it "uncurry . curry == id" $ do - let f :: (GaloisField k) => TExp ('TProd 'TField 'TField) k -> SM.Comp 'TField k + let f :: TExp ('TProd 'TField 'TField) k -> SM.Comp 'TField k f p = SM.fst_pair p SM.>>= \x -> SM.snd_pair p SM.>>= \y -> SM.return (x * y) - g :: (GaloisField k) => TExp ('TProd 'TField 'TField) k -> SM.Comp 'TField k + g :: TExp ('TProd 'TField 'TField) k -> SM.Comp 'TField k g = uncurry (curry f) - prog1 :: (GaloisField k) => SM.Comp 'TField k + prog1 :: SM.Comp 'TField k prog1 = SM.fresh_input SM.>>= \a -> SM.fresh_input SM.>>= \b -> pair a b SM.>>= \p -> f p - prog2 :: (GaloisField k) => SM.Comp 'TField k + prog2 :: SM.Comp 'TField k prog2 = SM.fresh_input SM.>>= \a -> SM.fresh_input SM.>>= \b -> diff --git a/tests/Test/Snarkl/Unit/Programs.hs b/tests/Test/Snarkl/Unit/Programs.hs index 0202751..df1d889 100644 --- a/tests/Test/Snarkl/Unit/Programs.hs +++ b/tests/Test/Snarkl/Unit/Programs.hs @@ -8,14 +8,11 @@ module Test.Snarkl.Unit.Programs where -import Data.Field.Galois (GaloisField, Prime) -import Snarkl.Compile -import Snarkl.Example.Keccak import Snarkl.Example.Lam import Snarkl.Example.List import Snarkl.Example.Peano import Snarkl.Example.Tree -import Snarkl.Field (F_BN128, P_BN128) +import Snarkl.Field (F_BN128) import Snarkl.Language.Syntax import Snarkl.Language.SyntaxMonad import Snarkl.Language.TExpr diff --git a/tests/Test/Snarkl/UnitSpec.hs b/tests/Test/Snarkl/UnitSpec.hs index 04c9800..6e356ac 100644 --- a/tests/Test/Snarkl/UnitSpec.hs +++ b/tests/Test/Snarkl/UnitSpec.hs @@ -7,22 +7,16 @@ import Data.Field.Galois (PrimeField) import Data.Typeable (Typeable) import Snarkl.Compile import Snarkl.Example.Keccak -import Snarkl.Example.Lam -import Snarkl.Example.List -import Snarkl.Example.Peano -import Snarkl.Example.Tree import Snarkl.Field -import Snarkl.Language (Comp) -import Snarkl.Language.Syntax hiding (negate) +import Snarkl.Language (Comp, Ty (..)) import Snarkl.Toplevel (Result (result_result), execute) import System.Exit (ExitCode (..)) import Test.ArkworksBridge (CMD (RunR1CS), runCMD) -import Test.Hspec (Spec, describe, it, shouldBe, shouldReturn) +import Test.Hspec (Spec, describe, it, shouldReturn) import Test.Snarkl.Unit.Programs -import Text.PrettyPrint.Leijen.Text (Pretty) import Prelude -test_comp :: (Typeable ty, Pretty k, PrimeField k) => SimplParam -> Comp ty k -> [k] -> IO (Either ExitCode k) +test_comp :: forall ty k. (Typeable ty, PrimeField k) => SimplParam -> Comp ty k -> [k] -> IO (Either ExitCode k) test_comp simpl mf args = do exit_code <- runCMD $ RunR1CS "./scripts" "hspec" simpl mf args @@ -72,7 +66,7 @@ spec = do it "8-1" $ test_comp @_ @F_BN128 Simplify prog8 [] `shouldReturn` Right 29 describe "unused inputs" $ do - it "11-1" $ test_comp @_ @F_BN128 Simplify prog11 [1, 1] `shouldReturn` Right 1 + it "11-1" $ test_comp @'TField @F_BN128 Simplify prog11 [1, 1] `shouldReturn` Right 1 describe "multiplicative identity" $ do it "13-1" $ test_comp @_ @F_BN128 Simplify prog13 [1] `shouldReturn` Right 1 @@ -147,7 +141,7 @@ spec = do it "8-1" $ test_comp @_ @F_BN128 Simplify prog8 [] `shouldReturn` Right 29 describe "unused inputs" $ do - it "11-1" $ test_comp @_ @F_BN128 Simplify prog11 [1, 1] `shouldReturn` Right 1 + it "11-1" $ test_comp @'TField @F_BN128 Simplify prog11 [1, 1] `shouldReturn` Right 1 describe "multiplicative identity" $ do it "13-1" $ test_comp @_ @F_BN128 Simplify prog13 [1] `shouldReturn` Right 1 diff --git a/tests/Test/UnionFindSpec.hs b/tests/Test/UnionFindSpec.hs index a932eb9..ee9774a 100644 --- a/tests/Test/UnionFindSpec.hs +++ b/tests/Test/UnionFindSpec.hs @@ -6,11 +6,21 @@ module Test.UnionFindSpec where -import Snarkl.Common +import Snarkl.Common (Var (..)) import Snarkl.Constraint.UnionFind -import Snarkl.Errors -import Test.Hspec + ( UnionFind, + empty, + insert, + root, + unite, + ) +import Test.Hspec (Spec, describe, it, shouldBe) import Test.QuickCheck + ( Arbitrary (arbitrary), + Testable (property), + forAll, + suchThat, + ) spec :: Spec spec = do From d48e0b0249f2c7a386e01cec299ed9a0fa42d5cc Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 22:37:20 -0800 Subject: [PATCH 13/19] remove print exec --- flake.nix | 2 -- print-examples/Main.hs | 69 ------------------------------------ snarkl.cabal | 79 +++++++++++++----------------------------- 3 files changed, 24 insertions(+), 126 deletions(-) delete mode 100644 print-examples/Main.hs diff --git a/flake.nix b/flake.nix index 13af1e5..3aa2fef 100644 --- a/flake.nix +++ b/flake.nix @@ -30,12 +30,10 @@ packages = { lib = flake.packages."snarkl:lib:snarkl"; - print = flake.packages."snarkl:exe:print-examples"; all = pkgs.symlinkJoin { name = "all"; paths = with packages; [ lib - print ]; }; default = packages.all; diff --git a/print-examples/Main.hs b/print-examples/Main.hs deleted file mode 100644 index 91c579d..0000000 --- a/print-examples/Main.hs +++ /dev/null @@ -1,69 +0,0 @@ -{-# LANGUAGE FlexibleContexts #-} - -module Main where - -import Data.Foldable (traverse_) -import Snarkl.Field () -import Snarkl.Toplevel (compileCompToTexp) -import Test.Snarkl.Unit.Programs -import Text.PrettyPrint.Leijen.Text -import Text.PrettyPrint.Leijen.Text.Render.String (renderString) - -main :: IO () -main = do - traverse_ - printProg - [ ("Program 1", prog1), - ("Program 2", prog2 42), - ("Program 3", prog3), - ("Program 4", prog4), - ("Program 5", prog5), - ("Program 6", prog6), - ("Program 7", prog7), - ("Program 8", prog8), - ("Program 11", prog11), - ("Program 13", prog13), - ("Program 14", prog14), - ("Program 15", prog15), - ("Program 26", prog26), - ("Program 27", prog27), - ("Program 28", prog28), - ("Program 29", prog29), - ("Program 30", prog30), - ("Program 31", prog31), - ("Program 34", prog34), - ("Program 35", prog35), - ("Program 36", prog36) - ] - traverse_ - printProg - [ ("Bool Program 10", bool_prog10), - ("Bool Program 12", bool_prog12), - ("Bool Program 16", bool_prog16), - ("Bool Program 17", bool_prog17), - ("Bool Program 18", bool_prog18), - ("Bool Program 19", bool_prog19), - ("Bool Program 20", bool_prog20), - ("Bool Program 21", bool_prog21), - ("Bool Program 22", bool_prog22), - ("Bool Program 23", bool_prog23), - ("Bool Program 24", bool_prog24), - ("Bool Program 25", bool_prog25), - ("Bool Program 32", bool_prog32), - ("Bool Program 33", bool_prog33) - ] - where - printProg (name, prog) = do - let texp = compileCompToTexp prog - -- this is just a sanity check because of the new Eq instances - if texp /= texp - then putStrLn ("--| " <> name <> " (FAILED)") - else do - let doc = pretty texp - layout = layoutPretty defaultLayoutOptions doc - putStrLn (replicate 80 '-') - putStrLn "\n" - putStrLn ("--| " <> name) - putStrLn "\n" - putStrLn $ renderString layout - putStrLn "\n" diff --git a/snarkl.cabal b/snarkl.cabal index 8a7091c..dddc05b 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -80,18 +80,18 @@ library build-depends: aeson - , base >=4.7 + , base >=4.7 , bytestring - , Cabal >=1.22 - , containers >=0.5 && <0.7 + , Cabal >=1.22 + , containers >=0.5 && <0.7 , errors - , galois-field >=1.0.4 - , hspec >=2.0 - , jsonl >=0.1.4 + , galois-field >=1.0.4 + , hspec >=2.0 + , jsonl >=0.1.4 , lens - , mtl >=2.2 && <2.3 - , parallel >=3.2 && <3.3 - , process >=1.2 + , mtl >=2.2 && <2.3 + , parallel >=3.2 && <3.3 + , process >=1.2 , transformers , wl-pprint-text @@ -129,20 +129,20 @@ test-suite spec hs-source-dirs: tests examples default-language: Haskell2010 build-depends: - base >=4.7 + base >=4.7 , bytestring - , Cabal >=1.22 - , containers >=0.5 && <0.6 - , criterion >=1.0 - , galois-field >=1.0.4 - , hspec >=2.0 - , mtl >=2.2 && <2.3 - , parallel >=3.2 && <3.3 - , process >=1.2 + , Cabal >=1.22 + , containers >=0.5 && <0.6 + , criterion >=1.0 + , galois-field >=1.0.4 + , hspec >=2.0 + , mtl >=2.2 && <2.3 + , parallel >=3.2 && <3.3 + , process >=1.2 , QuickCheck - , snarkl >=0.1.0.0 - ghc-options: - -Wredundant-constraints + , snarkl >=0.1.0.0 + + ghc-options: -Wredundant-constraints benchmark criterion type: exitcode-stdio-1.0 @@ -190,37 +190,6 @@ benchmark criterion , process >=1.2 , snarkl >=0.1.0.0 -executable print-examples - main-is: Main.hs - other-modules: - Snarkl.Example.Basic - Snarkl.Example.Games - Snarkl.Example.Keccak - Snarkl.Example.Lam - Snarkl.Example.List - Snarkl.Example.Matrix - Snarkl.Example.Peano - Snarkl.Example.Queue - Snarkl.Example.Stack - Snarkl.Example.Tree - Test.Snarkl.Unit.Programs - - default-extensions: - DataKinds - GADTs - KindSignatures - RankNTypes - ScopedTypeVariables - - hs-source-dirs: print-examples examples tests - default-language: Haskell2010 - build-depends: - base >=4.7 - , containers - , galois-field >=1.0.4 - , hspec >=2.0 - , snarkl >=0.1.0.0 - executable compile main-is: Main.hs other-modules: @@ -246,8 +215,8 @@ executable compile hs-source-dirs: app examples tests default-language: Haskell2010 build-depends: - base >=4.7 + base >=4.7 , bytestring , containers - , galois-field >=1.0.4 - , snarkl >=0.1.0.0 \ No newline at end of file + , galois-field >=1.0.4 + , snarkl >=0.1.0.0 From 7f06f7ae7771cffad2d751edf6983cce57d51b16 Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 22:37:59 -0800 Subject: [PATCH 14/19] lint --- src/Snarkl/Language/Syntax.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Snarkl/Language/Syntax.hs b/src/Snarkl/Language/Syntax.hs index 1f5e419..7cae059 100644 --- a/src/Snarkl/Language/Syntax.hs +++ b/src/Snarkl/Language/Syntax.hs @@ -768,4 +768,4 @@ uncurry f p = do return $ TEApp g y apply :: (Typeable a, Typeable b) => TExp ('TFun a b) k -> TExp a k -> Comp b k -apply f x = return $ TEApp f x \ No newline at end of file +apply f x = return $ TEApp f x From 62c2fdbc73b231e3270dd50f12ea0eeffa8a2e7d Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 22:45:35 -0800 Subject: [PATCH 15/19] remove redundant constraints --- examples/Snarkl/Example/Basic.hs | 6 +++--- examples/Snarkl/Example/Keccak.hs | 3 +-- examples/Snarkl/Example/Lam.hs | 11 +++-------- examples/Snarkl/Example/List.hs | 28 +++++++++------------------- examples/Snarkl/Example/Peano.hs | 4 ++-- examples/Snarkl/Example/Queue.hs | 10 +++++----- examples/Snarkl/Example/Stack.hs | 10 +++++----- examples/Snarkl/Example/Tree.hs | 9 +++------ 8 files changed, 31 insertions(+), 50 deletions(-) diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index 7a86fd1..f601911 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -31,7 +31,7 @@ mult_ex :: Comp 'TField k mult_ex x y = return $ x * y -arr_ex :: (GaloisField k) => TExp 'TField k -> Comp 'TField k +arr_ex :: TExp 'TField k -> Comp 'TField k arr_ex x = do a <- arr 2 forall [0 .. 1] (\i -> set (a, i) x) @@ -63,13 +63,13 @@ interp2' = comp_interp p2 [256] compile1 :: (GaloisField k) => R1CS k compile1 = compileCompToR1CS Simplify p1 -comp1 :: (GaloisField k, Typeable a) => Comp ('TSum 'TBool a) k +comp1 :: (Typeable a) => Comp ('TSum 'TBool a) k comp1 = inl false comp2 :: (GaloisField k, Typeable a) => Comp ('TSum a 'TField) k comp2 = inr (fromField 0) -test1 :: (GaloisField k) => State (Env k) (TExp 'TBool k) +test1 :: (GaloisField k) => Comp 'TBool k test1 = do b <- fresh_input z <- if return b then comp1 else comp2 diff --git a/examples/Snarkl/Example/Keccak.hs b/examples/Snarkl/Example/Keccak.hs index f0573e3..ce8f136 100644 --- a/examples/Snarkl/Example/Keccak.hs +++ b/examples/Snarkl/Example/Keccak.hs @@ -35,7 +35,6 @@ ln_width :: Int ln_width = 32 round1 :: - (GaloisField k) => (Int -> TExp 'TBool k) -> -- | 'i'th bit of round constant TExp ('TArr ('TArr ('TArr 'TBool))) k -> @@ -198,7 +197,7 @@ keccak_f1 num_rounds a = ) -- num_rounds = 12+2l, where 2^l = ln_width -keccak1 :: (GaloisField k) => Int -> Comp 'TBool k +keccak1 :: Int -> Comp 'TBool k keccak1 num_rounds = do a <- input_arr3 5 5 ln_width diff --git a/examples/Snarkl/Example/Lam.hs b/examples/Snarkl/Example/Lam.hs index 32e53d6..c33e7e8 100644 --- a/examples/Snarkl/Example/Lam.hs +++ b/examples/Snarkl/Example/Lam.hs @@ -33,7 +33,7 @@ type TFSubst = 'TFSum ('TFConst 'TField) ('TFProd ('TFConst TTerm) 'TFId) type TSubst = 'TMu TFSubst -subst_nil :: (GaloisField k) => TExp 'TField k -> Comp TSubst k +subst_nil :: TExp 'TField k -> Comp TSubst k subst_nil n = do n' <- inl n @@ -66,7 +66,6 @@ type TF = 'TFSum ('TFConst 'TField) ('TFSum 'TFId ('TFProd 'TFId 'TFId)) type TTerm = 'TMu TF varN :: - (GaloisField k) => TExp 'TField k -> Comp TTerm k varN e = @@ -84,7 +83,6 @@ varN' i = roll v lam :: - (GaloisField k) => TExp TTerm k -> Comp TTerm k lam t = @@ -94,7 +92,6 @@ lam t = roll v app :: - (GaloisField k) => TExp TTerm k -> TExp TTerm k -> Comp TTerm k @@ -106,9 +103,7 @@ app t1 t2 = roll v case_term :: - ( Typeable ty, - Zippable ty k, - GaloisField k + ( Zippable ty k ) => TExp TTerm k -> (TExp 'TField k -> Comp ty k) -> @@ -126,7 +121,7 @@ case_term t f_var f_lam f_app = e2 <- fst_pair p f_app e1 e2 -is_lam :: (GaloisField k) => TExp TTerm k -> Comp 'TBool k +is_lam :: TExp TTerm k -> Comp 'TBool k is_lam t = case_term t diff --git a/examples/Snarkl/Example/List.hs b/examples/Snarkl/Example/List.hs index fe2157b..48dcef3 100644 --- a/examples/Snarkl/Example/List.hs +++ b/examples/Snarkl/Example/List.hs @@ -27,12 +27,12 @@ type TList a = 'TMu (TF a) type List a k = TExp (TList a) k -nil :: (Typeable a, GaloisField k) => Comp (TList a) k +nil :: (Typeable a) => Comp (TList a) k nil = do t <- inl unit roll t -cons :: (Typeable a, GaloisField k) => TExp a k -> List a k -> Comp (TList a) k +cons :: (Typeable a) => TExp a k -> List a k -> Comp (TList a) k cons f t = do p <- pair f t @@ -41,9 +41,7 @@ cons f t = case_list :: ( Typeable a, - Typeable ty, - Zippable ty k, - GaloisField k + Zippable ty k ) => List a k -> Comp ty k -> @@ -62,9 +60,7 @@ case_list t f_nil f_cons = head_list :: ( Typeable a, - Zippable a k, - Derive a k, - GaloisField k + Zippable a k ) => TExp a k -> List a k -> @@ -78,8 +74,7 @@ head_list def l = tail_list :: ( Typeable a, Zippable a k, - Derive a k, - GaloisField k + Derive a k ) => List a k -> Comp (TList a) k @@ -96,8 +91,7 @@ tail_list l = app_list :: ( Typeable a, Zippable a k, - Derive a k, - GaloisField k + Derive a k ) => List a k -> List a k -> @@ -116,8 +110,7 @@ app_list l1 l2 = fix go l1 rev_list :: ( Typeable a, Zippable a k, - Derive a k, - GaloisField k + Derive a k ) => List a k -> Comp (TList a) k @@ -136,12 +129,9 @@ rev_list l = fix go l map_list :: ( Typeable a, - Zippable a k, - Derive a k, Typeable b, Zippable b k, - Derive b k, - GaloisField k + Derive b k ) => (TExp a k -> Comp b k) -> List a k -> @@ -161,7 +151,7 @@ map_list f l = ) last_list :: - (Typeable a, Zippable a k, Derive a k, GaloisField k) => + (Typeable a, Zippable a k) => TExp a k -> List a k -> Comp a k diff --git a/examples/Snarkl/Example/Peano.hs b/examples/Snarkl/Example/Peano.hs index cfbbdb3..5fb1d08 100644 --- a/examples/Snarkl/Example/Peano.hs +++ b/examples/Snarkl/Example/Peano.hs @@ -22,13 +22,13 @@ type TF = 'TFSum ('TFConst 'TUnit) 'TFId type TNat = 'TMu TF -nat_zero :: (GaloisField k) => Comp TNat k +nat_zero :: Comp TNat k nat_zero = do x <- inl unit roll x -nat_succ :: (GaloisField k) => TExp TNat k -> Comp TNat k +nat_succ :: TExp TNat k -> Comp TNat k nat_succ n = do x <- inr n diff --git a/examples/Snarkl/Example/Queue.hs b/examples/Snarkl/Example/Queue.hs index ed06e12..b69dddd 100644 --- a/examples/Snarkl/Example/Queue.hs +++ b/examples/Snarkl/Example/Queue.hs @@ -27,14 +27,14 @@ type TQueue a = 'TProd (TStack a) (TStack a) type Queue a k = TExp (TQueue a) k -empty_queue :: (Typeable a, GaloisField k) => Comp (TQueue a) k +empty_queue :: (Typeable a) => Comp (TQueue a) k empty_queue = do l <- empty_stack r <- empty_stack pair l r enqueue :: - (Typeable a, GaloisField k) => + (Typeable a) => TExp a k -> Queue a k -> Comp (TQueue a) k @@ -45,7 +45,7 @@ enqueue v q = do pair l' r dequeue :: - (Zippable a k, Derive a k, Typeable a, GaloisField k) => + (Zippable a k, Derive a k, Typeable a) => Queue a k -> TExp a k -> Comp ('TProd a (TQueue a)) k @@ -73,7 +73,7 @@ dequeue q def = do pair h p dequeue_rec :: - (Zippable a k, Derive a k, Typeable a, GaloisField k) => + (Zippable a k, Derive a k, Typeable a) => Queue a k -> TExp a k -> Comp ('TProd a (TQueue a)) k @@ -113,7 +113,7 @@ is_empty q = do (\_ _ -> return false) last_queue :: - (Zippable a k, Derive a k, Typeable a, GaloisField k) => + (Zippable a k, Derive a k, Typeable a) => Queue a k -> TExp a k -> Comp a k diff --git a/examples/Snarkl/Example/Stack.hs b/examples/Snarkl/Example/Stack.hs index 7c54c95..c7c4371 100644 --- a/examples/Snarkl/Example/Stack.hs +++ b/examples/Snarkl/Example/Stack.hs @@ -28,19 +28,19 @@ type TStack a = TList a type Stack a k = TExp (TStack a) k -empty_stack :: (Typeable a, GaloisField k) => Comp (TStack a) k +empty_stack :: (Typeable a) => Comp (TStack a) k empty_stack = nil -push_stack :: (Typeable a, GaloisField k) => TExp a k -> Stack a k -> Comp (TStack a) k +push_stack :: (Typeable a) => TExp a k -> Stack a k -> Comp (TStack a) k push_stack p q = cons p q -pop_stack :: (Derive a k, Zippable a k, Typeable a, GaloisField k) => Stack a k -> Comp (TStack a) k +pop_stack :: (Derive a k, Zippable a k, Typeable a) => Stack a k -> Comp (TStack a) k pop_stack f = tail_list f -top_stack :: (Derive a k, Zippable a k, Typeable a, GaloisField k) => TExp a k -> Stack a k -> Comp a k +top_stack :: (Zippable a k, Typeable a) => TExp a k -> Stack a k -> Comp a k top_stack def e = head_list def e -is_empty_stack :: (Typeable a, GaloisField k) => Stack a k -> Comp 'TBool k +is_empty_stack :: (Typeable a) => Stack a k -> Comp 'TBool k is_empty_stack s = case_list s (return true) (\_ _ -> return false) diff --git a/examples/Snarkl/Example/Tree.hs b/examples/Snarkl/Example/Tree.hs index b5a1af7..a3cb137 100644 --- a/examples/Snarkl/Example/Tree.hs +++ b/examples/Snarkl/Example/Tree.hs @@ -27,12 +27,12 @@ type Rat k = TExp 'TField k type Tree a k = TExp (TTree a) k -leaf :: (Typeable a, GaloisField k) => Comp (TTree a) k +leaf :: (Typeable a) => Comp (TTree a) k leaf = do t <- inl unit roll t -node :: (Typeable a, GaloisField k) => TExp a k -> Tree a k -> Tree a k -> Comp (TTree a) k +node :: (Typeable a) => TExp a k -> Tree a k -> Tree a k -> Comp (TTree a) k node v t1 t2 = do p <- pair t1 t2 p' <- pair v p @@ -41,8 +41,6 @@ node v t1 t2 = do case_tree :: ( Typeable a, - GaloisField k, - Typeable a1, Zippable a1 k ) => Tree a k -> @@ -64,8 +62,7 @@ map_tree :: ( Typeable a, Typeable a1, Zippable a1 k, - Derive a1 k, - GaloisField k + Derive a1 k ) => (TExp a k -> Comp a1 k) -> TExp (TTree a) k -> From 6cf05f691edca038be562273c83e121afe6c15bb Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 23:06:11 -0800 Subject: [PATCH 16/19] lint --- benchmarks/Harness.hs | 1 - snarkl.cabal | 2 +- src/Snarkl/Backend/R1CS/Poly.hs | 2 -- src/Snarkl/Compile.hs | 2 -- src/Snarkl/Language.hs | 3 ++- 5 files changed, 3 insertions(+), 7 deletions(-) diff --git a/benchmarks/Harness.hs b/benchmarks/Harness.hs index bb3c7e2..6ee4efd 100644 --- a/benchmarks/Harness.hs +++ b/benchmarks/Harness.hs @@ -31,7 +31,6 @@ import Snarkl.Toplevel compileTexpToConstraints, do_simplify, execute, - lastSeq, serializeR1CSAsJson, serializeWitnessAsJson, wit_of_r1cs, diff --git a/snarkl.cabal b/snarkl.cabal index dddc05b..a35ac90 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -25,7 +25,7 @@ source-repository head library ghc-options: - -Wall -Wredundant-constraints -funbox-strict-fields -optc-O3 + -Wall -Werror -Wredundant-constraints -funbox-strict-fields -optc-O3 -- -threaded exposed-modules: diff --git a/src/Snarkl/Backend/R1CS/Poly.hs b/src/Snarkl/Backend/R1CS/Poly.hs index 23e5a32..eeaa1af 100644 --- a/src/Snarkl/Backend/R1CS/Poly.hs +++ b/src/Snarkl/Backend/R1CS/Poly.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE InstanceSigs #-} - module Snarkl.Backend.R1CS.Poly where import qualified Data.Aeson as A diff --git a/src/Snarkl/Compile.hs b/src/Snarkl/Compile.hs index d11957a..83f3820 100644 --- a/src/Snarkl/Compile.hs +++ b/src/Snarkl/Compile.hs @@ -22,8 +22,6 @@ import Control.Monad.State import qualified Control.Monad.State as State import Data.Either (fromRight) import Data.Field.Galois (GaloisField) --- do_const_prop, - import Data.Foldable (traverse_) import Data.List (sort) import qualified Data.Map as Map diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index ca413c9..0af8035 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -86,6 +86,7 @@ module Snarkl.Language where import Data.Field.Galois (GaloisField) +import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Language.Core ( Assignment (..), Exp (..), @@ -170,4 +171,4 @@ compileTExpToProgram te = let eprog = mkProgram . expOfLambdaExp . tExpToLambdaExp $ te in case eprog of Right p -> p - Left err -> error $ "compileTExpToProgram: failed to convert TExp to Program: " <> err + Left err -> failWith $ ErrMsg $ "compileTExpToProgram: failed to convert TExp to Program: " <> err From f17df14e5d6b19d38317646ecae115172705a520 Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 23:45:17 -0800 Subject: [PATCH 17/19] narrow down language --- examples/Snarkl/Example/Basic.hs | 6 ++--- examples/Snarkl/Example/Games.hs | 6 ++--- examples/Snarkl/Example/Keccak.hs | 5 +--- examples/Snarkl/Example/List.hs | 5 +--- examples/Snarkl/Example/Matrix.hs | 6 ++--- examples/Snarkl/Example/Queue.hs | 5 +--- examples/Snarkl/Example/Stack.hs | 5 +--- src/Snarkl/Compile.hs | 41 +++++++++++------------------- src/Snarkl/Interp.hs | 11 ++++---- src/Snarkl/Language.hs | 38 +++++---------------------- src/Snarkl/Language/Expr.hs | 5 ++-- src/Snarkl/Language/SyntaxMonad.hs | 16 +++++++++++- src/Snarkl/Toplevel.hs | 6 ----- tests/Test/Snarkl/Unit/Programs.hs | 5 +--- 14 files changed, 56 insertions(+), 104 deletions(-) diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index f601911..644aaa9 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -7,10 +7,8 @@ import Data.Typeable (Typeable) import GHC.TypeLits (KnownNat) import Snarkl.Compile import Snarkl.Field (F_BN128) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Language +import Snarkl.Toplevel (R1CS, comp_interp) import System.Exit (ExitCode) import Prelude hiding ( fromRational, diff --git a/examples/Snarkl/Example/Games.hs b/examples/Snarkl/Example/Games.hs index 0d0bf4e..9735235 100644 --- a/examples/Snarkl/Example/Games.hs +++ b/examples/Snarkl/Example/Games.hs @@ -11,10 +11,8 @@ import Data.Kind (Type) import Data.Typeable import Snarkl.Errors import Snarkl.Field (F_BN128) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Language +import Snarkl.Toplevel (comp_interp) import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Keccak.hs b/examples/Snarkl/Example/Keccak.hs index ce8f136..4d440bb 100644 --- a/examples/Snarkl/Example/Keccak.hs +++ b/examples/Snarkl/Example/Keccak.hs @@ -9,10 +9,7 @@ import Data.Bits hiding (xor) import Data.Field.Galois (GaloisField, Prime) import qualified Data.Map.Strict as Map import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Language import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/List.hs b/examples/Snarkl/Example/List.hs index 48dcef3..bd0589b 100644 --- a/examples/Snarkl/Example/List.hs +++ b/examples/Snarkl/Example/List.hs @@ -5,10 +5,7 @@ module Snarkl.Example.List where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Language import Prelude hiding ( negate, return, diff --git a/examples/Snarkl/Example/Matrix.hs b/examples/Snarkl/Example/Matrix.hs index 9ac950a..0356604 100644 --- a/examples/Snarkl/Example/Matrix.hs +++ b/examples/Snarkl/Example/Matrix.hs @@ -4,10 +4,8 @@ module Snarkl.Example.Matrix where import Data.Field.Galois (GaloisField, Prime) import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Language +import Snarkl.Toplevel (comp_interp) import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Queue.hs b/examples/Snarkl/Example/Queue.hs index b69dddd..142d23b 100644 --- a/examples/Snarkl/Example/Queue.hs +++ b/examples/Snarkl/Example/Queue.hs @@ -6,10 +6,7 @@ import Data.Field.Galois (GaloisField) import Data.Typeable import Snarkl.Example.List import Snarkl.Example.Stack -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Language import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Stack.hs b/examples/Snarkl/Example/Stack.hs index c7c4371..eac62aa 100644 --- a/examples/Snarkl/Example/Stack.hs +++ b/examples/Snarkl/Example/Stack.hs @@ -7,10 +7,7 @@ import Data.Typeable import GHC.TypeLits (KnownNat) import Snarkl.Compile import Snarkl.Example.List -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Language import Prelude hiding ( fromRational, negate, diff --git a/src/Snarkl/Compile.hs b/src/Snarkl/Compile.hs index 83f3820..8799187 100644 --- a/src/Snarkl/Compile.hs +++ b/src/Snarkl/Compile.hs @@ -10,6 +10,7 @@ module Snarkl.Compile compileCompToTexp, compileTexpToConstraints, compileCompToConstraints, + compileTExpToProgram, ) where @@ -43,16 +44,11 @@ import Snarkl.Constraint solve, ) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language - ( Comp, - Env (Env, input_vars, next_variable), - TExp, - Variable (Variable), - booleanVarsOfTexp, - compileTExpToProgram, - runState, - ) import qualified Snarkl.Language.Core as Core +import Snarkl.Language.Expr (mkProgram) +import Snarkl.Language.LambdaExpr (expOfLambdaExp) +import Snarkl.Language.SyntaxMonad (Comp, Env (..), runComp) +import Snarkl.Language.TExpr (TExp, booleanVarsOfTexp, tExpToLambdaExp) import Text.PrettyPrint.Leijen.Text (Pretty (..)) ---------------------------------------------------------------- @@ -447,9 +443,9 @@ compileConstraintsToR1CS simpl cs = -- | The result of desugaring a Snarkl computation. data TExpPkg ty k = TExpPkg { -- | The number of free variables in the computation. - out_variable :: Variable, + out_variable :: Core.Variable, -- | The variables marked as inputs. - comp_input_variables :: [Variable], + comp_input_variables :: [Core.Variable], -- | The resulting 'TExp'. comp_texp :: TExp ty k } @@ -468,23 +464,16 @@ compileCompToTexp :: Comp ty k -> TExpPkg ty k compileCompToTexp mf = - case run mf of + case runComp mf of Left err -> failWith err Right (e, rho) -> - let out = Variable (next_variable rho) + let out = Core.Variable (next_variable rho) in_vars = sort $ input_vars rho in TExpPkg out in_vars e - where - run mf0 = - runState - mf0 - ( Env - 0 - 0 - [] - Map.empty - Map.empty - ) + +compileTExpToProgram :: (GaloisField k) => TExp ty k -> Core.Program k +compileTExpToProgram te = + mkProgram . expOfLambdaExp . tExpToLambdaExp $ te -- | Snarkl.Compile 'TExp's to constraint systems. Re-exported from 'Snarkl.Compile.Snarkl.Compile'. compileTexpToConstraints :: @@ -551,7 +540,7 @@ compileCompToR1CS simpl = compileConstraintsToR1CS simpl . compileCompToConstrai -------------------------------------------------------------------------------- -_Var :: Iso' Variable Var -_Var = iso (\(Variable v) -> Var v) (\(Var v) -> Variable v) +_Var :: Iso' Core.Variable Var +_Var = iso (\(Core.Variable v) -> Var v) (\(Var v) -> Core.Variable v) -------------------------------------------------------------------------------- diff --git a/src/Snarkl/Interp.hs b/src/Snarkl/Interp.hs index 4d66dd5..576dcfa 100644 --- a/src/Snarkl/Interp.hs +++ b/src/Snarkl/Interp.hs @@ -9,11 +9,12 @@ import Data.Foldable (traverse_) import Data.Map (Map) import qualified Data.Map as Map import Snarkl.Common (Op (..), UnOp (ZEq)) +import Snarkl.Compile (compileTExpToProgram) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language (TExp, Variable, compileTExpToProgram) import qualified Snarkl.Language.Core as Core +import Snarkl.Language.TExpr (TExp) -type Env a = Map Variable (Maybe a) +type Env a = Map Core.Variable (Maybe a) newtype InterpM a b = InterpM {runInterpM :: Env a -> Either ErrMsg (Env a, b)} @@ -38,11 +39,11 @@ raiseErr :: ErrMsg -> InterpM a b raiseErr err = InterpM (\_ -> Left err) -addBinds :: [(Variable, Maybe a)] -> InterpM a (Maybe b) +addBinds :: [(Core.Variable, Maybe a)] -> InterpM a (Maybe b) addBinds binds = InterpM (\rho -> Right (Map.union (Map.fromList binds) rho, Nothing)) -lookupVar :: (Show a) => Variable -> InterpM a (Maybe a) +lookupVar :: (Show a) => Core.Variable -> InterpM a (Maybe a) lookupVar x = InterpM ( \rho -> case Map.lookup x rho of @@ -86,7 +87,7 @@ interpTExp e = do interp :: (GaloisField k) => - Map Variable k -> + Map Core.Variable k -> TExp ty k -> Either ErrMsg (Env k, Maybe k) interp rho e = runInterpM (interpTExp e) $ Map.map Just rho diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index 0af8035..b56ce65 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -1,22 +1,15 @@ +{-# LANGUAGE NoImplicitPrelude #-} + module Snarkl.Language - ( compileTExpToProgram, - -- | Snarkl.Language.TExpr, - booleanVarsOfTexp, + ( -- | Snarkl.Language.TExpr, TExp, - -- | Snarkl.Language.Core, - Variable (..), - Program (..), - Assignment (..), - Exp (..), - -- types module Snarkl.Language.Type, -- | SyntaxMonad and Syntax Comp, - runState, + runComp, return, (>>=), (>>), - Env (..), -- | Return a fresh input variable. fresh_input, -- | Classes @@ -85,16 +78,6 @@ module Snarkl.Language ) where -import Data.Field.Galois (GaloisField) -import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language.Core - ( Assignment (..), - Exp (..), - Program (..), - Variable (..), - ) -import Snarkl.Language.Expr (mkProgram) -import Snarkl.Language.LambdaExpr (expOfLambdaExp) import Snarkl.Language.Syntax ( Derive, Zippable, @@ -152,23 +135,14 @@ import Snarkl.Language.Syntax ) import Snarkl.Language.SyntaxMonad ( Comp, - Env (..), false, fresh_input, return, - runState, + runComp, true, unit, (>>), (>>=), ) -import Snarkl.Language.TExpr (TExp, booleanVarsOfTexp, tExpToLambdaExp) +import Snarkl.Language.TExpr (TExp) import Snarkl.Language.Type -import Prelude (Either (..), error, ($), (.), (<>)) - -compileTExpToProgram :: (GaloisField k) => TExp ty k -> Program k -compileTExpToProgram te = - let eprog = mkProgram . expOfLambdaExp . tExpToLambdaExp $ te - in case eprog of - Right p -> p - Left err -> failWith $ ErrMsg $ "compileTExpToProgram: failed to convert TExp to Program: " <> err diff --git a/src/Snarkl/Language/Expr.hs b/src/Snarkl/Language/Expr.hs index 287e34e..42647fe 100644 --- a/src/Snarkl/Language/Expr.hs +++ b/src/Snarkl/Language/Expr.hs @@ -21,6 +21,7 @@ import Data.Map (Map) import qualified Data.Map as Map import Data.Sequence (Seq, fromList, (<|), (><), (|>), pattern Empty, pattern (:<|)) import Snarkl.Common (Op, UnOp, isAssoc) +import Snarkl.Errors (ErrMsg (..), failWith) import qualified Snarkl.Language.Core as Core import Text.PrettyPrint.Leijen.Text ( Pretty (pretty), @@ -144,8 +145,8 @@ mkAssignment e = throwError $ "mkAssignment: expected EAssert, got " <> show e -- At this point the expression should be either: -- 1. A sequence of assignments, followed by an expression -- 2. An expression -mkProgram :: (GaloisField k) => Exp k -> Either String (Core.Program k) -mkProgram _exp = do +mkProgram :: (GaloisField k) => Exp k -> Core.Program k +mkProgram _exp = either (failWith . ErrMsg) id $ do let e' = do_const_prop _exp case e' of ESeq es -> do diff --git a/src/Snarkl/Language/SyntaxMonad.hs b/src/Snarkl/Language/SyntaxMonad.hs index 2bf94a1..c57c42d 100644 --- a/src/Snarkl/Language/SyntaxMonad.hs +++ b/src/Snarkl/Language/SyntaxMonad.hs @@ -6,13 +6,14 @@ module Snarkl.Language.SyntaxMonad ( -- | Computation monad Comp, - CompResult, runState, + runComp, return, (>>=), (>>), raise_err, Env (..), + defaultEnv, State (..), -- | Return a fresh input variable. fresh_input, @@ -169,8 +170,21 @@ data Env k = Env } deriving (Show) +defaultEnv :: Env k +defaultEnv = + Env + { next_variable = 0, + next_loc = 0, + input_vars = [], + obj_map = Map.empty, + anal_map = Map.empty + } + type Comp ty k = State (Env k) (TExp ty k) +runComp :: Comp ty k -> Either ErrMsg (TExp ty k, Env k) +runComp f = runState f defaultEnv + {----------------------------------------------- Units, Booleans (used below) ------------------------------------------------} diff --git a/src/Snarkl/Toplevel.hs b/src/Snarkl/Toplevel.hs index 4f24429..70c7506 100644 --- a/src/Snarkl/Toplevel.hs +++ b/src/Snarkl/Toplevel.hs @@ -32,12 +32,6 @@ import Snarkl.Language import Text.PrettyPrint.Leijen.Text (Pretty (..), line, (<+>)) import Prelude ----------------------------------------------------- --- --- Snarkl.Toplevel Stuff --- ----------------------------------------------------- - -- | Using the executable semantics for the 'TExp' language, execute -- the computation on the provided inputs, returning the 'k' result. comp_interp :: diff --git a/tests/Test/Snarkl/Unit/Programs.hs b/tests/Test/Snarkl/Unit/Programs.hs index df1d889..fac87c6 100644 --- a/tests/Test/Snarkl/Unit/Programs.hs +++ b/tests/Test/Snarkl/Unit/Programs.hs @@ -13,10 +13,7 @@ import Snarkl.Example.List import Snarkl.Example.Peano import Snarkl.Example.Tree import Snarkl.Field (F_BN128) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Language import Prelude hiding ( fromRational, negate, From f998e4e325e8d1b0ae3e0c91ce3035a77c4fa019 Mon Sep 17 00:00:00 2001 From: martyall Date: Sun, 7 Jan 2024 23:59:13 -0800 Subject: [PATCH 18/19] build benchmarks with nix build --- app/Main.hs | 4 ++-- benchmarks/Harness.hs | 12 +++++------- flake.nix | 2 ++ src/Snarkl/Toplevel.hs | 3 --- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/app/Main.hs b/app/Main.hs index 40b9008..f8dd48d 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -6,9 +6,9 @@ import Data.Field.Galois (PrimeField) import Data.Typeable (Typeable) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Field (F_BN128) +import Snarkl.Language (Comp) import Snarkl.Toplevel - ( Comp, - Result (..), + ( Result (..), SimplParam (..), execute, mkInputsFilePath, diff --git a/benchmarks/Harness.hs b/benchmarks/Harness.hs index 6ee4efd..42c22d2 100644 --- a/benchmarks/Harness.hs +++ b/benchmarks/Harness.hs @@ -10,26 +10,24 @@ import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (GaloisField, Prime, PrimeField) import qualified Data.Map as Map import qualified Data.Set as Set -import Data.Typeable -import GHC.IO.Exception +import Data.Typeable (Typeable) +import GHC.IO.Exception (ExitCode (..)) import GHC.TypeLits (KnownNat) import Snarkl.Compile (SimplParam) +import Snarkl.Constraint (ConstraintSystem (..), do_simplify) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) +import Snarkl.Language (Comp) import Snarkl.Language.TExpr ( TExp (..), lastSeq, ) import Snarkl.Toplevel - ( Comp, - ConstraintSystem (..), - Result (..), - TExp (..), + ( Result (..), TExpPkg (..), comp_interp, compileCompToR1CS, compileCompToTexp, compileTexpToConstraints, - do_simplify, execute, serializeR1CSAsJson, serializeWitnessAsJson, diff --git a/flake.nix b/flake.nix index 3aa2fef..bdab0aa 100644 --- a/flake.nix +++ b/flake.nix @@ -30,10 +30,12 @@ packages = { lib = flake.packages."snarkl:lib:snarkl"; + benchmark = flake.packages."snarkl:bench:criterion"; all = pkgs.symlinkJoin { name = "all"; paths = with packages; [ lib + benchmark ]; }; default = packages.all; diff --git a/src/Snarkl/Toplevel.hs b/src/Snarkl/Toplevel.hs index 70c7506..59b333a 100644 --- a/src/Snarkl/Toplevel.hs +++ b/src/Snarkl/Toplevel.hs @@ -11,8 +11,6 @@ module Snarkl.Toplevel execute, -- * Re-exported modules - module Snarkl.Language, - module Snarkl.Constraint, module Snarkl.Backend.R1CS, module Snarkl.Compile, ) @@ -25,7 +23,6 @@ import Data.Typeable (Typeable) import Snarkl.Backend.R1CS import Snarkl.Common (Assgn) import Snarkl.Compile -import Snarkl.Constraint import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Interp (interp) import Snarkl.Language From f1c22f9d85e53849414eead95ead6cbadec2627c Mon Sep 17 00:00:00 2001 From: martyall Date: Mon, 8 Jan 2024 09:12:44 -0800 Subject: [PATCH 19/19] split out syntax module --- benchmarks/Main.hs | 2 +- examples/Snarkl/Example/Basic.hs | 2 +- examples/Snarkl/Example/Games.hs | 2 +- examples/Snarkl/Example/Keccak.hs | 2 +- examples/Snarkl/Example/Lam.hs | 2 +- examples/Snarkl/Example/List.hs | 2 +- examples/Snarkl/Example/Matrix.hs | 2 +- examples/Snarkl/Example/Peano.hs | 2 +- examples/Snarkl/Example/Queue.hs | 2 +- examples/Snarkl/Example/Stack.hs | 2 +- examples/Snarkl/Example/Tree.hs | 2 +- snarkl.cabal | 5 +- src/Snarkl/Compile.hs | 109 +++++++++--------- src/Snarkl/Language.hs | 165 ++++++---------------------- src/Snarkl/{Language => }/Syntax.hs | 46 +++++++- tests/Test/Snarkl/LambdaSpec.hs | 8 +- tests/Test/Snarkl/Unit/Programs.hs | 2 +- 17 files changed, 153 insertions(+), 204 deletions(-) rename src/Snarkl/{Language => }/Syntax.hs (95%) diff --git a/benchmarks/Main.hs b/benchmarks/Main.hs index 8a21d4a..87562e8 100644 --- a/benchmarks/Main.hs +++ b/benchmarks/Main.hs @@ -14,7 +14,7 @@ import qualified Snarkl.Example.Keccak as Keccak import qualified Snarkl.Example.List as List import qualified Snarkl.Example.Matrix as Matrix import Snarkl.Field (F_BN128, P_BN128) -import Snarkl.Language (Comp, fromField) +import Snarkl.Syntax (Comp, fromField) mk_bgroup :: (Typeable ty) => String -> Comp ty F_BN128 -> [Int] -> F_BN128 -> Benchmark mk_bgroup nm mf inputs result = diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index 644aaa9..bac359a 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -7,7 +7,7 @@ import Data.Typeable (Typeable) import GHC.TypeLits (KnownNat) import Snarkl.Compile import Snarkl.Field (F_BN128) -import Snarkl.Language +import Snarkl.Syntax import Snarkl.Toplevel (R1CS, comp_interp) import System.Exit (ExitCode) import Prelude hiding diff --git a/examples/Snarkl/Example/Games.hs b/examples/Snarkl/Example/Games.hs index 9735235..de7daf6 100644 --- a/examples/Snarkl/Example/Games.hs +++ b/examples/Snarkl/Example/Games.hs @@ -11,7 +11,7 @@ import Data.Kind (Type) import Data.Typeable import Snarkl.Errors import Snarkl.Field (F_BN128) -import Snarkl.Language +import Snarkl.Syntax import Snarkl.Toplevel (comp_interp) import Prelude hiding ( fromRational, diff --git a/examples/Snarkl/Example/Keccak.hs b/examples/Snarkl/Example/Keccak.hs index 4d440bb..bf1085c 100644 --- a/examples/Snarkl/Example/Keccak.hs +++ b/examples/Snarkl/Example/Keccak.hs @@ -9,7 +9,7 @@ import Data.Bits hiding (xor) import Data.Field.Galois (GaloisField, Prime) import qualified Data.Map.Strict as Map import GHC.TypeLits (KnownNat) -import Snarkl.Language +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Lam.hs b/examples/Snarkl/Example/Lam.hs index c33e7e8..d4fc4ab 100644 --- a/examples/Snarkl/Example/Lam.hs +++ b/examples/Snarkl/Example/Lam.hs @@ -10,7 +10,7 @@ import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) import Snarkl.Errors -import Snarkl.Language +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/List.hs b/examples/Snarkl/Example/List.hs index bd0589b..ecb422e 100644 --- a/examples/Snarkl/Example/List.hs +++ b/examples/Snarkl/Example/List.hs @@ -5,7 +5,7 @@ module Snarkl.Example.List where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) -import Snarkl.Language +import Snarkl.Syntax import Prelude hiding ( negate, return, diff --git a/examples/Snarkl/Example/Matrix.hs b/examples/Snarkl/Example/Matrix.hs index 0356604..cfa2a09 100644 --- a/examples/Snarkl/Example/Matrix.hs +++ b/examples/Snarkl/Example/Matrix.hs @@ -4,7 +4,7 @@ module Snarkl.Example.Matrix where import Data.Field.Galois (GaloisField, Prime) import GHC.TypeLits (KnownNat) -import Snarkl.Language +import Snarkl.Syntax import Snarkl.Toplevel (comp_interp) import Prelude hiding ( fromRational, diff --git a/examples/Snarkl/Example/Peano.hs b/examples/Snarkl/Example/Peano.hs index 5fb1d08..8cbb443 100644 --- a/examples/Snarkl/Example/Peano.hs +++ b/examples/Snarkl/Example/Peano.hs @@ -4,7 +4,7 @@ module Snarkl.Example.Peano where import Data.Field.Galois (GaloisField, Prime) import GHC.TypeLits (KnownNat) -import Snarkl.Language +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Queue.hs b/examples/Snarkl/Example/Queue.hs index 142d23b..c3d526f 100644 --- a/examples/Snarkl/Example/Queue.hs +++ b/examples/Snarkl/Example/Queue.hs @@ -6,7 +6,7 @@ import Data.Field.Galois (GaloisField) import Data.Typeable import Snarkl.Example.List import Snarkl.Example.Stack -import Snarkl.Language +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Stack.hs b/examples/Snarkl/Example/Stack.hs index eac62aa..0680743 100644 --- a/examples/Snarkl/Example/Stack.hs +++ b/examples/Snarkl/Example/Stack.hs @@ -7,7 +7,7 @@ import Data.Typeable import GHC.TypeLits (KnownNat) import Snarkl.Compile import Snarkl.Example.List -import Snarkl.Language +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Tree.hs b/examples/Snarkl/Example/Tree.hs index a3cb137..9a2ea8a 100644 --- a/examples/Snarkl/Example/Tree.hs +++ b/examples/Snarkl/Example/Tree.hs @@ -5,7 +5,7 @@ module Snarkl.Example.Tree where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) -import Snarkl.Language +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, diff --git a/snarkl.cabal b/snarkl.cabal index a35ac90..d08a292 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -25,7 +25,8 @@ source-repository head library ghc-options: - -Wall -Werror -Wredundant-constraints -funbox-strict-fields -optc-O3 + -Wall -Werror -Wredundant-constraints -funbox-strict-fields + -optc-O3 -- -threaded exposed-modules: @@ -51,10 +52,10 @@ library Snarkl.Language.Core Snarkl.Language.Expr Snarkl.Language.LambdaExpr - Snarkl.Language.Syntax Snarkl.Language.SyntaxMonad Snarkl.Language.TExpr Snarkl.Language.Type + Snarkl.Syntax Snarkl.Toplevel default-extensions: diff --git a/src/Snarkl/Compile.hs b/src/Snarkl/Compile.hs index 8799187..2b83f7a 100644 --- a/src/Snarkl/Compile.hs +++ b/src/Snarkl/Compile.hs @@ -44,11 +44,20 @@ import Snarkl.Constraint solve, ) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import qualified Snarkl.Language.Core as Core -import Snarkl.Language.Expr (mkProgram) -import Snarkl.Language.LambdaExpr (expOfLambdaExp) -import Snarkl.Language.SyntaxMonad (Comp, Env (..), runComp) -import Snarkl.Language.TExpr (TExp, booleanVarsOfTexp, tExpToLambdaExp) +import Snarkl.Language + ( Assignment (..), + Comp, + Env (..), + Exp (..), + Program (..), + TExp, + Variable (..), + booleanVarsOfTexp, + expOfLambdaExp, + mkProgram, + runComp, + tExpToLambdaExp, + ) import Text.PrettyPrint.Leijen.Text (Pretty (..)) ---------------------------------------------------------------- @@ -105,13 +114,13 @@ encode_or :: (GaloisField a) => (Var, Var, Var) -> State (CEnv a) () encode_or (x, y, z) = do x_mult_y <- fresh_var - cs_of_exp x_mult_y (Core.EBinop Mult [Core.EVar (_Var # x), Core.EVar (_Var # y)]) + cs_of_exp x_mult_y (EBinop Mult [EVar (_Var # x), EVar (_Var # y)]) cs_of_exp x_mult_y - ( Core.EBinop + ( EBinop Sub - [ Core.EBinop Add [Core.EVar (_Var # x), Core.EVar (_Var # y)], - Core.EVar (_Var # z) + [ EBinop Add [EVar (_Var # x), EVar (_Var # y)], + EVar (_Var # z) ] ) @@ -148,13 +157,13 @@ encode_boolean_eq :: (GaloisField a) => (Var, Var, Var) -> State (CEnv a) () encode_boolean_eq (x, y, z) = cs_of_exp z e where e = - Core.EBinop + EBinop Add - [ Core.EBinop Mult [Core.EVar (_Var # x), Core.EVar (_Var # y)], - Core.EBinop + [ EBinop Mult [EVar (_Var # x), EVar (_Var # y)], + EBinop Mult - [ Core.EBinop Sub [Core.EVal 1, Core.EVar (_Var # x)], - Core.EBinop Sub [Core.EVal 1, Core.EVar (_Var # y)] + [ EBinop Sub [EVal 1, EVar (_Var # x)], + EBinop Sub [EVal 1, EVar (_Var # y)] ] ] @@ -163,9 +172,9 @@ encode_boolean_eq (x, y, z) = cs_of_exp z e encode_eq :: (GaloisField a) => (Var, Var, Var) -> State (CEnv a) () encode_eq (x, y, z) = cs_of_assignment $ - Core.Assignment + Assignment (_Var # z) - (Core.EUnop ZEq (Core.EBinop Sub [Core.EVar (_Var # x), Core.EVar (_Var # y)])) + (EUnop ZEq (EBinop Sub [EVar (_Var # x), EVar (_Var # y)])) -- | Constraint 'y = x!=0 ? 1 : 0'. -- The encoding is: @@ -185,8 +194,8 @@ encode_zneq (x, y) = nm <- fresh_var add_constraint (CMagic nm [x, m] mf) -- END magic. - cs_of_exp y (Core.EBinop Mult [Core.EVar (_Var # x), Core.EVar (_Var # m)]) - cs_of_exp neg_y (Core.EBinop Sub [Core.EVal 1, Core.EVar (_Var # y)]) + cs_of_exp y (EBinop Mult [EVar (_Var # x), EVar (_Var # m)]) + cs_of_exp neg_y (EBinop Sub [EVal 1, EVar (_Var # y)]) add_constraint (CMult (1, neg_y) (1, x) (0, Nothing)) where @@ -213,7 +222,7 @@ encode_zeq (x, y) = do neg_y <- fresh_var encode_zneq (x, neg_y) - cs_of_exp y (Core.EBinop Sub [Core.EVal 1, Core.EVar (_Var # neg_y)]) + cs_of_exp y (EBinop Sub [EVal 1, EVar (_Var # neg_y)]) -- | Encode the constraint 'un_op x = y' encode_unop :: (GaloisField a) => UnOp -> (Var, Var) -> State (CEnv a) () @@ -255,20 +264,20 @@ encode_linear out xs = remove_consts (Left p : l) = p : remove_consts l remove_consts (Right _ : l) = remove_consts l -cs_of_exp :: (GaloisField k) => Var -> Core.Exp k -> State (CEnv k) () +cs_of_exp :: (GaloisField k) => Var -> Exp k -> State (CEnv k) () cs_of_exp out e = case e of - Core.EVar x -> + EVar x -> ensure_equal (out, view _Var x) - Core.EVal c -> + EVal c -> ensure_const (out, c) - Core.EUnop op (Core.EVar x) -> + EUnop op (EVar x) -> encode_unop op (view _Var x, out) - Core.EUnop op e1 -> + EUnop op e1 -> do e1_out <- fresh_var cs_of_exp e1_out e1 encode_unop op (e1_out, out) - Core.EBinop op es -> + EBinop op es -> -- [NOTE linear combination optimization:] cf. also -- 'encode_linear' above. 'go_linear' returns a list of -- (label*coeff + constant) pairs. @@ -281,33 +290,33 @@ cs_of_exp out e = case e of -- We special-case linear combinations in this way to avoid having -- to introduce new multiplication gates for multiplication by -- constant scalars. - let go_linear :: (GaloisField k) => [Core.Exp k] -> State (CEnv k) [Either (Var, k) k] + let go_linear :: (GaloisField k) => [Exp k] -> State (CEnv k) [Either (Var, k) k] go_linear [] = return [] - go_linear (Core.EBinop Mult [Core.EVar x, Core.EVal coeff] : es') = + go_linear (EBinop Mult [EVar x, EVal coeff] : es') = do labels <- go_linear es' return $ Left (x ^. _Var, coeff) : labels - go_linear (Core.EBinop Mult [Core.EVal coeff, Core.EVar y] : es') = + go_linear (EBinop Mult [EVal coeff, EVar y] : es') = do labels <- go_linear es' return $ Left (y ^. _Var, coeff) : labels - go_linear (Core.EBinop Mult [e_left, Core.EVal coeff] : es') = + go_linear (EBinop Mult [e_left, EVal coeff] : es') = do e_left_out <- fresh_var cs_of_exp e_left_out e_left labels <- go_linear es' return $ Left (e_left_out, coeff) : labels - go_linear (Core.EBinop Mult [Core.EVal coeff, e_right] : es') = + go_linear (EBinop Mult [EVal coeff, e_right] : es') = do e_right_out <- fresh_var cs_of_exp e_right_out e_right labels <- go_linear es' return $ Left (e_right_out, coeff) : labels - go_linear (Core.EVal c : es') = + go_linear (EVal c : es') = do labels <- go_linear es' return $ Right c : labels - go_linear (Core.EVar x : es') = + go_linear (EVar x : es') = do labels <- go_linear es' return $ Left (x ^. _Var, 1) : labels @@ -332,9 +341,9 @@ cs_of_exp out e = case e of rev_pol (Left (x, c) : ls) = Left (x, -c) : rev_pol ls rev_pol (Right c : ls) = Right (-c) : rev_pol ls - go_other :: (GaloisField k) => [Core.Exp k] -> State (CEnv k) [Var] + go_other :: (GaloisField k) => [Exp k] -> State (CEnv k) [Var] go_other [] = return [] - go_other (Core.EVar x : es') = + go_other (EVar x : es') = do labels <- go_other es' return $ (x ^. _Var) : labels @@ -372,24 +381,24 @@ cs_of_exp out e = case e of encode_labels labels -- Encoding: out = b*e1 + (1-b)e2 - Core.EIf b e1 e2 -> cs_of_exp out e0 + EIf b e1 e2 -> cs_of_exp out e0 where e0 = - Core.EBinop + EBinop Add - [ Core.EBinop Mult [b, e1], - Core.EBinop Mult [Core.EBinop Sub [Core.EVal 1, b], e2] + [ EBinop Mult [b, e1], + EBinop Mult [EBinop Sub [EVal 1, b], e2] ] - Core.EUnit -> + EUnit -> -- NOTE: [[ EUnit ]]_{out} = [[ EVal zero ]]_{out}. - cs_of_exp out (Core.EVal 0) + cs_of_exp out (EVal 0) ---- NOTE: when compiling assignments, the naive thing to do is ---- to introduce a new var, e2_out, bound to result of e2 and ---- then ensure that e2_out == x. We optimize by passing x to ---- compilation of e2 directly. -cs_of_assignment :: (GaloisField a) => Core.Assignment a -> State (CEnv a) () -cs_of_assignment (Core.Assignment x e) = cs_of_exp (view _Var x) e +cs_of_assignment :: (GaloisField a) => Assignment a -> State (CEnv a) () +cs_of_assignment (Assignment x e) = cs_of_exp (view _Var x) e data SimplParam = NoSimplify @@ -443,9 +452,9 @@ compileConstraintsToR1CS simpl cs = -- | The result of desugaring a Snarkl computation. data TExpPkg ty k = TExpPkg { -- | The number of free variables in the computation. - out_variable :: Core.Variable, + out_variable :: Variable, -- | The variables marked as inputs. - comp_input_variables :: [Core.Variable], + comp_input_variables :: [Variable], -- | The resulting 'TExp'. comp_texp :: TExp ty k } @@ -467,11 +476,11 @@ compileCompToTexp mf = case runComp mf of Left err -> failWith err Right (e, rho) -> - let out = Core.Variable (next_variable rho) + let out = Variable (next_variable rho) in_vars = sort $ input_vars rho in TExpPkg out in_vars e -compileTExpToProgram :: (GaloisField k) => TExp ty k -> Core.Program k +compileTExpToProgram :: (GaloisField k) => TExp ty k -> Program k compileTExpToProgram te = mkProgram . expOfLambdaExp . tExpToLambdaExp $ te @@ -490,11 +499,11 @@ compileTexpToConstraints (TExpPkg _out _in_vars te) = Set.toList $ Set.fromList in_vars `Set.intersection` Set.fromList (map (view _Var) $ booleanVarsOfTexp te) - Core.Program assignments e = compileTExpToProgram te + Program assignments e = compileTExpToProgram te traverse_ cs_of_assignment assignments -- e = do_const_prop e0 -- Snarkl.Compile 'e' to constraints 'cs', with output wire 'out'. - cs_of_assignment $ Core.Assignment (_Var # out) e + cs_of_assignment $ Assignment (_Var # out) e -- Add boolean constraints mapM_ ensure_boolean boolean_in_vars cs <- get_constraints @@ -540,7 +549,7 @@ compileCompToR1CS simpl = compileConstraintsToR1CS simpl . compileCompToConstrai -------------------------------------------------------------------------------- -_Var :: Iso' Core.Variable Var -_Var = iso (\(Core.Variable v) -> Var v) (\(Var v) -> Core.Variable v) +_Var :: Iso' Variable Var +_Var = iso (\(Variable v) -> Var v) (\(Var v) -> Variable v) -------------------------------------------------------------------------------- diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index b56ce65..6fe2f02 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -1,148 +1,53 @@ -{-# LANGUAGE NoImplicitPrelude #-} - module Snarkl.Language - ( -- | Snarkl.Language.TExpr, - TExp, + ( module Snarkl.Language.Core, + module Snarkl.Language.Expr, + module Snarkl.Language.LambdaExpr, + module Snarkl.Language.SyntaxMonad, + module Snarkl.Language.TExpr, module Snarkl.Language.Type, - -- | SyntaxMonad and Syntax - Comp, - runComp, - return, - (>>=), - (>>), - -- | Return a fresh input variable. - fresh_input, - -- | Classes - Zippable, - Derive, - -- | Basic values - unit, - false, - true, - fromField, - -- | Sums, products, recursive types - inl, - inr, - case_sum, - pair, - fst_pair, - snd_pair, - roll, - unroll, - fixN, - fix, - -- | Arithmetic and boolean operations - (+), - (-), - (*), - (/), - (&&), - zeq, - not, - xor, - eq, - beq, - exp_of_int, - inc, - dec, - ifThenElse, - negate, - -- | Arrays - arr, - arr2, - arr3, - input_arr, - input_arr2, - input_arr3, - set, - set2, - set3, - set4, - get, - get2, - get3, - get4, - -- | Iteration - iter, - iterM, - bigsum, - times, - forall, - forall2, - forall3, - -- | Function combinators - lambda, - curry, - uncurry, - apply, ) where -import Snarkl.Language.Syntax - ( Derive, - Zippable, - apply, +import Snarkl.Language.Core (Assignment (..), Exp (..), Program (..), Variable (..)) +import Snarkl.Language.Expr (mkProgram) +import Snarkl.Language.LambdaExpr (expOfLambdaExp) +import Snarkl.Language.SyntaxMonad + ( Comp, + Env (..), + State (..), arr, - arr2, - arr3, - beq, - bigsum, - case_sum, - curry, - dec, - eq, - exp_of_int, - fix, - fixN, - forall, - forall2, - forall3, - fromField, + assert_bot, + assert_false, + assert_true, + defaultEnv, + false, + fresh_input, + fresh_var, fst_pair, get, - get2, - get3, - get4, - ifThenElse, - inc, - inl, + guard, input_arr, - input_arr2, - input_arr3, - inr, - iter, - iterM, - lambda, - negate, - not, + is_bot, + is_false, + is_true, pair, - roll, - set, - set2, - set3, - set4, - snd_pair, - times, - uncurry, - unroll, - xor, - zeq, - (&&), - (*), - (+), - (-), - (/), - ) -import Snarkl.Language.SyntaxMonad - ( Comp, - false, - fresh_input, + raise_err, return, runComp, + runState, + set, + snd_pair, true, unit, (>>), (>>=), ) -import Snarkl.Language.TExpr (TExp) +import Snarkl.Language.TExpr + ( TExp (..), + TOp (..), + TUnop (..), + Val (..), + booleanVarsOfTexp, + tExpToLambdaExp, + ) import Snarkl.Language.Type diff --git a/src/Snarkl/Language/Syntax.hs b/src/Snarkl/Syntax.hs similarity index 95% rename from src/Snarkl/Language/Syntax.hs rename to src/Snarkl/Syntax.hs index 7cae059..371625d 100644 --- a/src/Snarkl/Language/Syntax.hs +++ b/src/Snarkl/Syntax.hs @@ -1,8 +1,27 @@ {-# LANGUAGE RebindableSyntax #-} -module Snarkl.Language.Syntax - ( Zippable, +module Snarkl.Syntax + ( -- | Snarkl.Language.TExpr, + TExp, + Ty (..), + TFunct (..), + Rep, + -- | SyntaxMonad and Syntax + Comp, + runComp, + return, + (>>=), + (>>), + -- | Return a fresh input variable. + fresh_input, + -- | Classes + Zippable, Derive, + -- | Basic values + unit, + false, + true, + fromField, -- | Sums, products, recursive types inl, inr, @@ -28,7 +47,6 @@ module Snarkl.Language.Syntax exp_of_int, inc, dec, - fromField, ifThenElse, negate, -- | Arrays @@ -54,10 +72,22 @@ module Snarkl.Language.Syntax forall, forall2, forall3, + -- | Function combinators lambda, curry, uncurry, apply, + assert_bot, + assert_false, + assert_true, + defaultEnv, + fresh_var, + guard, + is_bot, + is_false, + is_true, + raise_err, + runState, ) where @@ -65,8 +95,8 @@ import Data.Field.Galois (GaloisField) import Data.String (IsString (..)) import Data.Typeable (Typeable) import Snarkl.Common - ( Op (Add, And, BEq, Div, Eq, Mult, Sub, XOr), - UnOp (ZEq), + ( Op (..), + UnOp (..), ) import Snarkl.Errors (ErrMsg (ErrMsg)) import Snarkl.Language.SyntaxMonad @@ -77,7 +107,9 @@ import Snarkl.Language.SyntaxMonad assert_bot, assert_false, assert_true, + defaultEnv, false, + fresh_input, fresh_var, fst_pair, get, @@ -89,11 +121,13 @@ import Snarkl.Language.SyntaxMonad pair, raise_err, return, + runComp, runState, set, snd_pair, true, unit, + (>>), (>>=), ) import Snarkl.Language.TExpr @@ -102,7 +136,7 @@ import Snarkl.Language.TExpr TUnop (TUnop), Val (VFalse, VField, VTrue, VUnit), ) -import Snarkl.Language.Type (Rep, Ty (..)) +import Snarkl.Language.Type (Rep, TFunct (..), Ty (..)) import Unsafe.Coerce (unsafeCoerce) import Prelude hiding ( curry, diff --git a/tests/Test/Snarkl/LambdaSpec.hs b/tests/Test/Snarkl/LambdaSpec.hs index 08f1f5d..e6b3fa9 100644 --- a/tests/Test/Snarkl/LambdaSpec.hs +++ b/tests/Test/Snarkl/LambdaSpec.hs @@ -6,7 +6,10 @@ module Test.Snarkl.LambdaSpec where import Snarkl.Field (F_BN128) -import Snarkl.Language.Syntax +import qualified Snarkl.Language.SyntaxMonad as SM +import Snarkl.Language.TExpr (TExp) +import Snarkl.Language.Type (Ty (TField, TFun, TProd)) +import Snarkl.Syntax ( apply, curry, lambda, @@ -15,9 +18,6 @@ import Snarkl.Language.Syntax (*), (+), ) -import qualified Snarkl.Language.SyntaxMonad as SM -import Snarkl.Language.TExpr (TExp) -import Snarkl.Language.Type (Ty (TField, TFun, TProd)) import Snarkl.Toplevel (comp_interp) import Test.Hspec (Spec, describe, it) import Test.QuickCheck (Testable (property)) diff --git a/tests/Test/Snarkl/Unit/Programs.hs b/tests/Test/Snarkl/Unit/Programs.hs index fac87c6..b915f25 100644 --- a/tests/Test/Snarkl/Unit/Programs.hs +++ b/tests/Test/Snarkl/Unit/Programs.hs @@ -13,7 +13,7 @@ import Snarkl.Example.List import Snarkl.Example.Peano import Snarkl.Example.Tree import Snarkl.Field (F_BN128) -import Snarkl.Language +import Snarkl.Syntax import Prelude hiding ( fromRational, negate,