diff --git a/src/ZkFold/Symbolic/Base/Circuit.hs b/src/ZkFold/Symbolic/Base/Circuit.hs new file mode 100644 index 000000000..9ed77c0f8 --- /dev/null +++ b/src/ZkFold/Symbolic/Base/Circuit.hs @@ -0,0 +1,443 @@ +{-# LANGUAGE +AllowAmbiguousTypes +, DerivingStrategies +, DerivingVia +, QuantifiedConstraints +, RankNTypes +, TypeOperators +, UndecidableInstances +, UndecidableSuperClasses +#-} + +module ZkFold.Symbolic.Base.Circuit + ( Circuit (..), circuit, evalC + , MonadCircuit (..) + , IxMonadCircuit (..) + , CircuitIx (..) + , Blueprint + , SysVar (..) + , Var (..) + , Register (..) + , binaryExpansion + , compileC + , desolderC + , solderC + , newVarsC + ) where + +import Control.Applicative +import Control.Category +import Control.Monad +import Control.Monad.Trans +import Control.Monad.Trans.Indexed +import Data.Either +import Data.Eq +import Data.Foldable hiding (sum, product) +import Data.Function (($)) +import Data.IntMap (IntMap) +import qualified Data.IntMap as IntMap +import Data.Maybe +import Data.Monoid +import Data.Ord +import Data.Semigroup +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Traversable +import Data.Type.Equality +import qualified Data.Vector as V +import qualified Prelude + +import ZkFold.Symbolic.Base.Function +import ZkFold.Symbolic.Base.Num +import ZkFold.Symbolic.Base.Polynomial +import ZkFold.Symbolic.Base.Vector + +data Circuit x i o = UnsafeCircuit + { systemC :: Set (Poly (SysVar x i) Natural x) + -- ^ The system of polynomial constraints, + -- each polynomial constitutes a "multi-edge" of the circuit graph, + -- whose "vertices" are variables. + -- Polynomials constrain input variables and new variables. + -- Constant variables are absorbed into the polynomial coefficients. + , witnessC :: IntMap (i x -> x) + -- ^ The witness generation map, + -- witness functions for new variables. + -- Input and constant variables don't need witness functions. + , outputC :: o (Var x i) + -- ^ The output variables, + -- they can be input, constant or new variables. + } + +newVarsC :: Circuit x i o -> Int +newVarsC c = maybe 0 Prelude.fst (IntMap.lookupMax (witnessC c)) + +type Blueprint x i o = + forall t m. (IxMonadCircuit x t, Monad m) => t i i m (o (Var x i)) + +circuit + :: (Ord x, VectorSpace x i) + => Blueprint x i o + -> Circuit x i o +circuit m = case unPar1 (runCircuitIx m mempty) of + (o, c) -> c {outputC = o} + +evalC :: (VectorSpace x i, Functor o) => Circuit x i o -> i x -> o x +evalC c i = fmap (indexW (witnessC c) i) (outputC c) + +data SysVar x i + = InVar (Basis x i) + | NewVar Int +deriving stock instance VectorSpace x i => Eq (SysVar x i) +deriving stock instance VectorSpace x i => Ord (SysVar x i) + +data Var x i + = SysVar (SysVar x i) + | ConstVar x +deriving stock instance (Eq x, VectorSpace x i) => Eq (Var x i) +deriving stock instance (Ord x, VectorSpace x i) => Ord (Var x i) + +evalConst + :: (Ord x, VectorSpace x i) + => Poly (Var x i) Natural x + -> Poly (SysVar x i) Natural x +evalConst = mapPoly $ \case + ConstVar x -> Left x + SysVar v -> Right v + +indexW + :: VectorSpace x i + => IntMap (i x -> x) -> i x -> Var x i -> x +indexW witnessMap inp = \case + SysVar (InVar basisIx) -> indexV inp basisIx + SysVar (NewVar ix) -> fromMaybe zero (($ inp) <$> witnessMap IntMap.!? ix) + ConstVar x -> x + +instance (Ord x, VectorSpace x i, o ~ U1) => Monoid (Circuit x i o) where + mempty = UnsafeCircuit mempty mempty U1 +instance (Ord x, VectorSpace x i, o ~ U1) => Semigroup (Circuit x i o) where + c0 <> c1 = + let + varMax = newVarsC c0 + sysF = \case + InVar ix -> Right (InVar ix) + NewVar ix -> Right (NewVar (varMax + ix)) + in + UnsafeCircuit + { systemC = systemC c0 <> Set.map (mapPoly sysF) (systemC c1) + , witnessC = witnessC c0 <> IntMap.mapKeys (varMax +) (witnessC c1) + , outputC = U1 + } + +class Monad m => MonadCircuit x i m | m -> x, m -> i where + runCircuit + :: (VectorSpace x i, Functor o) + => Circuit x i o -> m (o (Var x i)) + input :: VectorSpace x i => m (i (Var x i)) + input = return (fmap (SysVar . InVar) (basisV @x)) + constraint + :: VectorSpace x i + => (forall a. Algebra x a => (Var x i -> a) -> a) + -> m () + newConstrained + :: VectorSpace x i + => (forall a. Algebra x a => (Var x i -> a) -> Var x i -> a) + -> ((Var x i -> x) -> x) + -> m (Var x i) + newAssigned + :: VectorSpace x i + => (forall a. Algebra x a => (Var x i -> a) -> a) + -> m (Var x i) + newAssigned p = newConstrained (\x i -> p x - x i) p + +class + ( forall i m. Monad m => MonadCircuit x i (t i i m) + , IxMonadTrans t + ) => IxMonadCircuit x t | t -> x where + apply + :: (VectorSpace x i, VectorSpace x j, Monad m) + => i x -> t (i :*: j) j m () + newInput + :: (VectorSpace x i, VectorSpace x j, Monad m) + => t j (i :*: j) m () + +newtype CircuitIx x i j m r = UnsafeCircuitIx + {runCircuitIx :: Circuit x i U1 -> m (r, Circuit x j U1)} + deriving Functor + +instance (Field x, Ord x, Monad m) + => Applicative (CircuitIx x i i m) where + pure x = UnsafeCircuitIx $ \c -> return (x,c) + (<*>) = apIx + +instance (Field x, Ord x, Monad m) + => Monad (CircuitIx x i i m) where + return = pure + (>>=) = Prelude.flip bindIx + +instance (Field x, Ord x, Monad m) + => MonadCircuit x i (CircuitIx x i i m) where + + runCircuit c1 = UnsafeCircuitIx $ \c0 -> do + let + outF = \case + SysVar (NewVar ix) -> SysVar (NewVar (newVarsC c0 + ix)) + v -> v + return (fmap outF (outputC c1), c0 <> c1 {outputC = U1}) + + constraint p = UnsafeCircuitIx $ \c -> return + ((), c {systemC = Set.insert (evalConst (p var)) (systemC c)}) + + newConstrained p w = UnsafeCircuitIx $ \c -> return $ + let + maxIndexMaybe = IntMap.lookupMax (witnessC c) + newIndex = maybe 0 ((1 +) . Prelude.fst) maxIndexMaybe + newWitness = w . indexW (witnessC c) + outVar = SysVar (NewVar newIndex) + newConstraint = evalConst (p var outVar) + newSystemC = Set.insert newConstraint (systemC c) + newWitnessC = IntMap.insert newIndex newWitness (witnessC c) + in + (outVar, c {systemC = newSystemC, witnessC = newWitnessC}) + +instance (i ~ j, Ord x, Field x) + => MonadTrans (CircuitIx x i j) where + lift m = UnsafeCircuitIx $ \c -> (, c) <$> m + +instance (Field x, Ord x) => IxMonadTrans (CircuitIx x) where + joinIx (UnsafeCircuitIx f) = UnsafeCircuitIx $ \c -> do + (UnsafeCircuitIx g, c') <- f c + g c' + +instance (Field x, Ord x) + => IxMonadCircuit x (CircuitIx x) where + apply i = UnsafeCircuitIx $ \c -> return + ( () + , c { systemC = Set.map (mapPoly sysF) (systemC c) + , witnessC = fmap witF (witnessC c) + , outputC = U1 + } + ) where + sysF = \case + InVar (Left bi) -> Left (indexV i bi) + InVar (Right bj) -> Right (InVar bj) + NewVar n -> Right (NewVar n) + witF f j = f (i :*: j) + + newInput = UnsafeCircuitIx $ \c -> return + ( () + , c { systemC = Set.map (mapPoly sysF) (systemC c) + , witnessC = fmap witF (witnessC c) + , outputC = U1 + } + ) + where + sysF = \case + InVar bj -> Right (InVar (Right bj)) + NewVar n -> Right (NewVar n) + witF f (_ :*: j) = f j + +instance (Ord x, VectorSpace x i) + => From x (Circuit x i Par1) where + from x = mempty { outputC = Par1 (ConstVar x) } + +instance (Ord x, VectorSpace x i) + => AdditiveMonoid (Circuit x i Par1) where + zero = from @x zero + c0 + c1 = circuit $ do + Par1 v0 <- runCircuit c0 + Par1 v1 <- runCircuit c1 + Par1 <$> newAssigned (\x -> x v0 + x v1) + +instance (Ord x, VectorSpace x i) + => AdditiveGroup (Circuit x i Par1) where + negate c = circuit $ do + Par1 v <- runCircuit c + Par1 <$> newAssigned (\x -> negate (x v)) + c0 - c1 = circuit $ do + Par1 v0 <- runCircuit c0 + Par1 v1 <- runCircuit c1 + Par1 <$> newAssigned (\x -> x v0 - x v1) + +instance (Ord x, VectorSpace x i) + => From Natural (Circuit x i Par1) where + from = from @x . from + +instance (Ord x, VectorSpace x i) + => From Integer (Circuit x i Par1) where + from = from @x . from + +instance (Ord x, VectorSpace x i) + => From Rational (Circuit x i Par1) where + from = from @x . from + +instance (Ord x, VectorSpace x i) + => MultiplicativeMonoid (Circuit x i Par1) where + one = from @x one + c0 * c1 = circuit $ do + Par1 v0 <- runCircuit c0 + Par1 v1 <- runCircuit c1 + Par1 <$> newAssigned (\x -> x v0 * x v1) + +instance From (Circuit x i Par1) (Circuit x i Par1) + +instance (Ord x, VectorSpace x i) + => Scalar Natural (Circuit x i Par1) where + scale = scale @x . from + combine = combineN + +instance (Ord x, VectorSpace x i) + => Scalar Integer (Circuit x i Par1) where + scale = scale @x . from + combine = combineZ + +instance (Ord x, VectorSpace x i) + => Scalar Rational (Circuit x i Par1) where + scale = scale @x . from + combine xs = sum [ scale k x | (k, x) <- xs ] + +instance (Ord x, VectorSpace x i) + => Scalar x (Circuit x i Par1) where + scale k c = circuit $ do + Par1 v <- runCircuit c + Par1 <$> newAssigned (\x -> k `scale` x v) + combine xs = sum [ scale k x | (k, x) <- xs ] + +instance (Ord x, VectorSpace x i) + => Scalar (Circuit x i Par1) (Circuit x i Par1) + +instance (Ord x, VectorSpace x i) + => Exponent Natural (Circuit x i Par1) where + exponent x p = evalMono [(x, p)] + evalMono = evalMonoN + +instance (Ord x, Discrete x, VectorSpace x i) + => Exponent Integer (Circuit x i Par1) where + exponent x p = evalMono [(x, p)] + evalMono = evalMonoZ + +instance (Ord x, Discrete x, VectorSpace x i) + => MultiplicativeGroup (Circuit x i Par1) where + recip c = + let + cInv = invertC c + _ :*: inv = outputC cInv + in + cInv { outputC = inv } + +instance (Ord x, Discrete x, VectorSpace x i) + => Discrete (Circuit x i Par1) where + dichotomy x y = isZero (x - y) + isZero c = + let + cInv = invertC c + isZ :*: _ = outputC cInv + in + cInv { outputC = isZ } + +invertC + :: (Ord x, Discrete x, VectorSpace x i) + => Circuit x i Par1 -> Circuit x i (Par1 :*: Par1) +invertC c = circuit $ do + Par1 v <- runCircuit c + isZ <- newConstrained + (\x i -> let xi = x i in xi * (xi - one)) + (\x -> isZero (x v)) + inv <- newConstrained + (\x i -> x i * x v + x isZ - one) + (\x -> recip (x v)) + return (Par1 isZ :*: Par1 inv) + +instance (PrimeField x, VectorSpace x i) + => Comparable (Circuit x i Par1) where + trichotomy c0 c1 = circuit $ do + UnsafeRegister v0 <- runCircuit (binaryExpansion c0) + UnsafeRegister v1 <- runCircuit (binaryExpansion c1) + let reverseLexicographical a b = b * b * (b - a) + a + v <- newAssigned $ \x -> + V.foldl reverseLexicographical one + (V.zipWith (\i0 i1 -> x i0 - x i1) v0 v1) + return (Par1 v) + +instance (Ord x, FiniteChr x, VectorSpace x i) + => FiniteChr (Circuit x i Par1) where + type Chr (Circuit x i Par1) = Chr x + +instance (PrimeField x, VectorSpace x i) => Symbolic x (Circuit x i Par1) + +-- A list of bits whose length is the number of bits +-- needed to represent an element of +-- the Arithmetic field of a Symbolic field extension. +newtype Register a = UnsafeRegister {fromRegister :: V.Vector a} + deriving stock (Functor, Foldable, Traversable) +instance Symbolic x a => VectorSpace a Register where + type Basis a Register = Int + indexV (UnsafeRegister v) ix = fromMaybe zero (v V.!? ix) + dimV = numberOfBits @x + basisV = UnsafeRegister (V.generate (from (numberOfBits @x)) id) + tabulateV f = UnsafeRegister (V.generate (from (numberOfBits @x)) f) + +binaryExpansion + :: forall x i. (PrimeField x, VectorSpace x i) + => Circuit x i Par1 + -> Circuit x i Register +binaryExpansion c = circuit $ do + Par1 v <- runCircuit c + lst <- expansion (from (numberOfBits @x)) v + return (UnsafeRegister (V.fromList lst)) + +horner + :: (VectorSpace x i, MonadCircuit x i m) + => [Var x i] -> m (Var x i) +-- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using +-- Horner's scheme. +horner xs = case Prelude.reverse xs of + [] -> return (ConstVar zero) + (b : bs) -> + foldlM (\a i -> newAssigned (\x -> let xa = x a in x i + xa + xa)) b bs + +bitsOf + :: (SemiEuclidean x, VectorSpace x i, MonadCircuit x i m) + => Int -> Var x i -> m [Var x i] +-- ^ @bitsOf n k@ creates @n@ bits and +-- sets their witnesses equal to @n@ smaller bits of @k@. +bitsOf n k = for [0 .. n - 1] $ \j -> newConstrained + (\x i -> let xi = x i in xi * (xi - one)) + ((Prelude.!! j) . expand . ($ k)) + where + two = from (2 :: Natural) + expand x = let (d,m) = divMod x two in m : expand d + +expansion + :: (SemiEuclidean x, VectorSpace x i, MonadCircuit x i m) + => Int -> Var x i -> m [Var x i] +-- ^ @expansion n k@ computes a binary expansion of @k@ if it fits in @n@ bits. +expansion n k = do + bits <- bitsOf n k + k' <- horner bits + constraint (\x -> x k - x k') + return bits + +solderC + :: (Ord x, VectorSpace x i, Functor o, Foldable o) + => o (Circuit x i Par1) + -> Circuit x i o +solderC cs = (fold (fmap (\c -> c {outputC = U1}) cs)) + { outputC = fmap (unPar1 . outputC) cs } + +desolderC + :: Functor o + => Circuit x i o + -> o (Circuit x i Par1) +desolderC c = fmap (\o -> c {outputC = Par1 o}) (outputC c) + +compileC + :: ( FunctionSpace (Circuit x i Par1) f + , i ~ InputSpace (Circuit x i Par1) f + , o ~ OutputSpace (Circuit x i Par1) f + , VectorSpace x i + , Functor o + , Foldable o + , Ord x + ) + => f -> Circuit x i o +compileC f = solderC (uncurryF f (desolderC (circuit input))) diff --git a/src/ZkFold/Symbolic/Base/Function.hs b/src/ZkFold/Symbolic/Base/Function.hs new file mode 100644 index 000000000..13bd29d78 --- /dev/null +++ b/src/ZkFold/Symbolic/Base/Function.hs @@ -0,0 +1,54 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} + +module ZkFold.Symbolic.Base.Function + ( FunctionSpace (..) + , InputSpace + , OutputSpace + ) where + +import Control.Category +import Data.Type.Equality + +import ZkFold.Symbolic.Base.Vector + +{- | `FunctionSpace` class of functions over variables. + +The type @FunctionSpace a f => f@ should be equal to some + +@vN a -> .. -> v1 a -> v0 a@ + +which via multiple-uncurrying is equivalent to + +@(vN :*: .. :*: v1 :*: U1) a -> v0 a@ +-} +class FunctionSpace a f where + uncurryF :: f -> InputSpace a f a -> OutputSpace a f a + curryF :: (InputSpace a f a -> OutputSpace a f a) -> f + +type family InputSpace a f where + InputSpace a (x a -> f) = x :*: InputSpace a f + InputSpace a (y a) = U1 + +type family OutputSpace a f where + OutputSpace a (x a -> f) = OutputSpace a f + OutputSpace a (y a) = y + +instance {-# OVERLAPPABLE #-} + ( OutputSpace a (y a) ~ y + , InputSpace a (y a) ~ U1 + ) => FunctionSpace a (y a) where + uncurryF f _ = f + curryF k = k U1 + +instance {-# OVERLAPPING #-} + ( OutputSpace a (x a -> f) ~ OutputSpace a f + , InputSpace a (x a -> f) ~ x :*: InputSpace a f + , FunctionSpace a f + ) => FunctionSpace a (x a -> f) where + uncurryF f (i :*: j) = uncurryF (f i) j + curryF k x = curryF (k . (:*:) x) diff --git a/src/ZkFold/Symbolic/Base/Num.hs b/src/ZkFold/Symbolic/Base/Num.hs new file mode 100644 index 000000000..e7367856b --- /dev/null +++ b/src/ZkFold/Symbolic/Base/Num.hs @@ -0,0 +1,743 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} + +module ZkFold.Symbolic.Base.Num + ( -- * Numeric types + Integer + , Natural + , Rational + , Int + , Mod + -- * Arithmetic constraints + , Symbolic + , PrimeField + -- * Algebraic constraints + , AdditiveMonoid (..) + , MultiplicativeMonoid (..) + , AdditiveGroup (..) + , MultiplicativeGroup (..) + , Semiring + , Ring + , Field + , Algebra + , SemiEuclidean (..) + , Euclidean (..) + , SemiIntegral + , Integral + , Modular + , Discrete (..) + , Comparable (..) + -- * Algebraic inter-constraints + , From (..) + , Into (..) + , Exponent (..), (^), (^^) + , Scalar (..) + -- * Type level numbers + , KnownNat + , Prime + , Finite (..) + , FiniteChr (..) + -- * Numeric combinators + , sum + , product + , even + , odd + , fromInteger + , fromSemiIntegral + , knownNat + , order + , numberOfBits + , characteristic + , combineN + , combineZ + , evalMonoN + , evalMonoZ + ) where + +import Control.Applicative +import Control.Category +import Data.Bool +import Data.Eq +import Data.Foldable hiding (product, sum, toList) +import Data.Functor +import Data.Kind +import Data.Ord +import Data.Ratio +import Data.Type.Bool +import Data.Type.Equality +import GHC.Exts (proxy#) +import GHC.TypeLits (ErrorMessage (..), TypeError) +import GHC.TypeNats hiding (Mod) +import qualified GHC.TypeNats as Type +import Prelude (Int, Integer) +import qualified Prelude + +-- Symbolic field extensions [Arithmetic a : a] should include: +-- PrimeField x => Symbolic x +-- PrimeField x => Symbolic (i -> x) +-- PrimeField x => Symbolic (Circuit x i Par1) +class + ( Field a + , Comparable a + , FiniteChr a + , 3 <= Chr a + , PrimeField x + , Algebra x a + , Chr a ~ Order x + ) => Symbolic x a | a -> x where + +-- Prime fields should only include: +-- (SemiIntegral int, Prime p) => PrimeField (int `Mod` p) +-- and newtypes of int `Mod` p. +-- p = 2 is ruled out to allow/require trichotomy. +type PrimeField x = + ( Modular x + , Finite x + , Symbolic x x + ) + +class AdditiveMonoid a where + infixl 6 + + (+) :: a -> a -> a + zero :: a + +sum :: (Foldable t, AdditiveMonoid a) => t a -> a +sum = foldl' (+) zero + +class AdditiveMonoid a => AdditiveGroup a where + negate :: a -> a + negate a = zero - a + infixl 6 - + (-) :: a -> a -> a + a - b = a + negate b + +class MultiplicativeMonoid a where + infixl 7 * + (*) :: a -> a -> a + one :: a + +product :: (Foldable t, MultiplicativeMonoid a) => t a -> a +product = foldl' (*) one + +class MultiplicativeMonoid a => MultiplicativeGroup a where + recip :: a -> a + recip a = one / a + infixl 7 / + (/) :: a -> a -> a + a / b = a * recip b + +-- from @Natural is the unique homomorphism from the free Semiring +-- from @Integer is the unique homomorphism from the free Ring +-- from @Rational is the unique homomorphism from the free Field +-- +-- prop> from . from = from +-- prop> from @a @a = id +class From x a where + from :: x -> a + default from :: x ~ a => x -> a + from = id + +type Semiring a = + ( AdditiveMonoid a + , MultiplicativeMonoid a + , From Natural a + , From a a + , Scalar Natural a + , Scalar a a + , Exponent Natural a + ) + +type Ring a = + ( AdditiveGroup a + , MultiplicativeMonoid a + , From Natural a + , From Integer a + , From a a + , Scalar Natural a + , Scalar Integer a + , Scalar a a + , Exponent Natural a + ) + +type Field a = + ( AdditiveGroup a + , MultiplicativeGroup a + , From Natural a + , From Integer a + , From Rational a + , From a a + , Scalar Natural a + , Scalar Integer a + , Scalar Rational a + , Scalar a a + , Exponent Natural a + , Exponent Integer a + ) + +type Algebra x a = (Ring x, Ring a, From x a, Scalar x a) + +class Semiring a => SemiEuclidean a where + divMod :: a -> a -> (a,a) + div :: a -> a -> a + div a b = let (divisor,_) = divMod a b in divisor + mod :: a -> a -> a + mod a b = let (_,modulus) = divMod a b in modulus + quotRem :: a -> a -> (a,a) + quot :: a -> a -> a + quot a b = let (quotient,_) = quotRem a b in quotient + rem :: a -> a -> a + rem a b = let (_,remainder) = quotRem a b in remainder + +class (SemiEuclidean a, Ring a) => Euclidean a where + eea :: a -> a -> (a,a,a) + default eea :: Eq a => a -> a -> (a,a,a) + eea = xEuclid one zero zero one where + xEuclid x0 y0 x1 y1 u v + | v == zero = (u,x0,y0) + | otherwise = + let + (q , r) = u `divMod` v + x2 = x0 - q * x1 + y2 = y0 - q * y1 + in + xEuclid x1 y1 x2 y2 v r + gcd :: a -> a -> a + gcd a b = let (d,_,_) = eea a b in d + +even :: SemiIntegral a => a -> Bool +even a = a `mod` (from @Natural 2) == zero + +odd :: SemiIntegral a => a -> Bool +odd a = a `mod` (from @Natural 2) == one + +class Ring a => Discrete a where + dichotomy :: a -> a -> a + default dichotomy :: Eq a => a -> a -> a + dichotomy a b = if a == b then one else zero + isZero :: a -> a + default isZero :: a -> a + isZero = dichotomy zero + +class Discrete a => Comparable a where + trichotomy :: a -> a -> a + default trichotomy :: Ord a => a -> a -> a + trichotomy a b = case compare a b of + LT -> negate one + EQ -> zero + GT -> one + +-- prop> to . from = id +-- prop> to . to = to +-- prop> to @a @a = id +class Into y a where + to :: a -> y + default to :: y ~ a => a -> y + to = id + +-- e.g. `Integer`, `Natural`, `Mod int` +type SemiIntegral a = + ( Prelude.Ord a + , SemiEuclidean a + , Into Rational a + , Into Integer a + ) +-- e.g. `Integer`, `Mod int` +type Integral a = + ( Prelude.Ord a + , Euclidean a + , Into Rational a + , Into Integer a + ) +-- e.g. `Mod int` & fixed-width unsigned integer types +type Modular a = + ( Prelude.Ord a + , Euclidean a + , Into Rational a + , Into Integer a + , Into Natural a + ) + +fromSemiIntegral :: (SemiIntegral a, From Integer b) => a -> b +fromSemiIntegral = from . to @Integer + +fromInteger :: From Integer b => Integer -> b +fromInteger = from + +-- Type level numbers --------------------------------------------------------- + +knownNat :: forall n. KnownNat n => Natural +knownNat = natVal' (proxy# @n) + +class + ( KnownNat (Order a) + , KnownNat (NumberOfBits a) + ) => Finite a where + type Order a :: Natural + +order :: forall a. Finite a => Natural +order = knownNat @(Order a) + +type NumberOfBits a = Log2 (Order a - 1) + 1 + +numberOfBits :: forall a. Finite a => Natural +numberOfBits = knownNat @(NumberOfBits a) + +class + ( KnownNat (Chr a) + , Semiring a + ) => FiniteChr a where + type Chr a :: Natural + +characteristic :: forall a. FiniteChr a => Natural +characteristic = knownNat @(Chr a) + +-- Use orphan instances for large publicly verified primes +class KnownNat p => Prime p +-- Use this overlappable instance for small enough primes and testing +instance {-# OVERLAPPABLE #-} (KnownNat p, KnownPrime p) => Prime p + +type family KnownPrime p where + KnownPrime p = If (IsPrime p) (() :: Constraint) (TypeError (NotPrimeError p)) + +type NotPrimeError p = + 'Text "Error: " ':<>: 'ShowType p ':<>: 'Text " is not a prime number." + +type family IsPrime p where + IsPrime 0 = 'False + IsPrime 1 = 'False + IsPrime 2 = 'True + IsPrime 3 = 'True + IsPrime n = NotDividesFromTo n 2 (AtLeastSqrt n) + +type family NotZero n where + NotZero 0 = 'False + NotZero n = 'True + +type family NotDividesFromTo dividend divisor0 divisor1 where + NotDividesFromTo dividend divisor divisor = NotZero (dividend `Type.Mod` divisor) + NotDividesFromTo dividend divisor0 divisor1 = + NotZero (dividend `Type.Mod` divisor0) && NotDividesFromTo dividend (divisor0 + 1) divisor1 + +type family AtLeastSqrt n where + AtLeastSqrt 0 = 0 + AtLeastSqrt n = 2 ^ (Log2 n `Div` 2 + 1) + +-- Rational ------------------------------------------------------------------- + +instance AdditiveMonoid Rational where + (+) = (Prelude.+) + zero = 0 + +instance AdditiveGroup Rational where + negate = Prelude.negate + (-) = (Prelude.-) + +instance MultiplicativeMonoid Rational where + (*) = (Prelude.*) + one = 1 + +instance MultiplicativeGroup Rational where + (/) = (Prelude./) + +instance Scalar Natural Rational where + scale n q = to n * q + combine = combineN +instance Scalar Integer Rational where + scale i q = to i * q + combine = combineZ +instance Scalar Rational Rational where + scale = (*) + combine terms = + let + coefs = [c | (c,_) <- terms] + commonDenom :: Integer = product (fmap (denominator . to) coefs) + clearDenom c = (numerator c * commonDenom) `div` commonDenom + numerators = fmap (\(c,a) -> (clearDenom (to c), a)) terms + in + combine numerators / from commonDenom + +instance Exponent Natural Rational where + exponent = (Prelude.^) + evalMono = evalMonoN +instance Exponent Integer Rational where + exponent = (Prelude.^^) + evalMono = evalMonoZ + +instance From Rational Rational +instance From Natural Rational where from = Prelude.fromIntegral +instance From Integer Rational where from = Prelude.fromInteger +instance Into Rational Rational + +instance Discrete Rational +instance Comparable Rational + +-- Integer -------------------------------------------------------------------- + +instance AdditiveMonoid Integer where + (+) = (Prelude.+) + zero = 0 + +instance AdditiveGroup Integer where + negate = Prelude.negate + (-) = (Prelude.-) + +instance MultiplicativeMonoid Integer where + (*) = (Prelude.*) + one = 1 + +instance Scalar Natural Integer where + scale n z = to n * z + combine = combineN +instance Scalar Integer Integer where + scale = (*) + combine = combineZ + +instance Exponent Natural Integer where + exponent = (Prelude.^) + evalMono = evalMonoN + +instance From Integer Integer +instance From Natural Integer where from = Prelude.fromIntegral +instance Into Rational Integer where to = Prelude.toRational +instance Into Integer Integer + +instance Discrete Integer +instance Comparable Integer + +instance SemiEuclidean Integer where + divMod = Prelude.divMod + quotRem = Prelude.quotRem +instance Euclidean Integer + +-- Natural -------------------------------------------------------------------- + +instance AdditiveMonoid Natural where + (+) = (Prelude.+) + zero = 0 + +instance MultiplicativeMonoid Natural where + (*) = (Prelude.*) + one = 1 + +instance Scalar Natural Natural where + scale = (*) + combine = combineN + +instance Exponent Natural Natural where + exponent = (Prelude.^) + evalMono = evalMonoN + +instance From Natural Natural +instance Into Integer Natural where to = Prelude.toInteger +instance Into Rational Natural where to = Prelude.toRational +instance Into Natural Natural + +instance SemiEuclidean Natural where + divMod = Prelude.divMod + quotRem = Prelude.quotRem + +-- Int ------------------------------------------------------------------------ + +instance AdditiveMonoid Int where + (+) = (Prelude.+) + zero = 0 + +instance AdditiveGroup Int where + negate = Prelude.negate + (-) = (Prelude.-) + +instance MultiplicativeMonoid Int where + (*) = (Prelude.*) + one = 1 + +instance Scalar Natural Int where + scale n z = from n * z + combine = combineN +instance Scalar Integer Int where + scale n z = from n * z + combine = combineZ +instance Scalar Int Int where + scale = (*) + combine terms = + combineZ [(to @Integer c,x) | (c,x) <- terms] + +instance Exponent Natural Int where + exponent = (Prelude.^) + evalMono = evalMonoN + +instance From Int Int +instance From Natural Int where from = Prelude.fromIntegral +instance From Integer Int where from = Prelude.fromIntegral +instance Into Rational Int where to = Prelude.toRational +instance Into Integer Int where to = Prelude.fromIntegral +instance Into Int Int + +instance Discrete Int +instance Comparable Int + +instance SemiEuclidean Int where + divMod = Prelude.divMod + quotRem = Prelude.quotRem +instance Euclidean Int + +-- Function ------------------------------------------------------------------- + +instance {-# OVERLAPPING #-} Semiring a + => AdditiveMonoid (i -> a) where + zero = pure zero + (+) = liftA2 (+) + +instance Ring a => AdditiveGroup (i -> a) where + negate = fmap negate + (-) = liftA2 (-) + +instance MultiplicativeMonoid a => MultiplicativeMonoid (i -> a) where + one = pure one + (*) = liftA2 (*) + +instance (Scalar c a, Scalar c c, Scalar a a) + => Scalar (i -> c) (i -> a) where + scale c a i = scale (c i) (a i) + combine terms i = combine (fmap (\(c,f) -> (c i, f i)) terms) + +instance Semiring a => Scalar Natural (i -> a) where + scale c f = scale c . f + combine terms i = combine (fmap (\(c,f) -> (c, f i)) terms) + +instance (Semiring a, Scalar Integer a) + => Scalar Integer (i -> a) where + scale c f = scale c . f + combine terms i = combine (fmap (\(c,f) -> (c, f i)) terms) + +instance (Semiring a, Scalar Rational a) + => Scalar Rational (i -> a) where + scale c f = scale c . f + combine terms i = combine (fmap (\(c,f) -> (c, f i)) terms) + +instance (Exponent pow a) => Exponent pow (i -> a) where + exponent a p i = exponent (a i) p + evalMono factors i = evalMono (fmap (\(f,p) -> (f i, p)) factors) + +instance MultiplicativeGroup a => MultiplicativeGroup (i -> a) where + (/) = liftA2 (/) + +instance From Natural a => From Natural (i -> a) where + from = pure . from + +instance From Integer a => From Integer (i -> a) where + from = pure . from + +instance From Rational a => From Rational (i -> a) where + from = pure . from + +instance From (i -> a) (i -> a) + +instance Discrete a => Discrete (i -> a) where + dichotomy = liftA2 dichotomy + +instance Comparable a => Comparable (i -> a) where + trichotomy = liftA2 trichotomy + +-- Mod ------------------------------------------------------------------------ +newtype Mod int n = UnsafeMod {fromMod :: int} + deriving newtype (Eq, Ord, Prelude.Show) + +instance (SemiIntegral int, KnownNat n) + => AdditiveMonoid (Mod int n) where + a + b = from (to @Integer a + to b) + zero = UnsafeMod zero + +instance (SemiIntegral int, KnownNat n) + => AdditiveGroup (Mod int n) where + negate = from . negate . to @Integer + a - b = from (to @Integer a - to b) + +instance (SemiIntegral int, KnownNat n) + => MultiplicativeMonoid (Mod int n) where + a * b = from (to @Integer a * to b) + one = UnsafeMod one + +instance (SemiIntegral int, Prime p) + => MultiplicativeGroup (Mod int p) where + recip a = case eea (to @Integer a) (from (knownNat @p)) of + (_,q,_) -> from q + +instance (SemiIntegral int, KnownNat n) + => SemiEuclidean (Mod int n) where + divMod a b = case divMod (to @Natural a) (to b) of + (d,m) -> (from d, from m) + quotRem a b = case quotRem (to @Natural a) (to b) of + (q,r) -> (from q, from r) + +instance (SemiIntegral int, KnownNat n) + => Euclidean (Mod int n) where + eea a b = + let + (d,b0,b1) = eea (to @Integer a) (to b) + in + (from d, from b0, from b1) + +residue :: forall n int. (SemiEuclidean int, KnownNat n) => int -> int +residue int = int `mod` from (knownNat @n) + +instance From (Mod int n) (Mod int n) + +instance (From Natural int, SemiEuclidean int, KnownNat n) + => From Natural (Mod int n) where + from = UnsafeMod . residue @n . from + +instance (SemiIntegral int, KnownNat n) + => From Integer (Mod int n) where + from = UnsafeMod . from @Natural . Prelude.fromIntegral . residue @n + +instance (SemiIntegral int, Prime p) + => From Rational (Mod int p) where + from q = from (numerator q) / from (denominator q) + +instance Into (Mod int n) (Mod int n) + +instance Into Rational int + => Into Rational (Mod int n) where + to = to . fromMod + +instance Into Integer int + => Into Integer (Mod int n) where + to = to . fromMod + +instance SemiIntegral int + => Into Natural (Mod int n) where + to = Prelude.fromInteger . to + +instance (SemiIntegral int, KnownNat n) + => Scalar Natural (Mod int n) where + scale n z = from n * z + combine = combine . fmap (\(c,a) -> (from @_ @Natural c, a)) +instance (SemiIntegral int, KnownNat n) + => Scalar Integer (Mod int n) where + scale n z = from n * z + combine = combine . fmap (\(c,a) -> (from @_ @Integer c, a)) +instance (SemiIntegral int, Prime p) + => Scalar Rational (Mod int p) where + scale n z = from n * z + combine = combine . fmap (\(c,a) -> (from @_ @Rational c, a)) +instance (SemiIntegral int, KnownNat n) + => Scalar (Mod int n) (Mod int n) where + scale = (*) + combine = combineN . fmap (\(c,a) -> (to @Natural c, a)) + +instance (SemiIntegral int, KnownNat n) + => Exponent Natural (Mod int n) where + exponent a q = exponent a (from @_ @(Mod int n) q) + evalMono = evalMono . fmap (\(a,q) -> (a,from @_ @(Mod int n) q)) + +instance (SemiIntegral int, Prime p) + => Exponent Integer (Mod int p) where + exponent a q = + if q >= zero + then exponent a (from @_ @(Mod int p) q) + else one / exponent a (from @_ @(Mod int p) (negate q)) + evalMono = evalMono . fmap absPow where + absPow (a,p) = + if p >= 0 + then (a, from @_ @(Mod int p) p) + else (one / a, from (negate p)) + +instance (SemiIntegral int, KnownNat n) + => Exponent (Mod int n) (Mod int n) where + exponent a q = from (to @Natural a ^ to @Natural q) + evalMono + = from + . evalMono + . fmap (\(a,q) -> (to @Natural a, to @Natural q)) + +instance (SemiIntegral int, KnownNat n) => Discrete (Mod int n) +instance (SemiIntegral int, KnownNat n) => Comparable (Mod int n) + +instance (KnownNat n, KnownNat (Log2 (n - 1) + 1)) + => Finite (Mod int n) where + type Order (Mod int n) = n + +instance (SemiIntegral int, KnownNat n) + => FiniteChr (Mod int n) where + type Chr (Mod int n) = n + +instance (SemiIntegral int, Prime p, KnownNat (Log2 (p - 1) + 1), 3 <= p) + => Symbolic (Mod int p) (Mod int p) + +-- Scalar ------------------------------------------------------------------ +class (Semiring c, AdditiveMonoid a) + => Scalar c a where + scale :: c -> a -> a + default scale :: c ~ a => c -> a -> a + scale = (*) + combine :: [(c,a)] -> a + default combine :: c ~ a => [(c,a)] -> a + combine = sum . fmap (Prelude.uncurry (*)) + +combineN + :: (Into Natural c, AdditiveMonoid a) + => [(c,a)] -> a +combineN combination = combineNat naturalized where + naturalized = [(to @Natural c, a) | (c,a) <- combination] + combineNat [] = zero + combineNat terms = + let + halves = combineNat [(c `div` 2, a) | (c,a) <- terms, c > 1] + halveNots = sum [a | (c,a) <- terms, odd c] + in + halves + halves + halveNots + +combineZ + :: (Into Integer c, AdditiveGroup a) + => [(c,a)] -> a +combineZ = combineN . fmap absCoeff where + absCoeff (c,a) = + let + cZ = to @Integer c + in + if cZ >= 0 + then (Prelude.fromIntegral cZ :: Natural, a) + else (Prelude.fromIntegral (negate cZ), negate a) + +-- Exponent ------------------------------------------------------------------- +class (Semiring pow, MultiplicativeMonoid a) + => Exponent pow a where + exponent :: a -> pow -> a + evalMono :: [(a,pow)] -> a + +infixr 8 ^ +(^) :: (Exponent Natural a, Into Natural pow) => a -> pow -> a +a ^ b = exponent @Natural a (to b) + +infixr 8 ^^ +(^^) :: (Exponent Integer a, Into Integer pow) => a -> pow -> a +a ^^ b = exponent @Integer a (to b) + +evalMonoN + :: (Into Natural p, MultiplicativeMonoid a) + => [(a,p)] -> a +evalMonoN monomial = evalMonoNat naturalized where + naturalized = [(a, to @Natural p) | (a,p) <- monomial] + evalMonoNat [] = one + evalMonoNat factors = + let + sqrts = evalMonoNat [(a,p `div` 2) | (a,p) <- factors, p > 1] + sqrtNots = product [a | (a,p) <- factors, odd p] + in + sqrts * sqrts * sqrtNots + +evalMonoZ + :: (Into Integer p, MultiplicativeGroup a) + => [(a,p)] -> a +evalMonoZ = evalMonoN . fmap absPow where + absPow (a,p) = + let + pZ = to @Integer p + in + if pZ >= 0 + then (a, Prelude.fromIntegral pZ :: Natural) + else (one / a, Prelude.fromIntegral (negate pZ)) diff --git a/src/ZkFold/Symbolic/Base/Polynomial.hs b/src/ZkFold/Symbolic/Base/Polynomial.hs new file mode 100644 index 000000000..0155306a3 --- /dev/null +++ b/src/ZkFold/Symbolic/Base/Polynomial.hs @@ -0,0 +1,164 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} + +module ZkFold.Symbolic.Base.Polynomial + ( Poly + , Mono (..), mono + , Combo (..), combo + , var + , varSet + , evalPoly + , mapPoly + ) where + +import Control.Category +import Data.Bifunctor +import Data.Either +import Data.Eq +import Data.Foldable hiding (product, sum, toList) +import Data.Functor +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map +import Data.Monoid +import Data.Ord +import Data.Set (Set) +import qualified Data.Set as Set +import GHC.IsList +import GHC.TypeNats +import qualified Prelude + +import ZkFold.Symbolic.Base.Num + +newtype Mono var pow = UnsafeMono {fromMono :: Map var pow} + deriving (Eq, Ord, Functor) + +mono :: (AdditiveMonoid pow, Eq pow) => Map var pow -> Mono var pow +mono = UnsafeMono . Map.filter (/= zero) + +instance (Ord var, AdditiveMonoid pow, Eq pow) + => IsList (Mono var pow) where + type Item (Mono var pow) = (var,pow) + toList (UnsafeMono m) = toList m + fromList l = + let + inserter m (v, pow) = Map.insertWith (+) v pow m + in + mono (foldl' inserter Map.empty l) + +instance (Ord var, Semiring pow, Eq pow) + => MultiplicativeMonoid (Mono var pow) where + one = UnsafeMono Map.empty + UnsafeMono x * UnsafeMono y = mono (Map.unionWith (+) x y) + +instance (Ord var, Semiring pow, Ord pow) + => Exponent Natural (Mono var pow) where + exponent a p = evalMono [(a,p)] + evalMono = evalMonoN + +newtype Combo var coef = UnsafeCombo {fromCombo :: Map var coef} + deriving (Eq, Ord, Functor) + +combo :: (AdditiveMonoid coef, Eq coef) => Map var coef -> Combo var coef +combo = UnsafeCombo . Map.filter (/= zero) + +instance (Ord var, Semiring coef, Eq coef) + => AdditiveMonoid (Combo var coef) where + zero = UnsafeCombo Map.empty + UnsafeCombo x + UnsafeCombo y = combo (Map.unionWith (+) x y) + +instance (Ord var, Ring coef, Eq coef) + => AdditiveGroup (Combo var coef) where + negate (UnsafeCombo x) = UnsafeCombo (Map.map negate x) + x - y = x + negate y + +instance (Ord var, AdditiveMonoid coef, Eq coef) + => IsList (Combo var coef) where + type Item (Combo var coef) = (coef,var) + toList (UnsafeCombo m) = [(c,v) | (v,c) <- toList m] + fromList l = + let + inserter m (c,v) = Map.insertWith (+) v c m + in + combo (foldl' inserter Map.empty l) + +type Poly var pow = Combo (Mono var pow) +instance (Ord var, Ord pow, Semiring pow, Semiring coef, Eq coef) + => From coef (Poly var pow coef) where + from coef = fromList [(coef,one)] +instance (Ord var, Ord pow, Semiring pow, Semiring coef, Eq coef) + => From Natural (Poly var pow coef) where + from = from @coef . from +instance (Ord var, Ord pow, Semiring pow, Ring coef, Eq coef) + => From Integer (Poly var pow coef) where + from = from @coef . from +instance (Ord var, Ord pow, Semiring pow, Field coef, Eq coef) + => From Rational (Poly var pow coef) where + from = from @coef . from +instance From (Poly var pow coef) (Poly var pow coef) +instance (Ord var, Ord pow, Semiring pow, Semiring coef, Eq coef) + => MultiplicativeMonoid (Poly var pow coef) where + one = fromList [(one, one)] + x * y = fromList + [ (xCoef * yCoef, xMono * yMono) + | (xCoef, xMono) <- toList x + , (yCoef, yMono) <- toList y + ] +instance (Ord var, Ord pow, Semiring pow, Semiring coef, Ord coef) + => Exponent Natural (Poly var pow coef) where + exponent x p = evalMono [(x,p)] + evalMono = evalMonoN +instance (Ord var, Ord pow, Semiring x, Eq x) + => Scalar x (Poly var pow x) where + scale c = if c == zero then Prelude.const zero else fmap (c *) + combine polys = + let + monos = + [(m,(c,c')) | (c,p) <- polys, (c',m) <- toList p] + insertCoefs mMap (m,cs) = Map.insertWith (<>) m [cs] mMap + monoMap = foldl' insertCoefs Map.empty monos + in + combo (fmap combine monoMap) +instance (Ord var, Ord pow, Semiring x, Eq x) + => Scalar Natural (Poly var pow x) where + scale c = if c == zero then Prelude.const zero else fmap (from c *) + combine = combineN +instance (Ord var, Ord pow, Ring x, Eq x) + => Scalar Integer (Poly var pow x) where + scale c = if c == zero then Prelude.const zero else fmap (from c *) + combine = combineZ +instance (Ord var, Ord pow, Semiring pow, Semiring x, Ord x) + => Scalar (Poly var pow x) (Poly var pow x) + +var + :: (Ord var, Ord pow, Semiring pow, Semiring coef, Eq coef) + => var -> Poly var pow coef +var x = fromList [(one, fromList [(x,one)])] + +varSet :: Ord var => Poly var pow coef -> Set var +varSet + = Set.unions + . Set.map (Map.keysSet . fromMono) + . Map.keysSet + . fromCombo + +-- evaluate a polynomial in its semiring of coefficients +evalPoly :: (Ord x, Semiring x) => Poly x Natural x -> x +evalPoly x = combine [(c, evalMono (toList m)) | (c,m) <- toList x] + +-- map a polynomial to new variables, evaluating some variables +mapPoly + :: (Eq x, Ord var0, Ord var1, Semiring x) + => (var0 -> Either x var1) + -> Poly var0 Natural x -> Poly var1 Natural x +mapPoly f polynomial = fromList + [ + let + (coefMono, varMono) = partitionEithers + [bimap (,p) (,p) (f v0) | (v0,p) <- toList monomial] + in + (c * evalMono coefMono, fromList varMono) + | (c, monomial) <- toList polynomial + ] diff --git a/src/ZkFold/Symbolic/Base/Vector.hs b/src/ZkFold/Symbolic/Base/Vector.hs new file mode 100644 index 000000000..d14794cb0 --- /dev/null +++ b/src/ZkFold/Symbolic/Base/Vector.hs @@ -0,0 +1,246 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} + +module ZkFold.Symbolic.Base.Vector + ( -- * VectorSpace + VectorSpace (..) + -- * Vector types + , Vector (..), vector + , SparseV (..), sparseV + , Gen.U1 (..) + , Gen.Par1 (..) + , (Gen.:*:) (..) + , (Gen.:.:) (..) + -- * Structure combinators + , constV + , zipWithV + -- * VectorSpace combinators + , zeroV + , addV + , subtractV + , negateV + , scaleV + , dotV + ) where + +import Control.Category +import Control.Monad +import Data.Bool +import Data.Distributive +import Data.Either +import Data.Eq +import Data.Foldable hiding (sum) +import Data.Function (const, ($)) +import Data.Functor +import Data.Functor.Rep +import Data.IntMap (IntMap) +import qualified Data.IntMap as IntMap +import Data.Kind (Type) +import Data.Maybe +import Data.Monoid +import Data.Ord +import Data.Traversable +import Data.Type.Equality +import qualified Data.Vector as V +import Data.Void +import qualified GHC.Generics as Gen +import qualified Prelude + +import ZkFold.Symbolic.Base.Num + +{- | +Class of vector spaces with a basis. + +`VectorSpace` is a known sized "monorepresentable" class, +similar to `Representable` plus `Traversable`, +but with a fixed element type that is a `Field`. + +A "vector" in a `VectorSpace` can be thought of as a +tuple of numbers @(x1,..,xn)@. +-} +class + ( Field a + , Traversable v + , Ord (Basis a v) + ) => VectorSpace a v where + {- | The `Basis` for a `VectorSpace`. More accurately, + `Basis` will be a spanning set with "out-of-bounds" + basis elements corresponding with zero. + -} + + type Basis a v :: Type + type Basis a v = Basis a (Gen.Rep1 v) + + tabulateV :: (Basis a v -> a) -> v a + default tabulateV + :: ( Gen.Generic1 v + , VectorSpace a (Gen.Rep1 v) + , Basis a v ~ Basis a (Gen.Rep1 v) + ) + => (Basis a v -> a) -> v a + tabulateV = Gen.to1 . tabulateV + + indexV :: v a -> Basis a v -> a + default indexV + :: ( Gen.Generic1 v + , VectorSpace a (Gen.Rep1 v) + , Basis a v ~ Basis a (Gen.Rep1 v) + ) + => v a -> Basis a v -> a + indexV = indexV . Gen.from1 + + dimV :: Natural + default dimV :: VectorSpace a (Gen.Rep1 v) => Natural + dimV = dimV @a @(Gen.Rep1 v) + + basisV :: v (Basis a v) + default basisV + :: ( Gen.Generic1 v + , VectorSpace a (Gen.Rep1 v) + , Basis a v ~ Basis a (Gen.Rep1 v) + ) + => v (Basis a v) + basisV = Gen.to1 (basisV @a @(Gen.Rep1 v)) + +constV :: VectorSpace a v => a -> v a +constV = tabulateV . const + +zipWithV :: VectorSpace a v => (a -> a -> a) -> v a -> v a -> v a +zipWithV f as bs = tabulateV $ \k -> + f (indexV as k) (indexV bs k) + +zeroV :: VectorSpace a v => v a +zeroV = constV zero + +addV :: VectorSpace a v => v a -> v a -> v a +addV = zipWithV (+) + +subtractV :: VectorSpace a v => v a -> v a -> v a +subtractV = zipWithV (-) + +negateV :: VectorSpace a v => v a -> v a +negateV = fmap negate + +scaleV :: VectorSpace a v => a -> v a -> v a +scaleV c = fmap (c *) + +-- | dot product +dotV :: VectorSpace a v => v a -> v a -> a +v `dotV` w = sum (zipWithV (*) v w) + +-- generic vector space +instance VectorSpace a v + => VectorSpace a (Gen.M1 i c v) where + type Basis a (Gen.M1 i c v) = Basis a v + indexV (Gen.M1 v) = indexV v + tabulateV f = Gen.M1 (tabulateV f) + dimV = dimV @a @v + basisV = Gen.M1 (basisV @a @v) + +-- zero dimensional vector space +instance Field a => VectorSpace a Gen.U1 where + type Basis a Gen.U1 = Void + tabulateV = tabulate + indexV = index + dimV = zero + basisV = Gen.U1 + +-- one dimensional vector space +instance Field a => VectorSpace a Gen.Par1 where + type Basis a Gen.Par1 = () + tabulateV = tabulate + indexV = index + dimV = one + basisV = Gen.Par1 () + +-- direct sum of vector spaces +instance (VectorSpace a v, VectorSpace a u) + => VectorSpace a (v Gen.:*: u) where + type Basis a (v Gen.:*: u) = Either (Basis a v) (Basis a u) + tabulateV f = tabulateV (f . Left) Gen.:*: tabulateV (f . Right) + indexV (a Gen.:*: _) (Left i) = indexV a i + indexV (_ Gen.:*: b) (Right j) = indexV b j + dimV = dimV @a @v + dimV @a @u + basisV = fmap Left (basisV @a @v) Gen.:*: fmap Right (basisV @a @u) + +-- tensor product of vector spaces +instance + ( VectorSpace a v + , VectorSpace a u + , Representable v + , Basis a v ~ Rep v + ) => VectorSpace a (v Gen.:.: u) where + type Basis a (v Gen.:.: u) = (Basis a v, Basis a u) + tabulateV = Gen.Comp1 . tabulate . fmap tabulateV . Prelude.curry + indexV (Gen.Comp1 fg) (i, j) = indexV (index fg i) j + dimV = dimV @a @v * dimV @a @u + basisV = Gen.Comp1 + (basisV @a @v <&> \bv -> basisV @a @u <&> \bu -> (bv,bu)) + +-- | concrete vectors +newtype Vector (n :: Natural) a = UnsafeV {fromV :: V.Vector a} + deriving stock + (Functor, Foldable, Traversable, Eq, Ord) + +vector + :: forall a n. (AdditiveMonoid a, KnownNat n) + => V.Vector a -> Vector n a +vector v = + let + len = V.length v + n = from (knownNat @n) + in + case compare len n of + EQ -> UnsafeV v + GT -> UnsafeV (V.take n v) + LT -> UnsafeV (v <> V.replicate (n - len) zero) + +instance KnownNat n => Representable (Vector n) where + type Rep (Vector n) = Prelude.Int + index (UnsafeV v) i = v V.! i + tabulate = UnsafeV . V.generate (from (knownNat @n)) + +instance KnownNat n => Distributive (Vector n) where + distribute = distributeRep + collect = collectRep + +instance (Field a, KnownNat n) => VectorSpace a (Vector n) where + type Basis a (Vector n) = Prelude.Int + indexV (UnsafeV v) i = fromMaybe zero (v V.!? i) + tabulateV = tabulate + dimV = knownNat @n + basisV = tabulate id + +-- | sparse vectors +newtype SparseV (n :: Natural) a = + UnsafeSparseV {fromSparseV :: IntMap a} + deriving stock + (Functor, Foldable, Traversable, Eq, Ord) + +sparseV :: forall a n. (Eq a, Field a, KnownNat n) => IntMap a -> SparseV n a +sparseV intMap = UnsafeSparseV (IntMap.foldMapWithKey sparsify intMap) where + sparsify int a = + if a == zero || int < 0 || int >= from (knownNat @n) + then IntMap.empty + else IntMap.singleton int a + +instance (Eq a, Field a, KnownNat n) => VectorSpace a (SparseV n) where + type Basis a (SparseV n) = Prelude.Int + indexV v i = fromMaybe zero (fromSparseV v IntMap.!? i) + tabulateV f = UnsafeSparseV $ + IntMap.fromList + [ (i, f i) + | i <- [0 .. from (knownNat @n) - 1] + , f i /= zero + ] + dimV = knownNat @n + basisV = UnsafeSparseV $ + IntMap.fromList + [ (i, i) + | i <- [0 .. from (knownNat @n) - 1] + ] diff --git a/zkfold-base.cabal b/zkfold-base.cabal index 8966dc6e3..beb8c51bf 100644 --- a/zkfold-base.cabal +++ b/zkfold-base.cabal @@ -133,6 +133,11 @@ library ZkFold.Symbolic.Algorithms.Hash.MiMC.Constants ZkFold.Symbolic.Algorithms.Hash.SHA2 ZkFold.Symbolic.Algorithms.Hash.SHA2.Constants + ZkFold.Symbolic.Base.Circuit + ZkFold.Symbolic.Base.Function + ZkFold.Symbolic.Base.Num + ZkFold.Symbolic.Base.Polynomial + ZkFold.Symbolic.Base.Vector ZkFold.Symbolic.Cardano.Wrapper ZkFold.Symbolic.Cardano.Contracts.BatchTransfer ZkFold.Symbolic.Cardano.Contracts.RandomOracle @@ -208,6 +213,8 @@ library containers < 0.7, cryptohash-sha256 < 0.12, deepseq <= 1.5.0.0, + distributive < 0.7, + indexed-transformers < 0.2, lens , mtl < 2.4, optics < 0.5,