diff --git a/.github/workflows/nix-ci.yml b/.github/workflows/nix-ci.yml index 8f62193..0177596 100644 --- a/.github/workflows/nix-ci.yml +++ b/.github/workflows/nix-ci.yml @@ -20,11 +20,11 @@ jobs: repo: torsion-labs/arkworks-bridge tag: v0.2.0 - - uses: cachix/install-nix-action@v22 + - uses: cachix/install-nix-action@v24 with: nix_path: nixpkgs=channel:nixos-unstable - - uses: cachix/cachix-action@v12 + - uses: cachix/cachix-action@v13 with: name: martyall authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' diff --git a/app/Main.hs b/app/Main.hs index 5d2adc0..f8dd48d 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -4,10 +4,20 @@ import Control.Monad (unless) import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (PrimeField) import Data.Typeable (Typeable) -import Snarkl.Compile (SimplParam (NoSimplify)) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Field +import Snarkl.Field (F_BN128) +import Snarkl.Language (Comp) import Snarkl.Toplevel + ( Result (..), + SimplParam (..), + execute, + mkInputsFilePath, + mkR1CSFilePath, + mkWitnessFilePath, + serializeInputsAsJson, + serializeR1CSAsJson, + serializeWitnessAsJson, + ) import qualified Test.Snarkl.Unit.Programs as Programs main :: IO () diff --git a/benchmarks/Harness.hs b/benchmarks/Harness.hs index bb3c7e2..42c22d2 100644 --- a/benchmarks/Harness.hs +++ b/benchmarks/Harness.hs @@ -10,28 +10,25 @@ import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (GaloisField, Prime, PrimeField) import qualified Data.Map as Map import qualified Data.Set as Set -import Data.Typeable -import GHC.IO.Exception +import Data.Typeable (Typeable) +import GHC.IO.Exception (ExitCode (..)) import GHC.TypeLits (KnownNat) import Snarkl.Compile (SimplParam) +import Snarkl.Constraint (ConstraintSystem (..), do_simplify) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) +import Snarkl.Language (Comp) import Snarkl.Language.TExpr ( TExp (..), lastSeq, ) import Snarkl.Toplevel - ( Comp, - ConstraintSystem (..), - Result (..), - TExp (..), + ( Result (..), TExpPkg (..), comp_interp, compileCompToR1CS, compileCompToTexp, compileTexpToConstraints, - do_simplify, execute, - lastSeq, serializeR1CSAsJson, serializeWitnessAsJson, wit_of_r1cs, diff --git a/benchmarks/Main.hs b/benchmarks/Main.hs index 8a21d4a..87562e8 100644 --- a/benchmarks/Main.hs +++ b/benchmarks/Main.hs @@ -14,7 +14,7 @@ import qualified Snarkl.Example.Keccak as Keccak import qualified Snarkl.Example.List as List import qualified Snarkl.Example.Matrix as Matrix import Snarkl.Field (F_BN128, P_BN128) -import Snarkl.Language (Comp, fromField) +import Snarkl.Syntax (Comp, fromField) mk_bgroup :: (Typeable ty) => String -> Comp ty F_BN128 -> [Int] -> F_BN128 -> Benchmark mk_bgroup nm mf inputs result = diff --git a/examples/Snarkl/Example/Basic.hs b/examples/Snarkl/Example/Basic.hs index 7a86fd1..bac359a 100644 --- a/examples/Snarkl/Example/Basic.hs +++ b/examples/Snarkl/Example/Basic.hs @@ -7,10 +7,8 @@ import Data.Typeable (Typeable) import GHC.TypeLits (KnownNat) import Snarkl.Compile import Snarkl.Field (F_BN128) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Syntax +import Snarkl.Toplevel (R1CS, comp_interp) import System.Exit (ExitCode) import Prelude hiding ( fromRational, @@ -31,7 +29,7 @@ mult_ex :: Comp 'TField k mult_ex x y = return $ x * y -arr_ex :: (GaloisField k) => TExp 'TField k -> Comp 'TField k +arr_ex :: TExp 'TField k -> Comp 'TField k arr_ex x = do a <- arr 2 forall [0 .. 1] (\i -> set (a, i) x) @@ -63,13 +61,13 @@ interp2' = comp_interp p2 [256] compile1 :: (GaloisField k) => R1CS k compile1 = compileCompToR1CS Simplify p1 -comp1 :: (GaloisField k, Typeable a) => Comp ('TSum 'TBool a) k +comp1 :: (Typeable a) => Comp ('TSum 'TBool a) k comp1 = inl false comp2 :: (GaloisField k, Typeable a) => Comp ('TSum a 'TField) k comp2 = inr (fromField 0) -test1 :: (GaloisField k) => State (Env k) (TExp 'TBool k) +test1 :: (GaloisField k) => Comp 'TBool k test1 = do b <- fresh_input z <- if return b then comp1 else comp2 diff --git a/examples/Snarkl/Example/Games.hs b/examples/Snarkl/Example/Games.hs index e152e2c..de7daf6 100644 --- a/examples/Snarkl/Example/Games.hs +++ b/examples/Snarkl/Example/Games.hs @@ -6,15 +6,13 @@ module Snarkl.Example.Games where -import Data.Field.Galois (GaloisField, Prime) +import Data.Field.Galois (GaloisField) +import Data.Kind (Type) import Data.Typeable -import GHC.TypeLits (KnownNat, Nat) import Snarkl.Errors import Snarkl.Field (F_BN128) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Syntax +import Snarkl.Toplevel (comp_interp) import Prelude hiding ( fromRational, negate, @@ -38,7 +36,7 @@ data ISO (t :: Ty) (s :: Ty) k = Iso from :: TExp s k -> Comp t k } -data Game :: Ty -> * -> * where +data Game :: Ty -> Type -> Type where Single :: forall (s :: Ty) (t :: Ty) k. ( Typeable s, @@ -108,8 +106,7 @@ sum_game :: Zippable t1 k, Zippable t2 k, Derive t1 k, - Derive t2 k, - GaloisField k + Derive t2 k ) => Game t1 k -> Game t2 k -> @@ -133,9 +130,7 @@ t2 :: F_BN128 t2 = comp_interp basic_test [1, 23, 88] -- 88 (+>) :: - ( Typeable t, - Typeable s, - Zippable t k, + ( Typeable s, Zippable s k ) => Game t k -> @@ -151,8 +146,7 @@ prodI :: ( Typeable a, Typeable b, Typeable c, - Typeable d, - GaloisField k + Typeable d ) => ISO a b k -> ISO c d k -> @@ -174,13 +168,12 @@ prodI (Iso f g) (Iso f' g') = pair y1 y2 ) -seqI :: (Typeable b) => ISO a b p -> ISO b c p -> ISO a c p +seqI :: ISO a b p -> ISO b c p -> ISO a c p seqI (Iso f g) (Iso f' g') = Iso (\a -> f a >>= f') (\c -> g' c >>= g) prodLInputI :: ( Typeable a, - Typeable b, - GaloisField k + Typeable b ) => ISO ('TProd a b) b k prodLInputI = @@ -200,8 +193,7 @@ prodLSumI :: Zippable c k, Derive a k, Derive b k, - Derive c k, - GaloisField k + Derive c k ) => ISO ('TProd ('TSum b c) a) ('TSum ('TProd b a) ('TProd c a)) k prodLSumI = @@ -301,8 +293,7 @@ instance (GaloisField k) => Gameable 'TUnit k where mkGame = unit_game instance - ( Typeable a, - Typeable b, + ( Typeable b, Zippable a k, Zippable b k, Derive a k, @@ -323,8 +314,7 @@ instance Derive a k, Derive b k, Gameable a k, - Gameable b k, - GaloisField k + Gameable b k ) => Gameable ('TSum a b) k where diff --git a/examples/Snarkl/Example/Keccak.hs b/examples/Snarkl/Example/Keccak.hs index f0573e3..bf1085c 100644 --- a/examples/Snarkl/Example/Keccak.hs +++ b/examples/Snarkl/Example/Keccak.hs @@ -9,10 +9,7 @@ import Data.Bits hiding (xor) import Data.Field.Galois (GaloisField, Prime) import qualified Data.Map.Strict as Map import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, @@ -35,7 +32,6 @@ ln_width :: Int ln_width = 32 round1 :: - (GaloisField k) => (Int -> TExp 'TBool k) -> -- | 'i'th bit of round constant TExp ('TArr ('TArr ('TArr 'TBool))) k -> @@ -198,7 +194,7 @@ keccak_f1 num_rounds a = ) -- num_rounds = 12+2l, where 2^l = ln_width -keccak1 :: (GaloisField k) => Int -> Comp 'TBool k +keccak1 :: Int -> Comp 'TBool k keccak1 num_rounds = do a <- input_arr3 5 5 ln_width diff --git a/examples/Snarkl/Example/Lam.hs b/examples/Snarkl/Example/Lam.hs index 7f7c184..d4fc4ab 100644 --- a/examples/Snarkl/Example/Lam.hs +++ b/examples/Snarkl/Example/Lam.hs @@ -10,9 +10,7 @@ import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) import Snarkl.Errors -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, @@ -35,7 +33,7 @@ type TFSubst = 'TFSum ('TFConst 'TField) ('TFProd ('TFConst TTerm) 'TFId) type TSubst = 'TMu TFSubst -subst_nil :: (GaloisField k) => TExp 'TField k -> State (Env k) (TExp TSubst k) +subst_nil :: TExp 'TField k -> Comp TSubst k subst_nil n = do n' <- inl n @@ -68,7 +66,6 @@ type TF = 'TFSum ('TFConst 'TField) ('TFSum 'TFId ('TFProd 'TFId 'TFId)) type TTerm = 'TMu TF varN :: - (GaloisField k) => TExp 'TField k -> Comp TTerm k varN e = @@ -86,7 +83,6 @@ varN' i = roll v lam :: - (GaloisField k) => TExp TTerm k -> Comp TTerm k lam t = @@ -96,7 +92,6 @@ lam t = roll v app :: - (GaloisField k) => TExp TTerm k -> TExp TTerm k -> Comp TTerm k @@ -108,9 +103,7 @@ app t1 t2 = roll v case_term :: - ( Typeable ty, - Zippable ty k, - GaloisField k + ( Zippable ty k ) => TExp TTerm k -> (TExp 'TField k -> Comp ty k) -> @@ -128,7 +121,7 @@ case_term t f_var f_lam f_app = e2 <- fst_pair p f_app e1 e2 -is_lam :: (GaloisField k) => TExp TTerm k -> Comp 'TBool k +is_lam :: TExp TTerm k -> Comp 'TBool k is_lam t = case_term t @@ -159,7 +152,7 @@ shift n t = fix go t app t1' t2' ) -compose :: (GaloisField k) => TExp TSubst k -> TExp ('TMu TFSubst) k -> State (Env k) (TExp ('TMu TFSubst) k) +compose :: (GaloisField k) => TExp TSubst k -> TExp ('TMu TFSubst) k -> Comp ('TMu TFSubst) k compose sigma1 sigma2 = do p <- pair sigma1 sigma2 @@ -192,7 +185,7 @@ compose sigma1 sigma2 = subst_cons t'' s2'' ) -subst_term :: (GaloisField k) => TExp ('TMu TFSubst) k -> TExp TTerm k -> State (Env k) (TExp ('TMu TF) k) +subst_term :: (GaloisField k) => TExp ('TMu TFSubst) k -> TExp TTerm k -> Comp ('TMu TF) k subst_term sigma t = do p <- pair sigma t diff --git a/examples/Snarkl/Example/List.hs b/examples/Snarkl/Example/List.hs index fe2157b..ecb422e 100644 --- a/examples/Snarkl/Example/List.hs +++ b/examples/Snarkl/Example/List.hs @@ -5,10 +5,7 @@ module Snarkl.Example.List where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Syntax import Prelude hiding ( negate, return, @@ -27,12 +24,12 @@ type TList a = 'TMu (TF a) type List a k = TExp (TList a) k -nil :: (Typeable a, GaloisField k) => Comp (TList a) k +nil :: (Typeable a) => Comp (TList a) k nil = do t <- inl unit roll t -cons :: (Typeable a, GaloisField k) => TExp a k -> List a k -> Comp (TList a) k +cons :: (Typeable a) => TExp a k -> List a k -> Comp (TList a) k cons f t = do p <- pair f t @@ -41,9 +38,7 @@ cons f t = case_list :: ( Typeable a, - Typeable ty, - Zippable ty k, - GaloisField k + Zippable ty k ) => List a k -> Comp ty k -> @@ -62,9 +57,7 @@ case_list t f_nil f_cons = head_list :: ( Typeable a, - Zippable a k, - Derive a k, - GaloisField k + Zippable a k ) => TExp a k -> List a k -> @@ -78,8 +71,7 @@ head_list def l = tail_list :: ( Typeable a, Zippable a k, - Derive a k, - GaloisField k + Derive a k ) => List a k -> Comp (TList a) k @@ -96,8 +88,7 @@ tail_list l = app_list :: ( Typeable a, Zippable a k, - Derive a k, - GaloisField k + Derive a k ) => List a k -> List a k -> @@ -116,8 +107,7 @@ app_list l1 l2 = fix go l1 rev_list :: ( Typeable a, Zippable a k, - Derive a k, - GaloisField k + Derive a k ) => List a k -> Comp (TList a) k @@ -136,12 +126,9 @@ rev_list l = fix go l map_list :: ( Typeable a, - Zippable a k, - Derive a k, Typeable b, Zippable b k, - Derive b k, - GaloisField k + Derive b k ) => (TExp a k -> Comp b k) -> List a k -> @@ -161,7 +148,7 @@ map_list f l = ) last_list :: - (Typeable a, Zippable a k, Derive a k, GaloisField k) => + (Typeable a, Zippable a k) => TExp a k -> List a k -> Comp a k diff --git a/examples/Snarkl/Example/Matrix.hs b/examples/Snarkl/Example/Matrix.hs index 9ac950a..cfa2a09 100644 --- a/examples/Snarkl/Example/Matrix.hs +++ b/examples/Snarkl/Example/Matrix.hs @@ -4,10 +4,8 @@ module Snarkl.Example.Matrix where import Data.Field.Galois (GaloisField, Prime) import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Syntax +import Snarkl.Toplevel (comp_interp) import Prelude hiding ( fromRational, negate, diff --git a/examples/Snarkl/Example/Peano.hs b/examples/Snarkl/Example/Peano.hs index 9d47e7f..8cbb443 100644 --- a/examples/Snarkl/Example/Peano.hs +++ b/examples/Snarkl/Example/Peano.hs @@ -4,9 +4,7 @@ module Snarkl.Example.Peano where import Data.Field.Galois (GaloisField, Prime) import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, @@ -24,13 +22,13 @@ type TF = 'TFSum ('TFConst 'TUnit) 'TFId type TNat = 'TMu TF -nat_zero :: (GaloisField k) => Comp TNat k +nat_zero :: Comp TNat k nat_zero = do x <- inl unit roll x -nat_succ :: (GaloisField k) => TExp TNat k -> Comp TNat k +nat_succ :: TExp TNat k -> Comp TNat k nat_succ n = do x <- inr n diff --git a/examples/Snarkl/Example/Queue.hs b/examples/Snarkl/Example/Queue.hs index b306692..c3d526f 100644 --- a/examples/Snarkl/Example/Queue.hs +++ b/examples/Snarkl/Example/Queue.hs @@ -2,16 +2,11 @@ module Snarkl.Example.Queue where -import Data.Field.Galois (GaloisField, Prime) +import Data.Field.Galois (GaloisField) import Data.Typeable -import GHC.TypeLits (KnownNat) -import Snarkl.Compile import Snarkl.Example.List import Snarkl.Example.Stack -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, @@ -29,14 +24,14 @@ type TQueue a = 'TProd (TStack a) (TStack a) type Queue a k = TExp (TQueue a) k -empty_queue :: (Typeable a, GaloisField k) => Comp (TQueue a) k +empty_queue :: (Typeable a) => Comp (TQueue a) k empty_queue = do l <- empty_stack r <- empty_stack pair l r enqueue :: - (Zippable a k, Derive a k, Typeable a, GaloisField k) => + (Typeable a) => TExp a k -> Queue a k -> Comp (TQueue a) k @@ -47,7 +42,7 @@ enqueue v q = do pair l' r dequeue :: - (Zippable a k, Derive a k, Typeable a, GaloisField k) => + (Zippable a k, Derive a k, Typeable a) => Queue a k -> TExp a k -> Comp ('TProd a (TQueue a)) k @@ -75,7 +70,7 @@ dequeue q def = do pair h p dequeue_rec :: - (Zippable a k, Derive a k, Typeable a, GaloisField k) => + (Zippable a k, Derive a k, Typeable a) => Queue a k -> TExp a k -> Comp ('TProd a (TQueue a)) k @@ -115,7 +110,7 @@ is_empty q = do (\_ _ -> return false) last_queue :: - (Zippable a k, Derive a k, Typeable a, GaloisField k) => + (Zippable a k, Derive a k, Typeable a) => Queue a k -> TExp a k -> Comp a k diff --git a/examples/Snarkl/Example/Stack.hs b/examples/Snarkl/Example/Stack.hs index 7c54c95..0680743 100644 --- a/examples/Snarkl/Example/Stack.hs +++ b/examples/Snarkl/Example/Stack.hs @@ -7,10 +7,7 @@ import Data.Typeable import GHC.TypeLits (KnownNat) import Snarkl.Compile import Snarkl.Example.List -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, @@ -28,19 +25,19 @@ type TStack a = TList a type Stack a k = TExp (TStack a) k -empty_stack :: (Typeable a, GaloisField k) => Comp (TStack a) k +empty_stack :: (Typeable a) => Comp (TStack a) k empty_stack = nil -push_stack :: (Typeable a, GaloisField k) => TExp a k -> Stack a k -> Comp (TStack a) k +push_stack :: (Typeable a) => TExp a k -> Stack a k -> Comp (TStack a) k push_stack p q = cons p q -pop_stack :: (Derive a k, Zippable a k, Typeable a, GaloisField k) => Stack a k -> Comp (TStack a) k +pop_stack :: (Derive a k, Zippable a k, Typeable a) => Stack a k -> Comp (TStack a) k pop_stack f = tail_list f -top_stack :: (Derive a k, Zippable a k, Typeable a, GaloisField k) => TExp a k -> Stack a k -> Comp a k +top_stack :: (Zippable a k, Typeable a) => TExp a k -> Stack a k -> Comp a k top_stack def e = head_list def e -is_empty_stack :: (Typeable a, GaloisField k) => Stack a k -> Comp 'TBool k +is_empty_stack :: (Typeable a) => Stack a k -> Comp 'TBool k is_empty_stack s = case_list s (return true) (\_ _ -> return false) diff --git a/examples/Snarkl/Example/Tree.hs b/examples/Snarkl/Example/Tree.hs index 643ba0c..9a2ea8a 100644 --- a/examples/Snarkl/Example/Tree.hs +++ b/examples/Snarkl/Example/Tree.hs @@ -5,9 +5,7 @@ module Snarkl.Example.Tree where import Data.Field.Galois (GaloisField, Prime) import Data.Typeable import GHC.TypeLits (KnownNat) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, @@ -29,12 +27,12 @@ type Rat k = TExp 'TField k type Tree a k = TExp (TTree a) k -leaf :: (Typeable a, GaloisField k) => Comp (TTree a) k +leaf :: (Typeable a) => Comp (TTree a) k leaf = do t <- inl unit roll t -node :: (Typeable a, GaloisField k) => TExp a k -> Tree a k -> Tree a k -> Comp (TTree a) k +node :: (Typeable a) => TExp a k -> Tree a k -> Tree a k -> Comp (TTree a) k node v t1 t2 = do p <- pair t1 t2 p' <- pair v p @@ -43,8 +41,6 @@ node v t1 t2 = do case_tree :: ( Typeable a, - GaloisField k, - Typeable a1, Zippable a1 k ) => Tree a k -> @@ -66,10 +62,9 @@ map_tree :: ( Typeable a, Typeable a1, Zippable a1 k, - Derive a1 k, - GaloisField k + Derive a1 k ) => - (TExp a k -> State (Env k) (TExp a1 k)) -> + (TExp a k -> Comp a1 k) -> TExp (TTree a) k -> Comp (TTree a1) k map_tree f t = diff --git a/flake.lock b/flake.lock index f706238..18024b6 100644 --- a/flake.lock +++ b/flake.lock @@ -84,6 +84,20 @@ } }, "flake-compat": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "revCount": 57, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz" + } + }, + "flake-compat_2": { "flake": false, "locked": { "lastModified": 1672831974, @@ -195,7 +209,7 @@ "cabal-34": "cabal-34", "cabal-36": "cabal-36", "cardano-shell": "cardano-shell", - "flake-compat": "flake-compat", + "flake-compat": "flake-compat_2", "ghc-8.6.5-iohk": "ghc-8.6.5-iohk", "ghc98X": "ghc98X", "ghc99": "ghc99", @@ -594,6 +608,7 @@ }, "root": { "inputs": { + "flake-compat": "flake-compat", "flake-utils": "flake-utils", "haskellNix": "haskellNix", "nixpkgs": [ diff --git a/flake.nix b/flake.nix index 13af1e5..bdab0aa 100644 --- a/flake.nix +++ b/flake.nix @@ -30,12 +30,12 @@ packages = { lib = flake.packages."snarkl:lib:snarkl"; - print = flake.packages."snarkl:exe:print-examples"; + benchmark = flake.packages."snarkl:bench:criterion"; all = pkgs.symlinkJoin { name = "all"; paths = with packages; [ lib - print + benchmark ]; }; default = packages.all; diff --git a/print-examples/Main.hs b/print-examples/Main.hs deleted file mode 100644 index 63a7904..0000000 --- a/print-examples/Main.hs +++ /dev/null @@ -1,69 +0,0 @@ -{-# LANGUAGE FlexibleContexts #-} - -module Main where - -import Data.Foldable (traverse_) -import Prettyprinter -import Prettyprinter.Render.String (renderString) -import Snarkl.Field () -import Snarkl.Toplevel (compileCompToTexp) -import Test.Snarkl.Unit.Programs - -main :: IO () -main = do - traverse_ - printProg - [ ("Program 1", prog1), - ("Program 2", prog2 42), - ("Program 3", prog3), - ("Program 4", prog4), - ("Program 5", prog5), - ("Program 6", prog6), - ("Program 7", prog7), - ("Program 8", prog8), - ("Program 11", prog11), - ("Program 13", prog13), - ("Program 14", prog14), - ("Program 15", prog15), - ("Program 26", prog26), - ("Program 27", prog27), - ("Program 28", prog28), - ("Program 29", prog29), - ("Program 30", prog30), - ("Program 31", prog31), - ("Program 34", prog34), - ("Program 35", prog35), - ("Program 36", prog36) - ] - traverse_ - printProg - [ ("Bool Program 10", bool_prog10), - ("Bool Program 12", bool_prog12), - ("Bool Program 16", bool_prog16), - ("Bool Program 17", bool_prog17), - ("Bool Program 18", bool_prog18), - ("Bool Program 19", bool_prog19), - ("Bool Program 20", bool_prog20), - ("Bool Program 21", bool_prog21), - ("Bool Program 22", bool_prog22), - ("Bool Program 23", bool_prog23), - ("Bool Program 24", bool_prog24), - ("Bool Program 25", bool_prog25), - ("Bool Program 32", bool_prog32), - ("Bool Program 33", bool_prog33) - ] - where - printProg (name, prog) = do - let texp = compileCompToTexp prog - -- this is just a sanity check because of the new Eq instances - if texp /= texp - then putStrLn ("--| " <> name <> " (FAILED)") - else do - let doc = pretty texp - layout = layoutPretty defaultLayoutOptions doc - putStrLn (replicate 80 '-') - putStrLn "\n" - putStrLn ("--| " <> name) - putStrLn "\n" - putStrLn $ renderString layout - putStrLn "\n" diff --git a/snarkl.cabal b/snarkl.cabal index f6a9bfa..d08a292 100644 --- a/snarkl.cabal +++ b/snarkl.cabal @@ -25,7 +25,7 @@ source-repository head library ghc-options: - -Wall -Wredundant-constraints -Werror -funbox-strict-fields + -Wall -Werror -Wredundant-constraints -funbox-strict-fields -optc-O3 -- -threaded @@ -49,11 +49,13 @@ library Snarkl.Field Snarkl.Interp Snarkl.Language + Snarkl.Language.Core Snarkl.Language.Expr Snarkl.Language.LambdaExpr - Snarkl.Language.Syntax Snarkl.Language.SyntaxMonad Snarkl.Language.TExpr + Snarkl.Language.Type + Snarkl.Syntax Snarkl.Toplevel default-extensions: @@ -65,29 +67,34 @@ library GADTs GeneralizedNewtypeDeriving KindSignatures + LambdaCase + MultiParamTypeClasses OverloadedStrings PolyKinds RankNTypes ScopedTypeVariables StandaloneDeriving + TypeApplications TypeFamilies TypeSynonymInstances UndecidableInstances build-depends: aeson - , base >=4.7 + , base >=4.7 , bytestring - , Cabal >=1.22 - , containers >=0.5 && <0.7 - , galois-field >=1.0.4 - , hspec >=2.0 - , jsonl >=0.1.4 + , Cabal >=1.22 + , containers >=0.5 && <0.7 + , errors + , galois-field >=1.0.4 + , hspec >=2.0 + , jsonl >=0.1.4 , lens - , mtl >=2.2 && <2.3 - , parallel >=3.2 && <3.3 - , prettyprinter - , process >=1.2 + , mtl >=2.2 && <2.3 + , parallel >=3.2 && <3.3 + , process >=1.2 + , transformers + , wl-pprint-text hs-source-dirs: src default-language: Haskell2010 @@ -136,10 +143,12 @@ test-suite spec , QuickCheck , snarkl >=0.1.0.0 + ghc-options: -Wredundant-constraints + benchmark criterion type: exitcode-stdio-1.0 main-is: Main.hs - ghc-options: -threaded -O2 + ghc-options: -threaded -O2 other-modules: Harness Snarkl.Example.Basic @@ -182,38 +191,6 @@ benchmark criterion , process >=1.2 , snarkl >=0.1.0.0 -executable print-examples - main-is: Main.hs - other-modules: - Snarkl.Example.Basic - Snarkl.Example.Games - Snarkl.Example.Keccak - Snarkl.Example.Lam - Snarkl.Example.List - Snarkl.Example.Matrix - Snarkl.Example.Peano - Snarkl.Example.Queue - Snarkl.Example.Stack - Snarkl.Example.Tree - Test.Snarkl.Unit.Programs - - default-extensions: - DataKinds - GADTs - KindSignatures - RankNTypes - ScopedTypeVariables - - hs-source-dirs: print-examples examples tests - default-language: Haskell2010 - build-depends: - base >=4.7 - , containers - , galois-field >=1.0.4 - , hspec >=2.0 - , prettyprinter - , snarkl >=0.1.0.0 - executable compile main-is: Main.hs other-modules: @@ -239,10 +216,8 @@ executable compile hs-source-dirs: app examples tests default-language: Haskell2010 build-depends: - base >=4.7 + base >=4.7 , bytestring , containers - , galois-field >=1.0.4 - , hspec >=2.0 - , prettyprinter - , snarkl >=0.1.0.0 + , galois-field >=1.0.4 + , snarkl >=0.1.0.0 diff --git a/src/Snarkl/Backend/R1CS/Poly.hs b/src/Snarkl/Backend/R1CS/Poly.hs index 72b07ca..eeaa1af 100644 --- a/src/Snarkl/Backend/R1CS/Poly.hs +++ b/src/Snarkl/Backend/R1CS/Poly.hs @@ -1,19 +1,17 @@ -{-# LANGUAGE InstanceSigs #-} - module Snarkl.Backend.R1CS.Poly where import qualified Data.Aeson as A import Data.Field.Galois (GaloisField, PrimeField, fromP) import qualified Data.Map as Map -import Prettyprinter (Pretty (..)) import Snarkl.Common +import Text.PrettyPrint.Leijen.Text (Pretty (..)) -data Poly a where - Poly :: (GaloisField a) => Assgn a -> Poly a +data Poly k where + Poly :: (GaloisField k) => Assgn k -> Poly k -deriving instance (Show a) => Show (Poly a) +deriving instance Show (Poly k) -instance (Pretty a) => Pretty (Poly a) where +instance Pretty (Poly k) where pretty (Poly m) = pretty $ Map.toList m -- The reason we use incVar is that we want to use -1 internally as the constant @@ -21,22 +19,21 @@ instance (Pretty a) => Pretty (Poly a) where -- harder to work with downstream where e.g. arkworks expects positive indices). -- The reason we use show is because it's hard to deserialize large integers -- in certain langauges (e.g. javascript, even rust). -instance (PrimeField a) => A.ToJSON (Poly a) where - toJSON :: Poly a -> A.Value +instance (PrimeField k) => A.ToJSON (Poly k) where toJSON (Poly m) = let kvs = map (\(var, coeff) -> (show $ fromP coeff, incVar var)) $ Map.toList m in A.toJSON kvs -- | The constant polynomial equal 'c' -const_poly :: (GaloisField a) => a -> Poly a +const_poly :: (GaloisField k) => k -> Poly k const_poly c = Poly $ Map.insert (Var (-1)) c Map.empty -- | The polynomial equal variable 'x' var_poly :: - (GaloisField a) => + (GaloisField k) => -- | Variable, with coeff - (a, Var) -> + (k, Var) -> -- | Resulting polynomial - Poly a + Poly k var_poly (coeff, x) = Poly $ Map.insert x coeff Map.empty diff --git a/src/Snarkl/Backend/R1CS/R1CS.hs b/src/Snarkl/Backend/R1CS/R1CS.hs index 0e47e12..e077fd6 100644 --- a/src/Snarkl/Backend/R1CS/R1CS.hs +++ b/src/Snarkl/Backend/R1CS/R1CS.hs @@ -11,19 +11,19 @@ import Control.Parallel.Strategies import qualified Data.Aeson as A import Data.Field.Galois (GaloisField, PrimeField) import qualified Data.Map as Map -import Prettyprinter (Pretty (..), (<+>)) import Snarkl.Backend.R1CS.Poly import Snarkl.Common import Snarkl.Errors +import Text.PrettyPrint.Leijen.Text (Pretty (..), (<+>)) ---------------------------------------------------------------- -- Rank-1 Constraint Systems -- ---------------------------------------------------------------- -data R1C a where - R1C :: (GaloisField a) => (Poly a, Poly a, Poly a) -> R1C a +data R1C k where + R1C :: (GaloisField k) => (Poly k, Poly k, Poly k) -> R1C k -deriving instance (Show a) => Show (R1C a) +deriving instance (Show k) => Show (R1C k) instance (PrimeField k) => A.ToJSON (R1C k) where toJSON (R1C (a, b, c)) = @@ -33,7 +33,7 @@ instance (PrimeField k) => A.ToJSON (R1C k) where "C" A..= c ] -instance (Pretty a) => Pretty (R1C a) where +instance Pretty (R1C k) where pretty (R1C (aV, bV, cV)) = pretty aV <+> "*" <+> pretty bV <+> "==" <+> pretty cV data R1CS a = R1CS @@ -44,19 +44,19 @@ data R1CS a = R1CS r1cs_gen_witness :: Assgn a -> Assgn a } -instance (Show a) => Show (R1CS a) where +instance (Show k) => Show (R1CS k) where show (R1CS cs nvs ivs ovs _) = show (cs, nvs, ivs, ovs) -num_constraints :: R1CS a -> Int +num_constraints :: R1CS k -> Int num_constraints = length . r1cs_clauses -- sat_r1c: Does witness 'w' satisfy constraint 'c'? -sat_r1c :: (GaloisField a) => Assgn a -> R1C a -> Bool +sat_r1c :: (GaloisField k) => Assgn k -> R1C k -> Bool sat_r1c w c | R1C (aV, bV, cV) <- c = inner aV w * inner bV w == inner cV w where - inner :: (GaloisField a) => Poly a -> Assgn a -> a + inner :: (GaloisField k) => Poly k -> Assgn k -> k inner (Poly v) w' = let c0 = Map.findWithDefault 0 (Var (-1)) v in Map.foldlWithKey (f w') c0 v @@ -65,7 +65,7 @@ sat_r1c w c (v_val * Map.findWithDefault 0 v_key w') + acc -- sat_r1cs: Does witness 'w' satisfy constraint set 'cs'? -sat_r1cs :: (GaloisField a) => Assgn a -> R1CS a -> Bool +sat_r1cs :: (GaloisField k) => Assgn k -> R1CS k -> Bool sat_r1cs w cs = and $ is_sat (r1cs_clauses cs) where is_sat cs0 = map g cs0 `using` parListChunk (chunk_sz cs0) rseq diff --git a/src/Snarkl/Common.hs b/src/Snarkl/Common.hs index 708a954..953d841 100644 --- a/src/Snarkl/Common.hs +++ b/src/Snarkl/Common.hs @@ -4,7 +4,7 @@ module Snarkl.Common where import qualified Data.Aeson as A import qualified Data.Map as Map -import Prettyprinter (Pretty (pretty)) +import Text.PrettyPrint.Leijen.Text (Pretty (pretty)) newtype Var = Var Int deriving (Eq, Ord, Show, A.ToJSON) diff --git a/src/Snarkl/Compile.hs b/src/Snarkl/Compile.hs index e17aec1..2b83f7a 100644 --- a/src/Snarkl/Compile.hs +++ b/src/Snarkl/Compile.hs @@ -10,6 +10,7 @@ module Snarkl.Compile compileCompToTexp, compileTexpToConstraints, compileCompToConstraints, + compileTExpToProgram, ) where @@ -22,11 +23,11 @@ import Control.Monad.State import qualified Control.Monad.State as State import Data.Either (fromRight) import Data.Field.Galois (GaloisField) +import Data.Foldable (traverse_) import Data.List (sort) import qualified Data.Map as Map import qualified Data.Set as Set import Data.Typeable (Typeable) -import Prettyprinter (Pretty (..)) import Snarkl.Backend.R1CS.R1CS (R1CS) import Snarkl.Common (Op (..), UnOp (..), Var (Var), incVar) import Snarkl.Constraint @@ -44,17 +45,20 @@ import Snarkl.Constraint ) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Language - ( Comp, - Env (Env, input_vars, next_variable), + ( Assignment (..), + Comp, + Env (..), Exp (..), + Program (..), TExp, - Variable (Variable), + Variable (..), booleanVarsOfTexp, - do_const_prop, - expOfTExp, - runState, - var_of_exp, + expOfLambdaExp, + mkProgram, + runComp, + tExpToLambdaExp, ) +import Text.PrettyPrint.Leijen.Text (Pretty (..)) ---------------------------------------------------------------- -- @@ -166,12 +170,11 @@ encode_boolean_eq (x, y, z) = cs_of_exp z e -- | Constraint 'x == y = z'. -- The encoding is: z = (x-y == 0). encode_eq :: (GaloisField a) => (Var, Var, Var) -> State (CEnv a) () -encode_eq (x, y, z) = cs_of_exp z e - where - e = - EAssert - (EVar (_Var # z)) - (EUnop ZEq (EBinop Sub [EVar (_Var # x), EVar (_Var # y)])) +encode_eq (x, y, z) = + cs_of_assignment $ + Assignment + (_Var # z) + (EUnop ZEq (EBinop Sub [EVar (_Var # x), EVar (_Var # y)])) -- | Constraint 'y = x!=0 ? 1 : 0'. -- The encoding is: @@ -249,19 +252,19 @@ encode_binop op (x, y, z) = go op add_constraint $ CMult (1, y) (1, z) (1, Just x) -encode_linear :: (GaloisField a) => Var -> [Either (Var, a) a] -> State (CEnv a) () +encode_linear :: (GaloisField k) => Var -> [Either (Var, k) k] -> State (CEnv k) () encode_linear out xs = let c = foldl (flip (+)) 0 $ map (fromRight 0) xs in add_constraint $ cadd c $ (out, -1) : remove_consts xs where - remove_consts :: [Either (Var, a) a] -> [(Var, a)] + remove_consts :: [Either (Var, k) k] -> [(Var, k)] remove_consts [] = [] remove_consts (Left p : l) = p : remove_consts l remove_consts (Right _ : l) = remove_consts l -cs_of_exp :: (GaloisField a) => Var -> Exp a -> State (CEnv a) () +cs_of_exp :: (GaloisField k) => Var -> Exp k -> State (CEnv k) () cs_of_exp out e = case e of EVar x -> ensure_equal (out, view _Var x) @@ -287,7 +290,7 @@ cs_of_exp out e = case e of -- We special-case linear combinations in this way to avoid having -- to introduce new multiplication gates for multiplication by -- constant scalars. - let go_linear :: (GaloisField a) => [Exp a] -> State (CEnv a) [Either (Var, a) a] + let go_linear :: (GaloisField k) => [Exp k] -> State (CEnv k) [Either (Var, k) k] go_linear [] = return [] go_linear (EBinop Mult [EVar x, EVal coeff] : es') = do @@ -338,7 +341,7 @@ cs_of_exp out e = case e of rev_pol (Left (x, c) : ls) = Left (x, -c) : rev_pol ls rev_pol (Right c : ls) = Right (-c) : rev_pol ls - go_other :: (GaloisField a) => [Exp a] -> State (CEnv a) [Var] + go_other :: (GaloisField k) => [Exp k] -> State (CEnv k) [Var] go_other [] = return [] go_other (EVar x : es') = do @@ -386,30 +389,17 @@ cs_of_exp out e = case e of [ EBinop Mult [b, e1], EBinop Mult [EBinop Sub [EVal 1, b], e2] ] - - -- NOTE: when compiling assignments, the naive thing to do is - -- to introduce a new var, e2_out, bound to result of e2 and - -- then ensure that e2_out == x. We optimize by passing x to - -- compilation of e2 directly. - EAssert e1 e2 -> - do - let x = var_of_exp e1 - cs_of_exp (x ^. _Var) e2 - ESeq le -> - do - x <- fresh_var -- x is garbage - go x le - where - go _ [] = failWith $ ErrMsg "internal error: empty ESeq" - go _ [e1] = cs_of_exp out e1 - go garbage_var (e1 : le') = - do - cs_of_exp garbage_var e1 - go garbage_var le' EUnit -> -- NOTE: [[ EUnit ]]_{out} = [[ EVal zero ]]_{out}. cs_of_exp out (EVal 0) +---- NOTE: when compiling assignments, the naive thing to do is +---- to introduce a new var, e2_out, bound to result of e2 and +---- then ensure that e2_out == x. We optimize by passing x to +---- compilation of e2 directly. +cs_of_assignment :: (GaloisField a) => Assignment a -> State (CEnv a) () +cs_of_assignment (Assignment x e) = cs_of_exp (view _Var x) e + data SimplParam = NoSimplify | Simplify @@ -470,10 +460,10 @@ data TExpPkg ty k = TExpPkg } deriving (Show) -instance (Typeable ty, Pretty k) => Pretty (TExpPkg ty k) where +instance Pretty (TExpPkg ty k) where pretty (TExpPkg _ _ e) = pretty e -deriving instance (Eq (TExp ty k)) => Eq (TExpPkg ty k) +deriving instance Eq (TExpPkg ty k) -- | Desugar a 'Comp'utation to a pair of: -- the total number of vars, @@ -483,23 +473,16 @@ compileCompToTexp :: Comp ty k -> TExpPkg ty k compileCompToTexp mf = - case run mf of + case runComp mf of Left err -> failWith err Right (e, rho) -> let out = Variable (next_variable rho) in_vars = sort $ input_vars rho in TExpPkg out in_vars e - where - run mf0 = - runState - mf0 - ( Env - 0 - 0 - [] - Map.empty - Map.empty - ) + +compileTExpToProgram :: (GaloisField k) => TExp ty k -> Program k +compileTExpToProgram te = + mkProgram . expOfLambdaExp . tExpToLambdaExp $ te -- | Snarkl.Compile 'TExp's to constraint systems. Re-exported from 'Snarkl.Compile.Snarkl.Compile'. compileTexpToConstraints :: @@ -516,10 +499,11 @@ compileTexpToConstraints (TExpPkg _out _in_vars te) = Set.toList $ Set.fromList in_vars `Set.intersection` Set.fromList (map (view _Var) $ booleanVarsOfTexp te) - e0 = expOfTExp te - e = do_const_prop e0 + Program assignments e = compileTExpToProgram te + traverse_ cs_of_assignment assignments + -- e = do_const_prop e0 -- Snarkl.Compile 'e' to constraints 'cs', with output wire 'out'. - cs_of_exp out e + cs_of_assignment $ Assignment (_Var # out) e -- Add boolean constraints mapM_ ensure_boolean boolean_in_vars cs <- get_constraints diff --git a/src/Snarkl/Errors.hs b/src/Snarkl/Errors.hs index 2a7c076..f917226 100644 --- a/src/Snarkl/Errors.hs +++ b/src/Snarkl/Errors.hs @@ -1,8 +1,8 @@ module Snarkl.Errors where -import Control.Exception +import Control.Exception (Exception, throw) import Data.String (IsString) -import Data.Typeable +import Data.Typeable (Typeable) newtype ErrMsg = ErrMsg {errMsg :: String} deriving (Typeable, IsString) diff --git a/src/Snarkl/Field.hs b/src/Snarkl/Field.hs index f9616ca..1f193af 100644 --- a/src/Snarkl/Field.hs +++ b/src/Snarkl/Field.hs @@ -2,12 +2,8 @@ module Snarkl.Field where -import Data.Field.Galois (Prime, fromP) -import Prettyprinter (Pretty (..)) +import Data.Field.Galois (Prime) type P_BN128 = 21888242871839275222246405745257275088548364400416034343698204186575808495617 type F_BN128 = Prime P_BN128 - -instance Pretty F_BN128 where - pretty = pretty . fromP diff --git a/src/Snarkl/Interp.hs b/src/Snarkl/Interp.hs index 94de305..576dcfa 100644 --- a/src/Snarkl/Interp.hs +++ b/src/Snarkl/Interp.hs @@ -1,20 +1,20 @@ -{-# LANGUAGE LambdaCase #-} - module Snarkl.Interp ( interp, ) where import Control.Monad (ap, foldM) -import Data.Data (Typeable) import Data.Field.Galois (GaloisField) +import Data.Foldable (traverse_) import Data.Map (Map) import qualified Data.Map as Map import Snarkl.Common (Op (..), UnOp (ZEq)) +import Snarkl.Compile (compileTExpToProgram) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language (Exp (..), TExp, Variable, expOfTExp) +import qualified Snarkl.Language.Core as Core +import Snarkl.Language.TExpr (TExp) -type Env a = Map Variable (Maybe a) +type Env a = Map Core.Variable (Maybe a) newtype InterpM a b = InterpM {runInterpM :: Env a -> Either ErrMsg (Env a, b)} @@ -39,11 +39,11 @@ raiseErr :: ErrMsg -> InterpM a b raiseErr err = InterpM (\_ -> Left err) -addBinds :: [(Variable, Maybe a)] -> InterpM a (Maybe b) +addBinds :: [(Core.Variable, Maybe a)] -> InterpM a (Maybe b) addBinds binds = InterpM (\rho -> Right (Map.union (Map.fromList binds) rho, Nothing)) -lookupVar :: (Show a) => Variable -> InterpM a (Maybe a) +lookupVar :: (Show a) => Core.Variable -> InterpM a (Maybe a) lookupVar x = InterpM ( \rho -> case Map.lookup x rho of @@ -77,62 +77,63 @@ boolOfField v = ) interpTExp :: - ( GaloisField a, - Typeable ty + ( GaloisField k ) => - TExp ty a -> - InterpM a (Maybe a) + TExp ty k -> + InterpM k (Maybe k) interpTExp e = do - let _exp = expOfTExp e - interpExpr _exp + let _exp = compileTExpToProgram e + interpProg _exp interp :: - (GaloisField a, Typeable ty) => - Map Variable a -> - TExp ty a -> - Either ErrMsg (Env a, Maybe a) + (GaloisField k) => + Map Core.Variable k -> + TExp ty k -> + Either ErrMsg (Env k, Maybe k) interp rho e = runInterpM (interpTExp e) $ Map.map Just rho -interpExpr :: +interpProg :: + (GaloisField a) => + Core.Program a -> + InterpM a (Maybe a) +interpProg (Core.Program assignments finalExp) = + let f (Core.Assignment var e) = do + e' <- interpCoreExpr e + addBinds [(var, e')] + in do + traverse_ f assignments + interpCoreExpr finalExp + +interpCoreExpr :: (GaloisField a) => - Exp a -> + Core.Exp a -> InterpM a (Maybe a) -interpExpr e = case e of - EVar x -> lookupVar x - EVal v -> pure $ Just v - EUnop op e2 -> do - v2 <- interpExpr e2 +interpCoreExpr = \case + Core.EVar x -> lookupVar x + Core.EVal v -> pure $ Just v + Core.EUnop op e2 -> do + v2 <- interpCoreExpr e2 case v2 of Nothing -> pure Nothing Just v2' -> case op of ZEq -> return $ Just $ fieldOfBool (v2' == 0) - EBinop op _es -> case _es of + Core.EBinop op _es -> case _es of [] -> failWith $ ErrMsg "empty binary args" (a : as) -> do - b <- interpExpr a + b <- interpCoreExpr a foldM (interpBinopExpr op) b as - EIf eb e1 e2 -> + Core.EIf eb e1 e2 -> do - mb <- interpExpr eb + mb <- interpCoreExpr eb case mb of Nothing -> pure Nothing - Just _b -> boolOfField _b >>= \b -> if b then interpExpr e1 else interpExpr e2 - EAssert e1 e2 -> - case (e1, e2) of - (EVar x, _) -> - do - v2 <- interpExpr e2 - addBinds [(x, v2)] - (_, _) -> raiseErr $ ErrMsg $ show e1 ++ " not a variable" - ESeq es -> case es of - [] -> failWith $ ErrMsg "empty sequence" - _ -> last <$> mapM interpExpr es - EUnit -> return $ Just 1 + Just _b -> boolOfField _b >>= \b -> if b then interpCoreExpr e1 else interpCoreExpr e2 + Core.EUnit -> return $ Just 1 where - interpBinopExpr :: (GaloisField a) => Op -> Maybe a -> Exp a -> InterpM a (Maybe a) + interpBinopExpr :: (GaloisField a) => Op -> Maybe a -> Core.Exp a -> InterpM a (Maybe a) interpBinopExpr _ Nothing _ = return Nothing interpBinopExpr _op (Just a1) _exp = do - ma2 <- interpExpr _exp + ma2 <- interpCoreExpr _exp case ma2 of Nothing -> return Nothing Just a2 -> Just <$> op a1 a2 diff --git a/src/Snarkl/Language.hs b/src/Snarkl/Language.hs index 0f1e150..6fe2f02 100644 --- a/src/Snarkl/Language.hs +++ b/src/Snarkl/Language.hs @@ -1,19 +1,53 @@ module Snarkl.Language - ( expOfTExp, - module Snarkl.Language.TExpr, + ( module Snarkl.Language.Core, module Snarkl.Language.Expr, + module Snarkl.Language.LambdaExpr, module Snarkl.Language.SyntaxMonad, - module Snarkl.Language.Syntax, + module Snarkl.Language.TExpr, + module Snarkl.Language.Type, ) where -import Data.Data (Typeable) -import Data.Field.Galois (GaloisField) -import Snarkl.Language.Expr +import Snarkl.Language.Core (Assignment (..), Exp (..), Program (..), Variable (..)) +import Snarkl.Language.Expr (mkProgram) import Snarkl.Language.LambdaExpr (expOfLambdaExp) -import Snarkl.Language.Syntax import Snarkl.Language.SyntaxMonad + ( Comp, + Env (..), + State (..), + arr, + assert_bot, + assert_false, + assert_true, + defaultEnv, + false, + fresh_input, + fresh_var, + fst_pair, + get, + guard, + input_arr, + is_bot, + is_false, + is_true, + pair, + raise_err, + return, + runComp, + runState, + set, + snd_pair, + true, + unit, + (>>), + (>>=), + ) import Snarkl.Language.TExpr - -expOfTExp :: (GaloisField a, Typeable ty) => TExp ty a -> Exp a -expOfTExp = expOfLambdaExp . lambdaExpOfTExp + ( TExp (..), + TOp (..), + TUnop (..), + Val (..), + booleanVarsOfTexp, + tExpToLambdaExp, + ) +import Snarkl.Language.Type diff --git a/src/Snarkl/Language/Core.hs b/src/Snarkl/Language/Core.hs new file mode 100644 index 0000000..9dffa2e --- /dev/null +++ b/src/Snarkl/Language/Core.hs @@ -0,0 +1,26 @@ +module Snarkl.Language.Core where + +import Data.Field.Galois (GaloisField) +import Data.Kind (Type) +import Data.Sequence (Seq) +import Snarkl.Common (Op, UnOp) +import Text.PrettyPrint.Leijen.Text (Pretty) + +newtype Variable = Variable Int deriving (Eq, Ord, Show, Pretty) + +data Exp :: Type -> Type where + EVar :: Variable -> Exp k + EVal :: (GaloisField k) => k -> Exp k + EUnop :: UnOp -> Exp k -> Exp k + EBinop :: Op -> [Exp k] -> Exp k + EIf :: Exp k -> Exp k -> Exp k -> Exp k + EUnit :: Exp k + +deriving instance Eq (Exp k) + +deriving instance Show (Exp k) + +data Assignment a = Assignment Variable (Exp a) + +data Program :: Type -> Type where + Program :: Seq (Assignment a) -> Exp a -> Program a diff --git a/src/Snarkl/Language/Expr.hs b/src/Snarkl/Language/Expr.hs index d8149fe..42647fe 100644 --- a/src/Snarkl/Language/Expr.hs +++ b/src/Snarkl/Language/Expr.hs @@ -1,89 +1,51 @@ +{-# LANGUAGE PatternSynonyms #-} + module Snarkl.Language.Expr ( Exp (..), - Variable (..), - exp_binop, - exp_seq, - is_pure, - var_of_exp, - do_const_prop, + mkProgram, + expSeq, + expBinop, ) where -import Control.Monad.State (State, evalState, gets, modify) +import Control.Error (hoistEither, runExceptT) +import Control.Monad.Except + ( ExceptT, + MonadError (throwError), + ) +import Control.Monad.State (State, evalState, gets, modify, runState) import Data.Field.Galois (GaloisField) +import Data.Foldable (toList) import Data.Kind (Type) import Data.Map (Map) import qualified Data.Map as Map -import Prettyprinter +import Data.Sequence (Seq, fromList, (<|), (><), (|>), pattern Empty, pattern (:<|)) +import Snarkl.Common (Op, UnOp, isAssoc) +import Snarkl.Errors (ErrMsg (..), failWith) +import qualified Snarkl.Language.Core as Core +import Text.PrettyPrint.Leijen.Text ( Pretty (pretty), hsep, parens, punctuate, (<+>), ) -import Snarkl.Common (Op, UnOp, isAssoc) -import Snarkl.Errors (ErrMsg (ErrMsg), failWith) - -newtype Variable = Variable Int deriving (Eq, Ord, Show, Pretty) data Exp :: Type -> Type where - EVar :: Variable -> Exp a - EVal :: (GaloisField a) => a -> Exp a - EUnop :: UnOp -> Exp a -> Exp a - EBinop :: Op -> [Exp a] -> Exp a - EIf :: Exp a -> Exp a -> Exp a -> Exp a - EAssert :: Exp a -> Exp a -> Exp a - ESeq :: [Exp a] -> Exp a - EUnit :: Exp a - -deriving instance (Eq a) => Eq (Exp a) + EVar :: Core.Variable -> Exp k + EVal :: (GaloisField k) => k -> Exp k + EUnop :: UnOp -> Exp k -> Exp k + EBinop :: Op -> [Exp k] -> Exp k + EIf :: Exp k -> Exp k -> Exp k -> Exp k + EAssert :: Exp k -> Exp k -> Exp k + ESeq :: Seq (Exp k) -> Exp k + EUnit :: Exp k -deriving instance (Show a) => Show (Exp a) - -var_of_exp :: (Show a) => Exp a -> Variable -var_of_exp e = case e of - EVar x -> x - _ -> failWith $ ErrMsg ("var_of_exp: expected variable: " ++ show e) - --- | Smart constructor for EBinop, ensuring all expressions (involving --- associative operations) are flattened to top level. -exp_binop :: Op -> Exp a -> Exp a -> Exp a -exp_binop op e1 e2 = - case (e1, e2) of - (EBinop op1 l1, EBinop op2 l2) - | op1 == op2 && op2 == op && isAssoc op -> - EBinop op (l1 ++ l2) - (EBinop op1 l1, _) - | op1 == op && isAssoc op -> - EBinop op (l1 ++ [e2]) - (_, EBinop op2 l2) - | op2 == op && isAssoc op -> - EBinop op (e1 : l2) - (_, _) -> EBinop op [e1, e2] +deriving instance Eq (Exp k) --- | Smart constructor for sequence, ensuring all expressions are --- flattened to top level. -exp_seq :: Exp a -> Exp a -> Exp a -exp_seq e1 e2 = - case (e1, e2) of - (ESeq l1, ESeq l2) -> ESeq (l1 ++ l2) - (ESeq l1, _) -> ESeq (l1 ++ [e2]) - (_, ESeq l2) -> ESeq (e1 : l2) - (_, _) -> ESeq [e1, e2] +deriving instance Show (Exp k) -is_pure :: Exp a -> Bool -is_pure e = - case e of - EVar _ -> True - EVal _ -> True - EUnop _ e1 -> is_pure e1 - EBinop _ es -> all is_pure es - EIf b e1 e2 -> is_pure b && is_pure e1 && is_pure e2 - EAssert _ _ -> False - ESeq es -> all is_pure es - EUnit -> True - -const_prop :: (GaloisField a) => Exp a -> State (Map Variable a) (Exp a) +const_prop :: (GaloisField k) => Exp k -> State (Map Core.Variable k) (Exp k) const_prop e = case e of EVar x -> lookup_var x @@ -114,23 +76,23 @@ const_prop e = return $ ESeq es' EUnit -> return EUnit where - lookup_var :: (GaloisField a) => Variable -> State (Map Variable a) (Exp a) + lookup_var :: (GaloisField k) => Core.Variable -> State (Map Core.Variable k) (Exp k) lookup_var x0 = gets ( \m -> case Map.lookup x0 m of Nothing -> EVar x0 Just c -> EVal c ) - add_bind :: (Variable, a) -> State (Map Variable a) (Exp a) + add_bind :: (Core.Variable, k) -> State (Map Core.Variable k) (Exp k) add_bind (x0, c0) = do modify (Map.insert x0 c0) return EUnit -do_const_prop :: (GaloisField a) => Exp a -> Exp a +do_const_prop :: (GaloisField k) => Exp k -> Exp k do_const_prop e = evalState (const_prop e) Map.empty -instance (Pretty a) => Pretty (Exp a) where +instance Pretty (Exp k) where pretty (EVar x) = "var_" <> pretty x pretty (EVal c) = pretty c pretty (EUnop op e1) = pretty op <> parens (pretty e1) @@ -139,5 +101,67 @@ instance (Pretty a) => Pretty (Exp a) where pretty (EIf b e1 e2) = "if" <+> pretty b <+> "then" <+> pretty e1 <+> "else" <+> pretty e2 pretty (EAssert e1 e2) = pretty e1 <+> ":=" <+> pretty e2 - pretty (ESeq es) = parens $ hsep $ punctuate ";" $ map pretty es + pretty (ESeq es) = parens $ hsep $ punctuate ";" $ map pretty (toList es) pretty EUnit = "()" + +mkExpression :: Exp k -> Either String (Core.Exp k) +mkExpression = \case + EVar x -> pure $ Core.EVar x + EVal v -> pure $ Core.EVal v + EUnop op e -> Core.EUnop op <$> mkExpression e + EBinop op es -> Core.EBinop op <$> traverse mkExpression es + EIf e1 e2 e3 -> Core.EIf <$> mkExpression e1 <*> mkExpression e2 <*> mkExpression e3 + EUnit -> pure Core.EUnit + e -> throwError $ "mkExpression: " ++ show e + +-- | Smart constructor for sequence, ensuring all expressions are +-- flattened to top level. +expSeq :: Exp k -> Exp k -> Exp k +expSeq e1 e2 = + case (e1, e2) of + (ESeq l1, ESeq l2) -> ESeq (l1 >< l2) + (ESeq l1, _) -> ESeq (l1 |> e2) + (_, ESeq l2) -> ESeq (e1 <| l2) + (_, _) -> ESeq (fromList [e1, e2]) + +expBinop :: Op -> Exp k -> Exp k -> Exp k +expBinop op e1 e2 = + case (e1, e2) of + (EBinop op1 l1, EBinop op2 l2) + | op1 == op2 && op2 == op && isAssoc op -> + EBinop op (l1 ++ l2) + (EBinop op1 l1, _) + | op1 == op && isAssoc op -> + EBinop op (l1 ++ [e2]) + (_, EBinop op2 l2) + | op2 == op && isAssoc op -> + EBinop op (e1 : l2) + (_, _) -> EBinop op [e1, e2] + +mkAssignment :: Exp k -> Either String (Core.Assignment k) +mkAssignment (EAssert (EVar v) e) = Core.Assignment v <$> mkExpression e +mkAssignment e = throwError $ "mkAssignment: expected EAssert, got " <> show e + +-- At this point the expression should be either: +-- 1. A sequence of assignments, followed by an expression +-- 2. An expression +mkProgram :: (GaloisField k) => Exp k -> Core.Program k +mkProgram _exp = either (failWith . ErrMsg) id $ do + let e' = do_const_prop _exp + case e' of + ESeq es -> do + let (eexpr, assignments) = runState (runExceptT $ go es) mempty + Core.Program assignments <$> eexpr + where + go :: Seq (Exp k) -> ExceptT String (State (Seq (Core.Assignment k))) (Core.Exp k) + go = \case + Empty -> throwError "mkProgram: empty sequence" + e :<| Empty -> hoistEither $ mkExpression e + e :<| rest -> do + case e of + EUnit -> go rest + _ -> do + assignment <- hoistEither $ mkAssignment e + modify (|> assignment) + go rest + _ -> Core.Program Empty <$> mkExpression e' diff --git a/src/Snarkl/Language/LambdaExpr.hs b/src/Snarkl/Language/LambdaExpr.hs index c2d6873..d1b79cf 100644 --- a/src/Snarkl/Language/LambdaExpr.hs +++ b/src/Snarkl/Language/LambdaExpr.hs @@ -3,128 +3,86 @@ module Snarkl.Language.LambdaExpr ( Exp (..), expOfLambdaExp, - expBinop, ) where import Control.Monad.Error.Class (throwError) -import Control.Monad.State (State, gets, modify, runState) import Data.Field.Galois (GaloisField) import Data.Kind (Type) -import Data.Map (Map) -import qualified Data.Map as Map -import Snarkl.Common (Op, UnOp, isAssoc) -import Snarkl.Language.Expr (Variable) -import qualified Snarkl.Language.Expr as Core +import Snarkl.Common (Op, UnOp) +import Snarkl.Errors (ErrMsg (ErrMsg), failWith) +import Snarkl.Language.Core (Variable) +import Snarkl.Language.Expr (expSeq) +import qualified Snarkl.Language.Expr as E +-- This expression language is just the untyped version of the typed +-- expression language TEExp. It is used to remove lamba application +-- and abstraction before passing on to the next expression language. +-- There is also a certain amount of "flattening" between this representation +-- and the underlying Expr language -- we reassociate all of the Seq constructors +-- to the right and then flatten them. Similarly nested Binops of the same +-- operator are flattened into a single list if that operator is associative. data Exp :: Type -> Type where - EVar :: Variable -> Exp a - EVal :: (GaloisField a) => a -> Exp a - EUnop :: UnOp -> Exp a -> Exp a - EBinop :: Op -> [Exp a] -> Exp a - EIf :: Exp a -> Exp a -> Exp a -> Exp a - EAssert :: Exp a -> Exp a -> Exp a - ESeq :: Exp a -> Exp a -> Exp a - EUnit :: Exp a - EAbs :: Variable -> Exp a -> Exp a - EApp :: Exp a -> Exp a -> Exp a + EVar :: Variable -> Exp k + EVal :: (GaloisField k) => k -> Exp k + EUnop :: UnOp -> Exp k -> Exp k + EBinop :: Op -> Exp k -> Exp k -> Exp k + EIf :: Exp k -> Exp k -> Exp k -> Exp k + EAssert :: Exp k -> Exp k -> Exp k + ESeq :: Exp k -> Exp k -> Exp k + EUnit :: Exp k + EAbs :: Variable -> Exp k -> Exp k + EApp :: Exp k -> Exp k -> Exp k -deriving instance (Show a) => Show (Exp a) +deriving instance Show (Exp k) -deriving instance (Eq a) => Eq (Exp a) +deriving instance Eq (Exp k) -type Env a = Map Variable (Exp a) - -defaultEnv :: Env a -defaultEnv = Map.empty - -applyLambdas :: - Exp a -> - (Exp a, Env a) -applyLambdas expression = runState (go expression) defaultEnv - where - go :: Exp a -> State (Env a) (Exp a) - go = \case - EVar var -> do - ma <- gets (Map.lookup var) - case ma of - Nothing -> do - pure (EVar var) - Just e -> go e - e@(EVal _) -> pure e - EUnit -> pure EUnit - EUnop op e -> EUnop op <$> go e - EBinop op es -> EBinop op <$> mapM go es - EIf b e1 e2 -> EIf <$> go b <*> go e1 <*> go e2 - EAssert (EVar v1) e@(EAbs _ _) -> do - _e <- go e - modify $ Map.insert v1 _e - pure EUnit - EAssert e1 e2 -> EAssert <$> go e1 <*> go e2 - ESeq e1 e2 -> ESeq <$> go e1 <*> go e2 - EAbs var e -> do - EAbs var <$> go e - EApp (EAbs var e1) e2 -> do - _e2 <- go e2 - modify $ Map.insert var e2 - go (substitute var e2 e1) - EApp e1 e2 -> do - _e1 <- go e1 - _e2 <- go e2 - go (EApp _e1 _e2) - -substitute :: Variable -> Exp a -> Exp a -> Exp a -substitute var e1 = \case - e@(EVar var') -> if var == var' then e1 else e - e@(EVal _) -> e +betaNormalize :: Exp k -> Exp k +betaNormalize = \case + EVar x -> EVar x + EVal v -> EVal v + EUnop op e -> EUnop op (betaNormalize e) + EBinop op l r -> EBinop op (betaNormalize l) (betaNormalize r) + EIf e1 e2 e3 -> EIf (betaNormalize e1) (betaNormalize e2) (betaNormalize e3) + EAssert e1 e2 -> EAssert (betaNormalize e1) (betaNormalize e2) + ESeq e1 e2 -> ESeq (betaNormalize e1) (betaNormalize e2) + EAbs v e -> EAbs v (betaNormalize e) + EApp f a -> + case betaNormalize f of + EAbs v e -> substitute (v, betaNormalize a) e + f' -> EApp f' (betaNormalize a) EUnit -> EUnit - EUnop op e -> EUnop op (substitute var e1 e) - EBinop op es -> EBinop op (map (substitute var e1) es) - EIf b e2 e3 -> EIf (substitute var e1 b) (substitute var e1 e2) (substitute var e1 e3) - EAssert e2 e3 -> EAssert (substitute var e1 e2) (substitute var e1 e3) - ESeq e2 e3 -> ESeq (substitute var e1 e2) (substitute var e1 e3) - EAbs var' e -> EAbs var' (substitute var e1 e) - EApp e2 e3 -> EApp (substitute var e1 e2) (substitute var e1 e3) - -expBinop :: Op -> Exp a -> Exp a -> Exp a -expBinop op e1 e2 = - case (e1, e2) of - (EBinop op1 l1, EBinop op2 l2) - | op1 == op2 && op2 == op && isAssoc op -> - EBinop op (l1 ++ l2) - (EBinop op1 l1, _) - | op1 == op && isAssoc op -> - EBinop op (l1 ++ [e2]) - (_, EBinop op2 l2) - | op2 == op && isAssoc op -> - EBinop op (e1 : l2) - (_, _) -> EBinop op [e1, e2] - --- | Smart constructor for sequence, ensuring all expressions are --- flattened to top level. -expSeq :: Core.Exp a -> Core.Exp a -> Core.Exp a -expSeq e1 e2 = - case (e1, e2) of - (Core.ESeq l1, Core.ESeq l2) -> Core.ESeq (l1 ++ l2) - (Core.ESeq l1, _) -> Core.ESeq (l1 ++ [e2]) - (_, Core.ESeq l2) -> Core.ESeq (e1 : l2) - (_, _) -> Core.ESeq [e1, e2] + where + -- substitute x e1 e2 = e2 [x := e1 ] + substitute :: (Variable, Exp k) -> Exp k -> Exp k + substitute (var, e1) = \case + e@(EVar var') -> if var == var' then e1 else e + e@(EVal _) -> e + EUnit -> EUnit + EUnop op e -> EUnop op (substitute (var, e1) e) + EBinop op l r -> EBinop op (substitute (var, e1) l) (substitute (var, e1) r) + EIf b e2 e3 -> EIf (substitute (var, e1) b) (substitute (var, e1) e2) (substitute (var, e1) e3) + EAssert e2 e3 -> EAssert (substitute (var, e1) e2) (substitute (var, e1) e3) + ESeq l r -> ESeq (substitute (var, e1) l) (substitute (var, e1) r) + EAbs var' e -> EAbs var' (substitute (var, e1) e) + EApp e2 e3 -> EApp (substitute (var, e1) e2) (substitute (var, e1) e3) -expOfLambdaExp :: (Show a) => Exp a -> Core.Exp a +expOfLambdaExp :: Exp k -> E.Exp k expOfLambdaExp _exp = - let (coreExp, _) = applyLambdas _exp + let coreExp = betaNormalize _exp in case expOfLambdaExp' coreExp of - Left err -> error err + Left err -> failWith $ ErrMsg err Right e -> e where - expOfLambdaExp' :: (Show a) => Exp a -> Either String (Core.Exp a) + expOfLambdaExp' :: Exp k -> Either String (E.Exp k) expOfLambdaExp' = \case - EVar var -> pure $ Core.EVar var - EVal v -> pure $ Core.EVal v - EUnit -> pure Core.EUnit - EUnop op e -> Core.EUnop op <$> expOfLambdaExp' e - EBinop op es -> Core.EBinop op <$> mapM expOfLambdaExp' es - EIf b e1 e2 -> Core.EIf <$> expOfLambdaExp' b <*> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 - EAssert e1 e2 -> Core.EAssert <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 + EVar var -> pure $ E.EVar var + EVal v -> pure $ E.EVal v + EUnit -> pure E.EUnit + EUnop op e -> E.EUnop op <$> expOfLambdaExp' e + EBinop op l r -> E.expBinop op <$> expOfLambdaExp' l <*> expOfLambdaExp' r + EIf b e1 e2 -> E.EIf <$> expOfLambdaExp' b <*> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 + EAssert e1 e2 -> E.EAssert <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 ESeq e1 e2 -> expSeq <$> expOfLambdaExp' e1 <*> expOfLambdaExp' e2 - e -> throwError ("Impossible after IR simplicifaction: " <> show e) + e -> throwError ("Impossible after lambda simplicifaction: " <> show e) diff --git a/src/Snarkl/Language/SyntaxMonad.hs b/src/Snarkl/Language/SyntaxMonad.hs index ee57287..c57c42d 100644 --- a/src/Snarkl/Language/SyntaxMonad.hs +++ b/src/Snarkl/Language/SyntaxMonad.hs @@ -6,13 +6,14 @@ module Snarkl.Language.SyntaxMonad ( -- | Computation monad Comp, - CompResult, runState, + runComp, return, (>>=), (>>), raise_err, Env (..), + defaultEnv, State (..), -- | Return a fresh input variable. fresh_input, @@ -50,24 +51,23 @@ where import Control.Monad (forM, replicateM) import Control.Monad.Supply (Supply, runSupply) import Control.Monad.Supply.Class (MonadSupply (fresh)) -import Data.Field.Galois (GaloisField) import qualified Data.Map.Strict as Map import Data.String (IsString (..)) import Data.Typeable (Typeable) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language.Expr (Variable (..)) +import Snarkl.Language.Core (Variable (..)) import Snarkl.Language.TExpr ( Loc, - TExp (TEAssert, TEBinop, TEBot, TESeq, TEUnop, TEVal, TEVar), + TExp (..), TLoc (TLoc), TVar (TVar), - Ty (TArr, TBool, TProd, TUnit), Val (VFalse, VLoc, VTrue, VUnit), lastSeq, locOfTexp, teSeq, varOfTExp, ) +import Snarkl.Language.Type (Ty (..)) import Prelude hiding ( fromRational, negate, @@ -103,7 +103,6 @@ raise_err msg = State (\_ -> Left msg) -- of the results of 'mf', 'g' (not just whatever 'g' returns) (>>=) :: forall (ty1 :: Ty) (ty2 :: Ty) s a. - (Typeable ty1) => State s (TExp ty1 a) -> (TExp ty1 a -> State s (TExp ty2 a)) -> State s (TExp ty2 a) @@ -119,7 +118,6 @@ raise_err msg = State (\_ -> Left msg) (>>) :: forall (ty1 :: Ty) (ty2 :: Ty) s a. - (Typeable ty1) => State s (TExp ty1 a) -> State s (TExp ty2 a) -> State s (TExp ty2 a) @@ -172,8 +170,21 @@ data Env k = Env } deriving (Show) +defaultEnv :: Env k +defaultEnv = + Env + { next_variable = 0, + next_loc = 0, + input_vars = [], + obj_map = Map.empty, + anal_map = Map.empty + } + type Comp ty k = State (Env k) (TExp ty k) +runComp :: Comp ty k -> Either ErrMsg (TExp ty k, Env k) +runComp f = runState f defaultEnv + {----------------------------------------------- Units, Booleans (used below) ------------------------------------------------} @@ -286,7 +297,6 @@ get_addr (l, i) = guard :: (Typeable ty2) => - (GaloisField k) => (TExp ty k -> State (Env k) (TExp ty2 k)) -> TExp ty k -> State (Env k) (TExp ty2 k) @@ -300,19 +310,18 @@ guard f e = guarded_get_addr :: (Typeable ty2) => - (GaloisField k) => TExp ty k -> Int -> State (Env k) (TExp ty2 k) guarded_get_addr e i = guard (\e0 -> get_addr (locOfTexp e0, i)) e -get :: (Typeable ty) => (GaloisField k) => (TExp ('TArr ty) k, Int) -> Comp ty k +get :: (Typeable ty) => (TExp ('TArr ty) k, Int) -> Comp ty k get (TEBot, _) = return TEBot get (a, i) = guarded_get_addr a i -- | Smart constructor for TEAssert -te_assert :: (Typeable ty) => (GaloisField k) => TExp ty k -> TExp ty k -> Comp 'TUnit k +te_assert :: (Typeable ty) => TExp ty k -> TExp ty k -> Comp 'TUnit k te_assert x@(TEVar _) e = do e_bot <- is_bot e @@ -333,7 +342,6 @@ te_assert _ e = -- in the object map. set_addr :: (Typeable ty) => - (GaloisField k) => (TExp ('TArr ty) k, Int) -> TExp ty k -> Comp 'TUnit k @@ -361,7 +369,7 @@ set_addr (TEVal (VLoc (TLoc l)), i) e = set_addr (e1, _) _ = raise_err $ ErrMsg ("expected " ++ show e1 ++ " a loc") -set :: (Typeable ty, GaloisField k) => (TExp ('TArr ty) k, Int) -> TExp ty k -> Comp 'TUnit k +set :: (Typeable ty) => (TExp ('TArr ty) k, Int) -> TExp ty k -> Comp 'TUnit k set (a, i) e = set_addr (a, i) e {----------------------------------------------- @@ -370,8 +378,7 @@ set (a, i) e = set_addr (a, i) e pair :: ( Typeable ty1, - Typeable ty2, - GaloisField k + Typeable ty2 ) => TExp ty1 k -> TExp ty2 k -> @@ -411,15 +418,13 @@ pair te1 te2 = fst_pair :: (Typeable ty1) => - (GaloisField k) => TExp ('TProd ty1 ty2) k -> Comp ty1 k fst_pair TEBot = return TEBot fst_pair e = guarded_get_addr e 0 snd_pair :: - ( Typeable ty2, - GaloisField k + ( Typeable ty2 ) => TExp ('TProd ty1 ty2) k -> Comp ty2 k @@ -494,7 +499,7 @@ add_statics binds = ) -- | Does boolean expression 'e' resolve (statically) to 'b'? -is_bool :: (GaloisField k) => TExp ty k -> Bool -> Comp 'TBool k +is_bool :: TExp ty k -> Bool -> Comp 'TBool k is_bool (TEVal VFalse) False = return true is_bool (TEVal VTrue) True = return true is_bool e@(TEVar _) b = @@ -511,24 +516,24 @@ is_bool e@(TEVar _) b = ) is_bool _ _ = return false -is_false :: (GaloisField k) => TExp ty k -> Comp 'TBool k +is_false :: TExp ty k -> Comp 'TBool k is_false = flip is_bool False -is_true :: (GaloisField k) => TExp ty k -> Comp 'TBool k +is_true :: TExp ty k -> Comp 'TBool k is_true = flip is_bool True -- | Add binding 'x = b'. -assert_bool :: (GaloisField k) => TExp ty k -> Bool -> Comp 'TUnit k +assert_bool :: TExp ty k -> Bool -> Comp 'TUnit k assert_bool (TEVar (TVar x)) b = add_statics [(x, AnalBool b)] assert_bool e _ = raise_err $ ErrMsg $ "expected " ++ show e ++ " a variable" -assert_false :: (GaloisField k) => TExp ty k -> Comp 'TUnit k +assert_false :: TExp ty k -> Comp 'TUnit k assert_false = flip assert_bool False -assert_true :: (GaloisField k) => TExp ty k -> Comp 'TUnit k +assert_true :: TExp ty k -> Comp 'TUnit k assert_true = flip assert_bool True -var_is_bot :: (GaloisField k) => TExp ty k -> Comp 'TBool k +var_is_bot :: TExp ty k -> Comp 'TBool k var_is_bot e@(TEVar (TVar _)) = State ( \s -> @@ -542,7 +547,7 @@ var_is_bot e@(TEVar (TVar _)) = ) var_is_bot _ = return false -is_bot :: (GaloisField k) => TExp ty k -> Comp 'TBool k +is_bot :: TExp ty k -> Comp 'TBool k is_bot e = case e of e0@(TEVar _) -> var_is_bot e0 @@ -552,7 +557,7 @@ is_bot e = TEBot -> return true _ -> return false where - either_is_bot :: (GaloisField k) => TExp ty1 k -> TExp ty2 k -> Comp 'TBool k + either_is_bot :: TExp ty1 k -> TExp ty2 k -> Comp 'TBool k either_is_bot e10 e20 = do e1_bot <- is_bot e10 @@ -562,6 +567,6 @@ is_bot e = (_, TEVal VTrue) -> return true _ -> return false -assert_bot :: (GaloisField k) => TExp ty k -> Comp 'TUnit k +assert_bot :: TExp ty k -> Comp 'TUnit k assert_bot (TEVar (TVar x)) = add_statics [(x, AnalBot)] assert_bot e = raise_err $ ErrMsg $ "in assert_bot, expected " ++ show e ++ " a variable" diff --git a/src/Snarkl/Language/TExpr.hs b/src/Snarkl/Language/TExpr.hs index 7690c79..36c294e 100644 --- a/src/Snarkl/Language/TExpr.hs +++ b/src/Snarkl/Language/TExpr.hs @@ -1,101 +1,31 @@ -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} module Snarkl.Language.TExpr ( Val (..), TExp (..), - TFunct (..), - Ty (..), - Rep, TUnop (..), TOp (..), TVar (..), Loc, TLoc (..), + tExpToLambdaExp, booleanVarsOfTexp, - lambdaExpOfTExp, varOfTExp, locOfTexp, teSeq, lastSeq, - -- expOfTExp, ) where import Data.Field.Galois (GaloisField) import Data.Kind (Type) import Data.Typeable (Proxy (..), Typeable, eqT, typeOf, typeRep, type (:~:) (Refl)) -import Prettyprinter (Pretty (pretty), line, parens, (<+>)) import Snarkl.Common (Op, UnOp) import Snarkl.Errors (ErrMsg (ErrMsg), failWith) -import Snarkl.Language.Expr (Variable) +import Snarkl.Language.Core (Variable) import qualified Snarkl.Language.LambdaExpr as LE - -data TFunct where - TFConst :: Ty -> TFunct - TFId :: TFunct - TFProd :: TFunct -> TFunct -> TFunct - TFSum :: TFunct -> TFunct -> TFunct - TFComp :: TFunct -> TFunct -> TFunct - deriving (Typeable) - -instance Pretty TFunct where - pretty f = case f of - TFConst ty -> "Const" <+> pretty ty - TFId -> "Id" - TFProd f1 f2 -> parens (pretty f1 <+> "⊗" <+> pretty f2) - TFSum f1 f2 -> parens (pretty f1 <+> "⊕" <+> pretty f2) - TFComp f1 f2 -> parens (pretty f1 <+> "∘" <+> pretty f2) - -data Ty where - TField :: Ty - TBool :: Ty - TArr :: Ty -> Ty - TProd :: Ty -> Ty -> Ty - TSum :: Ty -> Ty -> Ty - TMu :: TFunct -> Ty - TUnit :: Ty - TFun :: Ty -> Ty -> Ty - deriving (Typeable) - -deriving instance Typeable 'TField - -deriving instance Typeable 'TBool - -deriving instance Typeable 'TArr - -deriving instance Typeable 'TProd - -deriving instance Typeable 'TSum - -deriving instance Typeable 'TMu - -deriving instance Typeable 'TUnit - -deriving instance Typeable 'TFun - -instance Pretty Ty where - pretty ty = case ty of - TField -> "Field" - TBool -> "Bool" - TArr _ty -> "Array" <+> pretty _ty - TProd ty1 ty2 -> parens (pretty ty1 <+> "⨉" <+> pretty ty2) - TSum ty1 ty2 -> parens (pretty ty1 <+> "+" <+> pretty ty2) - TMu f -> "μ" <> parens (pretty f) - TUnit -> "()" - TFun ty1 ty2 -> parens (pretty ty1 <+> "->" <+> pretty ty2) - -type family Rep (f :: TFunct) (x :: Ty) :: Ty - -type instance Rep ('TFConst ty) x = ty - -type instance Rep 'TFId x = x - -type instance Rep ('TFProd f g) x = 'TProd (Rep f x) (Rep g x) - -type instance Rep ('TFSum f g) x = 'TSum (Rep f x) (Rep g x) - -type instance Rep ('TFComp f g) x = Rep f (Rep g x) +import Snarkl.Language.Type (Ty (TBool, TField, TFun, TUnit)) +import Text.PrettyPrint.Leijen.Text (Pretty (pretty), line, parens, (<+>)) newtype TVar (ty :: Ty) = TVar Variable deriving (Eq, Show) @@ -128,17 +58,17 @@ instance Pretty (TOp ty1 ty2 ty3) where pretty (TOp op) = pretty op data Val :: Ty -> Type -> Type where - VField :: (GaloisField a) => a -> Val 'TField a - VTrue :: Val 'TBool a - VFalse :: Val 'TBool a - VUnit :: Val 'TUnit a - VLoc :: TLoc ty -> Val ty a + VField :: (GaloisField k) => k -> Val 'TField k + VTrue :: Val 'TBool k + VFalse :: Val 'TBool k + VUnit :: Val 'TUnit k + VLoc :: TLoc ty -> Val ty k -deriving instance (Eq a) => Eq (Val (b :: Ty) a) +deriving instance Eq (Val (b :: Ty) k) -deriving instance (Show a) => Show (Val (b :: Ty) a) +deriving instance Show (Val (b :: Ty) k) -instance (Pretty a) => Pretty (Val ty a) where +instance Pretty (Val ty k) where pretty v = case v of VField a -> pretty a VTrue -> "true" @@ -170,22 +100,22 @@ data TExp :: Ty -> Type -> Type where TEAbs :: (Typeable ty, Typeable ty1) => TVar ty -> TExp ty1 a -> TExp ('TFun ty ty1) a TEApp :: (Typeable ty, Typeable ty1) => TExp ('TFun ty ty1) a -> TExp ty a -> TExp ty1 a -deriving instance (Show a) => Show (TExp (b :: Ty) a) +deriving instance Show (TExp (b :: Ty) k) -instance (Eq a) => Eq (TExp (b :: Ty) a) where +instance Eq (TExp (b :: Ty) k) where TEVar x == TEVar y = x == y TEVal a == TEVal b = a == b TEUnop (op :: TUnop ty1 ty) e1 == TEUnop (op' :: TUnop ty2 ty) e1' = case eqT @ty1 @ty2 of Just Refl -> op == op' && e1 == e1' Nothing -> False - TEBinop (op :: TOp ty1 ty2 ty) (e1 :: TExp ty1 a) (e2 :: TExp ty2 a) == TEBinop (op' :: TOp ty3 ty4 ty) e1' e2' = + TEBinop (op :: TOp ty1 ty2 ty) (e1 :: TExp ty1 k) (e2 :: TExp ty2 k) == TEBinop (op' :: TOp ty3 ty4 ty) e1' e2' = case (eqT @ty1 @ty3, eqT @ty2 @ty4) of (Just Refl, Just Refl) -> op == op' && e1 == e1' && e2 == e2' _ -> False TEIf e e1 e2 == TEIf e' e1' e2' = e == e' && e1 == e1' && e2 == e2' - TEAssert (e1 :: TExp ty1 a) (e2 :: TExp ty1 a) == TEAssert (e1' :: TExp ty2 a) (e2' :: TExp ty2 a) = + TEAssert (e1 :: TExp ty1 k) (e2 :: TExp ty1 a) == TEAssert (e1' :: TExp ty2 k) (e2' :: TExp ty2 k) = case eqT @ty1 @ty2 of Just Refl -> e1 == e1' && e2 == e2' Nothing -> False @@ -194,24 +124,39 @@ instance (Eq a) => Eq (TExp (b :: Ty) a) where TEBot == TEBot = True _ == _ = False -lambdaExpOfTExp :: (GaloisField a, Typeable ty) => TExp ty a -> LE.Exp a -lambdaExpOfTExp te = case te of +instance Pretty (TExp ty a) where + pretty (TEVar var) = pretty var + pretty (TEVal val) = pretty val + pretty (TEUnop unop _exp) = pretty unop <+> pretty _exp + pretty (TEBinop binop exp1 exp2) = pretty exp1 <+> pretty binop <+> pretty exp2 + pretty (TEIf condExp thenExp elseExp) = "if" <+> pretty condExp <+> "then" <+> pretty thenExp <+> "else" <+> pretty elseExp + pretty (TEAssert exp1 exp2) = pretty exp1 <+> ":=" <+> pretty exp2 + pretty (TESeq exp1 exp2) = parens (pretty exp1 <+> ";" <> line <> pretty exp2) + pretty TEBot = "⊥" + pretty (TEAbs var _exp) = parens ("\\" <> pretty var <+> "->" <+> pretty _exp) + pretty (TEApp exp1 exp2) = parens (pretty exp1 <+> pretty exp2) + +tExpToLambdaExp :: + (GaloisField k) => + TExp ty k -> + LE.Exp k +tExpToLambdaExp te = case te of TEVar (TVar x) -> LE.EVar x TEVal v -> lambdaExpOfVal v TEUnop (TUnop op) te1 -> - LE.EUnop op (lambdaExpOfTExp te1) + LE.EUnop op (tExpToLambdaExp te1) TEBinop (TOp op) te1 te2 -> - LE.expBinop op (lambdaExpOfTExp te1) (lambdaExpOfTExp te2) + LE.EBinop op (tExpToLambdaExp te1) (tExpToLambdaExp te2) TEIf te1 te2 te3 -> - LE.EIf (lambdaExpOfTExp te1) (lambdaExpOfTExp te2) (lambdaExpOfTExp te3) + LE.EIf (tExpToLambdaExp te1) (tExpToLambdaExp te2) (tExpToLambdaExp te3) TEAssert te1 te2 -> - LE.EAssert (lambdaExpOfTExp te1) (lambdaExpOfTExp te2) - TESeq te1 te2 -> LE.ESeq (lambdaExpOfTExp te1) (lambdaExpOfTExp te2) + LE.EAssert (tExpToLambdaExp te1) (tExpToLambdaExp te2) + TESeq te1 te2 -> LE.ESeq (tExpToLambdaExp te1) (tExpToLambdaExp te2) TEBot -> LE.EUnit - TEAbs (TVar v) e -> LE.EAbs v (lambdaExpOfTExp e) - TEApp e1 e2 -> LE.EApp (lambdaExpOfTExp e1) (lambdaExpOfTExp e2) + TEAbs (TVar v) e -> LE.EAbs v (tExpToLambdaExp e) + TEApp e1 e2 -> LE.EApp (tExpToLambdaExp e1) (tExpToLambdaExp e2) where - lambdaExpOfVal :: (GaloisField a) => Val ty a -> LE.Exp a + lambdaExpOfVal :: (GaloisField k) => Val ty k -> LE.Exp k lambdaExpOfVal v = case v of VField c -> LE.EVal c VTrue -> LE.EVal 1 @@ -222,13 +167,19 @@ lambdaExpOfTExp te = case te of -- | Smart constructor for 'TESeq'. Simplify 'TESeq te1 te2' to 'te2' -- whenever the normal form of 'te1' (with seq's reassociated right) -- is *not* equal 'TEAssert _ _'. -teSeq :: (Typeable ty1) => TExp ty1 a -> TExp ty2 a -> TExp ty2 a +teSeq :: + TExp ty1 a -> + TExp ty2 a -> + TExp ty2 a teSeq te1 te2 = case (te1, te2) of (TEAssert _ _, _) -> TESeq te1 te2 (TESeq tx ty, _) -> teSeq tx (teSeq ty te2) (_, _) -> te2 -booleanVarsOfTexp :: (Typeable ty) => TExp ty a -> [Variable] +booleanVarsOfTexp :: + (Typeable ty) => + TExp ty a -> + [Variable] booleanVarsOfTexp = go [] where go :: (Typeable ty) => [Variable] -> TExp ty a -> [Variable] @@ -247,29 +198,23 @@ booleanVarsOfTexp = go [] go vars (TEAbs _ e) = go vars e go vars (TEApp e1 e2) = go (go vars e1) e2 -varOfTExp :: (Show (TExp ty a)) => TExp ty a -> Variable +varOfTExp :: + TExp ty a -> + Variable varOfTExp te = case lastSeq te of TEVar (TVar x) -> x _ -> failWith $ ErrMsg ("varOfTExp: expected var: " ++ show te) -locOfTexp :: (Show (TExp ty a)) => TExp ty a -> Loc +locOfTexp :: + TExp ty a -> + Loc locOfTexp te = case lastSeq te of TEVal (VLoc (TLoc l)) -> l _ -> failWith $ ErrMsg ("locOfTexp: expected loc: " ++ show te) -lastSeq :: TExp ty a -> TExp ty a +lastSeq :: + TExp ty a -> + TExp ty a lastSeq te = case te of TESeq _ te2 -> lastSeq te2 _ -> te - -instance (Pretty a, Typeable ty) => Pretty (TExp ty a) where - pretty (TEVar var) = pretty var - pretty (TEVal val) = pretty val - pretty (TEUnop unop _exp) = pretty unop <+> pretty _exp - pretty (TEBinop binop exp1 exp2) = pretty exp1 <+> pretty binop <+> pretty exp2 - pretty (TEIf condExp thenExp elseExp) = "if" <+> pretty condExp <+> "then" <+> pretty thenExp <+> "else" <+> pretty elseExp - pretty (TEAssert exp1 exp2) = pretty exp1 <+> ":=" <+> pretty exp2 - pretty (TESeq exp1 exp2) = parens (pretty exp1 <+> ";" <> line <> pretty exp2) - pretty TEBot = "⊥" - pretty (TEAbs var _exp) = parens ("\\" <> pretty var <+> "->" <+> pretty _exp) - pretty (TEApp exp1 exp2) = parens (pretty exp1 <+> pretty exp2) diff --git a/src/Snarkl/Language/Type.hs b/src/Snarkl/Language/Type.hs new file mode 100644 index 0000000..145b2f9 --- /dev/null +++ b/src/Snarkl/Language/Type.hs @@ -0,0 +1,77 @@ +{-# LANGUAGE UndecidableInstances #-} + +module Snarkl.Language.Type + ( TFunct (..), + Ty (..), + Rep, + ) +where + +import Data.Typeable (Typeable) +import Text.PrettyPrint.Leijen.Text (Pretty (pretty), parens, (<+>)) + +data TFunct where + TFConst :: Ty -> TFunct + TFId :: TFunct + TFProd :: TFunct -> TFunct -> TFunct + TFSum :: TFunct -> TFunct -> TFunct + TFComp :: TFunct -> TFunct -> TFunct + deriving (Typeable) + +instance Pretty TFunct where + pretty f = case f of + TFConst ty -> "Const" <+> pretty ty + TFId -> "Id" + TFProd f1 f2 -> parens (pretty f1 <+> "⊗" <+> pretty f2) + TFSum f1 f2 -> parens (pretty f1 <+> "⊕" <+> pretty f2) + TFComp f1 f2 -> parens (pretty f1 <+> "∘" <+> pretty f2) + +data Ty where + TField :: Ty + TBool :: Ty + TArr :: Ty -> Ty + TProd :: Ty -> Ty -> Ty + TSum :: Ty -> Ty -> Ty + TMu :: TFunct -> Ty + TUnit :: Ty + TFun :: Ty -> Ty -> Ty + deriving (Typeable) + +deriving instance Typeable 'TField + +deriving instance Typeable 'TBool + +deriving instance Typeable 'TArr + +deriving instance Typeable 'TProd + +deriving instance Typeable 'TSum + +deriving instance Typeable 'TMu + +deriving instance Typeable 'TUnit + +deriving instance Typeable 'TFun + +instance Pretty Ty where + pretty ty = case ty of + TField -> "Field" + TBool -> "Bool" + TArr _ty -> "Array" <+> pretty _ty + TProd ty1 ty2 -> parens (pretty ty1 <+> "⨉" <+> pretty ty2) + TSum ty1 ty2 -> parens (pretty ty1 <+> "+" <+> pretty ty2) + TMu f -> "μ" <> parens (pretty f) + TUnit -> "()" + TFun ty1 ty2 -> parens (pretty ty1 <+> "->" <+> pretty ty2) + +type family Rep (f :: TFunct) (x :: Ty) :: Ty + +type instance Rep ('TFConst ty) x = ty + +type instance Rep 'TFId x = x + +type instance Rep ('TFProd f g) x = 'TProd (Rep f x) (Rep g x) + +type instance Rep ('TFSum f g) x = 'TSum (Rep f x) (Rep g x) + +type instance Rep ('TFComp f g) x = Rep f (Rep g x) diff --git a/src/Snarkl/Language/Syntax.hs b/src/Snarkl/Syntax.hs similarity index 88% rename from src/Snarkl/Language/Syntax.hs rename to src/Snarkl/Syntax.hs index fe741d6..371625d 100644 --- a/src/Snarkl/Language/Syntax.hs +++ b/src/Snarkl/Syntax.hs @@ -1,9 +1,27 @@ -{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RebindableSyntax #-} -module Snarkl.Language.Syntax - ( Zippable, +module Snarkl.Syntax + ( -- | Snarkl.Language.TExpr, + TExp, + Ty (..), + TFunct (..), + Rep, + -- | SyntaxMonad and Syntax + Comp, + runComp, + return, + (>>=), + (>>), + -- | Return a fresh input variable. + fresh_input, + -- | Classes + Zippable, Derive, + -- | Basic values + unit, + false, + true, + fromField, -- | Sums, products, recursive types inl, inr, @@ -29,7 +47,6 @@ module Snarkl.Language.Syntax exp_of_int, inc, dec, - fromField, ifThenElse, negate, -- | Arrays @@ -55,10 +72,22 @@ module Snarkl.Language.Syntax forall, forall2, forall3, + -- | Function combinators lambda, curry, uncurry, apply, + assert_bot, + assert_false, + assert_true, + defaultEnv, + fresh_var, + guard, + is_bot, + is_false, + is_true, + raise_err, + runState, ) where @@ -66,8 +95,8 @@ import Data.Field.Galois (GaloisField) import Data.String (IsString (..)) import Data.Typeable (Typeable) import Snarkl.Common - ( Op (Add, And, BEq, Div, Eq, Mult, Sub, XOr), - UnOp (ZEq), + ( Op (..), + UnOp (..), ) import Snarkl.Errors (ErrMsg (ErrMsg)) import Snarkl.Language.SyntaxMonad @@ -78,7 +107,9 @@ import Snarkl.Language.SyntaxMonad assert_bot, assert_false, assert_true, + defaultEnv, false, + fresh_input, fresh_var, fst_pair, get, @@ -90,21 +121,22 @@ import Snarkl.Language.SyntaxMonad pair, raise_err, return, + runComp, runState, set, snd_pair, true, unit, + (>>), (>>=), ) import Snarkl.Language.TExpr - ( Rep, - TExp (TEAbs, TEApp, TEBinop, TEBot, TEIf, TEUnop, TEVal, TEVar), + ( TExp (TEAbs, TEApp, TEBinop, TEBot, TEIf, TEUnop, TEVal, TEVar), TOp (TOp), TUnop (TUnop), - Ty (TArr, TBool, TField, TFun, TMu, TProd, TSum, TUnit), Val (VFalse, VField, VTrue, VUnit), ) +import Snarkl.Language.Type (Rep, TFunct (..), Ty (..)) import Unsafe.Coerce (unsafeCoerce) import Prelude hiding ( curry, @@ -137,7 +169,7 @@ dec n = (P.-) n 1 -- | 2-d arrays. 'width' is the size, in "bits" (#field elements), of -- each array element. -arr2 :: (Typeable ty, GaloisField k) => Int -> Int -> Comp ('TArr ('TArr ty)) k +arr2 :: (Typeable ty) => Int -> Int -> Comp ('TArr ('TArr ty)) k arr2 len width = do a <- arr len @@ -152,7 +184,7 @@ arr2 len width = return a -- | 3-d arrays. -arr3 :: (Typeable ty, GaloisField k) => Int -> Int -> Int -> Comp ('TArr ('TArr ('TArr ty))) k +arr3 :: (Typeable ty) => Int -> Int -> Int -> Comp ('TArr ('TArr ('TArr ty))) k arr3 len width height = do a <- arr2 len width @@ -166,7 +198,7 @@ arr3 len width height = ) return a -input_arr2 :: (Typeable ty, GaloisField k) => Int -> Int -> Comp ('TArr ('TArr ty)) k +input_arr2 :: (Typeable ty) => Int -> Int -> Comp ('TArr ('TArr ty)) k input_arr2 0 _ = raise_err $ ErrMsg "array must have size > 0" input_arr2 len width = do @@ -181,7 +213,7 @@ input_arr2 len width = ) return a -input_arr3 :: (Typeable ty, GaloisField k) => Int -> Int -> Int -> Comp ('TArr ('TArr ('TArr ty))) k +input_arr3 :: (Typeable ty) => Int -> Int -> Int -> Comp ('TArr ('TArr ('TArr ty))) k input_arr3 len width height = do a <- arr2 len width @@ -195,13 +227,13 @@ input_arr3 len width height = ) return a -set2 :: (Typeable ty2, GaloisField k) => (TExp ('TArr ('TArr ty2)) k, Int, Int) -> TExp ty2 k -> Comp 'TUnit k +set2 :: (Typeable ty2) => (TExp ('TArr ('TArr ty2)) k, Int, Int) -> TExp ty2 k -> Comp 'TUnit k set2 (a, i, j) e = do a' <- get (a, i) set (a', j) e set3 :: - (Typeable ty, GaloisField k) => + (Typeable ty) => ( TExp ('TArr ('TArr ('TArr ty))) k, Int, Int, @@ -214,7 +246,7 @@ set3 (a, i, j, k) e = do set (a', k) e set4 :: - (Typeable ty, GaloisField k) => + (Typeable ty) => ( TExp ('TArr ('TArr ('TArr ('TArr ty)))) k, Int, Int, @@ -227,13 +259,13 @@ set4 (a, i, j, k, l) e = do a' <- get3 (a, i, j, k) set (a', l) e -get2 :: (Typeable ty2, GaloisField k) => (TExp ('TArr ('TArr ty2)) k, Int, Int) -> State (Env k) (TExp ty2 k) +get2 :: (Typeable ty2) => (TExp ('TArr ('TArr ty2)) k, Int, Int) -> State (Env k) (TExp ty2 k) get2 (a, i, j) = do a' <- get (a, i) get (a', j) get3 :: - (Typeable ty, GaloisField k) => + (Typeable ty) => ( TExp ('TArr ('TArr ('TArr ty))) k, Int, Int, @@ -245,7 +277,7 @@ get3 (a, i, j, k) = do get (a', k) get4 :: - (Typeable ty, GaloisField k) => + (Typeable ty) => ( TExp ('TArr ('TArr ('TArr ('TArr ty)))) k, Int, Int, @@ -274,8 +306,7 @@ unrep_sum :: unrep_sum = unsafe_cast inl :: - (GaloisField k) => - forall ty1 ty2. + forall ty1 ty2 k. ( Typeable ty1, Typeable ty2 ) => @@ -295,8 +326,7 @@ inl te1 = inr :: forall ty1 ty2 k. ( Typeable ty1, - Typeable ty2, - GaloisField k + Typeable ty2 ) => TExp ty2 k -> Comp ('TSum ty1 ty2) k @@ -315,9 +345,7 @@ case_sum :: forall ty1 ty2 ty k. ( Typeable ty1, Typeable ty2, - Typeable ty, - Zippable ty k, - GaloisField k + Zippable ty k ) => (TExp ty1 k -> Comp ty k) -> (TExp ty2 k -> Comp ty k) -> @@ -357,7 +385,7 @@ instance Derive 'TBool k where instance (GaloisField k) => Derive 'TField k where derive _ = return $ TEVal (VField 0) -instance (Typeable ty, Derive ty k, GaloisField k) => Derive ('TArr ty) k where +instance (Typeable ty, Derive ty k) => Derive ('TArr ty) k where derive n = do a <- arr 1 @@ -369,8 +397,7 @@ instance ( Typeable ty1, Derive ty1 k, Typeable ty2, - Derive ty2 k, - GaloisField k + Derive ty2 k ) => Derive ('TProd ty1 ty2) k where @@ -383,8 +410,7 @@ instance instance ( Typeable ty1, Derive ty1 k, - Typeable ty2, - GaloisField k + Typeable ty2 ) => Derive ('TSum ty1 ty2) k where @@ -394,10 +420,7 @@ instance inl v1 instance - ( Typeable f, - Typeable (Rep f ('TMu f)), - Derive (Rep f ('TMu f)) k, - GaloisField k + ( Derive (Rep f ('TMu f)) k ) => Derive ('TMu f) k where @@ -427,7 +450,7 @@ class Zippable ty k where instance Zippable 'TUnit k where zip_vals _ _ _ = return unit -zip_base :: (Typeable ty, GaloisField k) => TExp 'TBool k -> TExp ty k -> TExp ty k -> Comp ty k +zip_base :: (Typeable ty) => TExp 'TBool k -> TExp ty k -> TExp ty k -> Comp ty k zip_base TEBot _ _ = return TEBot zip_base _ TEBot e2 = return e2 zip_base _ e1 TEBot = return e1 @@ -451,18 +474,17 @@ zip_base b e1 e2 = ) b -instance (GaloisField k) => Zippable 'TBool k where +instance Zippable 'TBool k where zip_vals b b1 b2 = zip_base b b1 b2 -instance (GaloisField k) => Zippable 'TField k where +instance Zippable 'TField k where zip_vals b e1 e2 = zip_base b e1 e2 fuel :: Int fuel = 1 check_bots :: - ( Derive ty k, - GaloisField k + ( Derive ty k ) => Comp ty k -> TExp 'TBool k -> @@ -492,8 +514,7 @@ instance Derive ty1 k, Zippable ty2 k, Typeable ty2, - Derive ty2 k, - GaloisField k + Derive ty2 k ) => Zippable ('TProd ty1 ty2) k where @@ -514,8 +535,7 @@ instance Derive ty1 k, Zippable ty2 k, Typeable ty2, - Derive ty2 k, - GaloisField k + Derive ty2 k ) => Zippable ('TSum ty1 ty2) k where @@ -528,11 +548,8 @@ instance return $ unrep_sum p' instance - ( Typeable f, - Typeable (Rep f ('TMu f)), - Zippable (Rep f ('TMu f)) k, - Derive (Rep f ('TMu f)) k, - GaloisField k + ( Zippable (Rep f ('TMu f)) k, + Derive (Rep f ('TMu f)) k ) => Zippable ('TMu f) k where @@ -632,7 +649,7 @@ fix = fixN 100 zeq :: TExp 'TField k -> TExp 'TBool k zeq e = TEUnop (TUnop ZEq) e -not :: (Eq k) => TExp 'TBool k -> TExp 'TBool k +not :: TExp 'TBool k -> TExp 'TBool k not e = ifThenElse_aux e false true xor :: TExp 'TBool k -> TExp 'TBool k -> TExp 'TBool k @@ -651,7 +668,6 @@ exp_of_int :: (GaloisField k) => Int -> TExp 'TField k exp_of_int i = TEVal (VField $ fromIntegral i) ifThenElse_aux :: - (Eq a) => TExp 'TBool a -> TExp ty a -> TExp ty a -> @@ -665,8 +681,7 @@ ifThenElse_aux b e1 e2 _ -> TEIf b e1 e2 ifThenElse :: - ( Zippable ty k, - Typeable ty + ( Zippable ty k ) => Comp 'TBool k -> Comp ty k -> @@ -699,7 +714,6 @@ iter n f e = g n f e g m f' e' = f' m $ g (dec m) f' e' iterM :: - (Typeable ty) => Int -> (Int -> TExp ty k -> Comp ty k) -> TExp ty k -> @@ -766,7 +780,6 @@ curry :: (Typeable a) => (Typeable b) => (Typeable c) => - (GaloisField k) => (TExp ('TProd a b) k -> Comp c k) -> TExp a k -> Comp ('TFun b c) k @@ -779,7 +792,6 @@ uncurry :: (Typeable a) => (Typeable b) => (Typeable c) => - (GaloisField k) => (TExp a k -> Comp ('TFun b c) k) -> TExp ('TProd a b) k -> Comp c k diff --git a/src/Snarkl/Toplevel.hs b/src/Snarkl/Toplevel.hs index 53503ab..59b333a 100644 --- a/src/Snarkl/Toplevel.hs +++ b/src/Snarkl/Toplevel.hs @@ -11,8 +11,6 @@ module Snarkl.Toplevel execute, -- * Re-exported modules - module Snarkl.Language, - module Snarkl.Constraint, module Snarkl.Backend.R1CS, module Snarkl.Compile, ) @@ -22,26 +20,20 @@ import Data.Field.Galois (GaloisField, PrimeField) import Data.List (intercalate) import qualified Data.Map as Map import Data.Typeable (Typeable) -import Prettyprinter (Pretty (..), line, (<+>)) import Snarkl.Backend.R1CS import Snarkl.Common (Assgn) import Snarkl.Compile -import Snarkl.Constraint import Snarkl.Errors (ErrMsg (ErrMsg), failWith) import Snarkl.Interp (interp) import Snarkl.Language +import Text.PrettyPrint.Leijen.Text (Pretty (..), line, (<+>)) import Prelude ----------------------------------------------------- --- --- Snarkl.Toplevel Stuff --- ----------------------------------------------------- - -- | Using the executable semantics for the 'TExp' language, execute -- the computation on the provided inputs, returning the 'k' result. comp_interp :: - (Typeable ty, GaloisField k) => + forall ty k. + (GaloisField k) => Comp ty k -> [k] -> k diff --git a/tests/Test/ArkworksBridge.hs b/tests/Test/ArkworksBridge.hs index e1e4703..f847a9b 100644 --- a/tests/Test/ArkworksBridge.hs +++ b/tests/Test/ArkworksBridge.hs @@ -4,8 +4,16 @@ import qualified Data.ByteString.Lazy as LBS import Data.Field.Galois (GaloisField, PrimeField) import Data.Typeable (Typeable) import Snarkl.Backend.R1CS + ( mkInputsFilePath, + mkR1CSFilePath, + mkWitnessFilePath, + serializeInputsAsJson, + serializeR1CSAsJson, + serializeWitnessAsJson, + wit_of_r1cs, + ) import Snarkl.Compile (SimplParam, compileCompToR1CS) -import Snarkl.Language (Comp) +import Snarkl.Language.SyntaxMonad (Comp) import qualified System.Exit as GHC import System.Process (createProcess, shell, waitForProcess) diff --git a/tests/Test/Snarkl/DataflowSpec.hs b/tests/Test/Snarkl/DataflowSpec.hs index e4a5384..5dc2d1b 100644 --- a/tests/Test/Snarkl/DataflowSpec.hs +++ b/tests/Test/Snarkl/DataflowSpec.hs @@ -2,9 +2,7 @@ module Test.Snarkl.DataflowSpec where -import Data.Field.Galois (GaloisField, Prime, PrimeField, toP) -import qualified Data.IntMap as IntMap -import qualified Data.Map as Map +import Data.Field.Galois (Prime, toP) import qualified Data.Set as Set import GHC.TypeLits (KnownNat) import Snarkl.Common (Var (Var)) @@ -30,8 +28,8 @@ constraint2 :: (KnownNat p) => Constraint (Prime p) constraint2 = CMult (toP 2, Var 1) (toP 3, Var 2) (toP 4, Just $ Var 3) -- NOTE: notice 4 doesn't count as a variable here, WHY? -constraint3 :: (GaloisField k) => Constraint k -constraint3 = CMagic (Var 4) [Var 2, Var 3] $ \vars -> return True +constraint3 :: Constraint k +constraint3 = CMagic (Var 4) [Var 2, Var 3] $ \_ -> return True -- 4 is independent from 1,2,3 constraint4 :: (KnownNat p) => Constraint (Prime p) diff --git a/tests/Test/Snarkl/LambdaSpec.hs b/tests/Test/Snarkl/LambdaSpec.hs index ecf636e..e6b3fa9 100644 --- a/tests/Test/Snarkl/LambdaSpec.hs +++ b/tests/Test/Snarkl/LambdaSpec.hs @@ -5,12 +5,11 @@ {-# HLINT ignore "Redundant uncurry" #-} module Test.Snarkl.LambdaSpec where -import Data.Field.Galois (GaloisField, Prime) -import qualified Data.Map as Map -import GHC.TypeLits (KnownNat) -import Snarkl.Field -import Snarkl.Interp (interp) -import Snarkl.Language.Syntax +import Snarkl.Field (F_BN128) +import qualified Snarkl.Language.SyntaxMonad as SM +import Snarkl.Language.TExpr (TExp) +import Snarkl.Language.Type (Ty (TField, TFun, TProd)) +import Snarkl.Syntax ( apply, curry, lambda, @@ -19,29 +18,27 @@ import Snarkl.Language.Syntax (*), (+), ) -import qualified Snarkl.Language.SyntaxMonad as SM -import Snarkl.Language.TExpr (TExp, Ty (TField, TFun, TProd)) import Snarkl.Toplevel (comp_interp) -import Test.Hspec (Spec, describe, it, shouldBe) +import Test.Hspec (Spec, describe, it) import Test.QuickCheck (Testable (property)) -import Prelude hiding (apply, curry, return, uncurry, (*), (+)) +import Prelude hiding (curry, return, uncurry, (*), (+)) spec :: Spec spec = do describe "Snarkl.Lambda" $ do describe "curry/uncurry identities for simply operations" $ do it "curry . uncurry == id" $ do - let f :: (GaloisField k) => TExp 'TField k -> SM.Comp ('TFun 'TField 'TField) k + let f :: TExp 'TField k -> SM.Comp ('TFun 'TField 'TField) k f x = lambda $ \y -> SM.return (x + y) - g :: (GaloisField k) => TExp 'TField k -> SM.Comp ('TFun 'TField 'TField) k + g :: TExp 'TField k -> SM.Comp ('TFun 'TField 'TField) k g = curry (uncurry f) - prog1 :: (GaloisField k) => SM.Comp 'TField k + prog1 :: SM.Comp 'TField k prog1 = SM.fresh_input SM.>>= \a -> SM.fresh_input SM.>>= \b -> f a SM.>>= \k -> apply k b - prog2 :: (GaloisField k) => SM.Comp 'TField k + prog2 :: SM.Comp 'TField k prog2 = SM.fresh_input SM.>>= \a -> SM.fresh_input SM.>>= \b -> @@ -51,21 +48,21 @@ spec = do comp_interp @_ @F_BN128 prog1 [a, b] == comp_interp prog2 [a, b] it "uncurry . curry == id" $ do - let f :: (GaloisField k) => TExp ('TProd 'TField 'TField) k -> SM.Comp 'TField k + let f :: TExp ('TProd 'TField 'TField) k -> SM.Comp 'TField k f p = SM.fst_pair p SM.>>= \x -> SM.snd_pair p SM.>>= \y -> SM.return (x * y) - g :: (GaloisField k) => TExp ('TProd 'TField 'TField) k -> SM.Comp 'TField k + g :: TExp ('TProd 'TField 'TField) k -> SM.Comp 'TField k g = uncurry (curry f) - prog1 :: (GaloisField k) => SM.Comp 'TField k + prog1 :: SM.Comp 'TField k prog1 = SM.fresh_input SM.>>= \a -> SM.fresh_input SM.>>= \b -> pair a b SM.>>= \p -> f p - prog2 :: (GaloisField k) => SM.Comp 'TField k + prog2 :: SM.Comp 'TField k prog2 = SM.fresh_input SM.>>= \a -> SM.fresh_input SM.>>= \b -> diff --git a/tests/Test/Snarkl/Unit/Programs.hs b/tests/Test/Snarkl/Unit/Programs.hs index fa316ef..b915f25 100644 --- a/tests/Test/Snarkl/Unit/Programs.hs +++ b/tests/Test/Snarkl/Unit/Programs.hs @@ -8,19 +8,12 @@ module Test.Snarkl.Unit.Programs where -import Data.Field.Galois (GaloisField, Prime) -import Snarkl.Compile -import Snarkl.Example.Keccak import Snarkl.Example.Lam import Snarkl.Example.List import Snarkl.Example.Peano import Snarkl.Example.Tree -import Snarkl.Field (F_BN128, P_BN128) -import Snarkl.Language.Syntax -import Snarkl.Language.SyntaxMonad -import Snarkl.Language.TExpr -import Snarkl.Toplevel -import Test.Hspec (Spec, describe, it, shouldBe, shouldReturn) +import Snarkl.Field (F_BN128) +import Snarkl.Syntax import Prelude hiding ( fromRational, negate, diff --git a/tests/Test/Snarkl/UnitSpec.hs b/tests/Test/Snarkl/UnitSpec.hs index 0a3701a..6e356ac 100644 --- a/tests/Test/Snarkl/UnitSpec.hs +++ b/tests/Test/Snarkl/UnitSpec.hs @@ -7,21 +7,16 @@ import Data.Field.Galois (PrimeField) import Data.Typeable (Typeable) import Snarkl.Compile import Snarkl.Example.Keccak -import Snarkl.Example.Lam -import Snarkl.Example.List -import Snarkl.Example.Peano -import Snarkl.Example.Tree import Snarkl.Field -import Snarkl.Language (Comp) -import Snarkl.Language.Syntax hiding (negate) +import Snarkl.Language (Comp, Ty (..)) import Snarkl.Toplevel (Result (result_result), execute) import System.Exit (ExitCode (..)) import Test.ArkworksBridge (CMD (RunR1CS), runCMD) -import Test.Hspec (Spec, describe, it, shouldBe, shouldReturn) +import Test.Hspec (Spec, describe, it, shouldReturn) import Test.Snarkl.Unit.Programs import Prelude -test_comp :: (Typeable ty, PrimeField k) => SimplParam -> Comp ty k -> [k] -> IO (Either ExitCode k) +test_comp :: forall ty k. (Typeable ty, PrimeField k) => SimplParam -> Comp ty k -> [k] -> IO (Either ExitCode k) test_comp simpl mf args = do exit_code <- runCMD $ RunR1CS "./scripts" "hspec" simpl mf args @@ -71,7 +66,7 @@ spec = do it "8-1" $ test_comp @_ @F_BN128 Simplify prog8 [] `shouldReturn` Right 29 describe "unused inputs" $ do - it "11-1" $ test_comp @_ @F_BN128 Simplify prog11 [1, 1] `shouldReturn` Right 1 + it "11-1" $ test_comp @'TField @F_BN128 Simplify prog11 [1, 1] `shouldReturn` Right 1 describe "multiplicative identity" $ do it "13-1" $ test_comp @_ @F_BN128 Simplify prog13 [1] `shouldReturn` Right 1 @@ -146,7 +141,7 @@ spec = do it "8-1" $ test_comp @_ @F_BN128 Simplify prog8 [] `shouldReturn` Right 29 describe "unused inputs" $ do - it "11-1" $ test_comp @_ @F_BN128 Simplify prog11 [1, 1] `shouldReturn` Right 1 + it "11-1" $ test_comp @'TField @F_BN128 Simplify prog11 [1, 1] `shouldReturn` Right 1 describe "multiplicative identity" $ do it "13-1" $ test_comp @_ @F_BN128 Simplify prog13 [1] `shouldReturn` Right 1 diff --git a/tests/Test/UnionFindSpec.hs b/tests/Test/UnionFindSpec.hs index a932eb9..ee9774a 100644 --- a/tests/Test/UnionFindSpec.hs +++ b/tests/Test/UnionFindSpec.hs @@ -6,11 +6,21 @@ module Test.UnionFindSpec where -import Snarkl.Common +import Snarkl.Common (Var (..)) import Snarkl.Constraint.UnionFind -import Snarkl.Errors -import Test.Hspec + ( UnionFind, + empty, + insert, + root, + unite, + ) +import Test.Hspec (Spec, describe, it, shouldBe) import Test.QuickCheck + ( Arbitrary (arbitrary), + Testable (property), + forAll, + suchThat, + ) spec :: Spec spec = do