diff --git a/lsm-tree.cabal b/lsm-tree.cabal index b7e89fb08..41cc6be87 100644 --- a/lsm-tree.cabal +++ b/lsm-tree.cabal @@ -148,6 +148,19 @@ benchmark lsm-tree-macro-bench main-is: Main.hs build-depends: base +library kmerge + import: warnings + default-language: Haskell2010 + hs-source-dirs: src-kmerge + exposed-modules: + KMerge.Heap + KMerge.LoserTree + + build-depends: + , base + , indexed-traversable + , primitive + test-suite kmerge-test import: warnings default-language: Haskell2010 @@ -155,9 +168,12 @@ test-suite kmerge-test hs-source-dirs: test main-is: kmerge-test.hs build-depends: - , base >=4.14 && <4.19 + , base >=4.14 && <4.19 , deepseq , heaps + , indexed-traversable + , lsm-tree:kmerge + , primitive , QuickCheck , splitmix , tasty diff --git a/src-kmerge/KMerge/Heap.hs b/src-kmerge/KMerge/Heap.hs new file mode 100644 index 000000000..d258166e3 --- /dev/null +++ b/src-kmerge/KMerge/Heap.hs @@ -0,0 +1,181 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -fexpose-all-unfoldings #-} +-- | Mutable heap for k-merge algorithm. +-- +-- This data-structure represents a min-heap with the root node *removed*. +-- (internally the filling of root value and sifting down is delayed). +-- +-- Also there isn't *insert* operation, i.e. the heap can only shrink. +-- Other heap usual heap opeartions are *create-heap*, *extract-min* and *replace*. +-- However, as the 'MutableHeap' always represents a heap with its root (minimum value) +-- extracted, *extract-min* is "fused" to other operations. +module KMerge.Heap ( + MutableHeap, + newMutableHeap, + replaceRoot, + extract, +) where + +import Control.Monad.Primitive (PrimMonad (PrimState), RealWorld) +import qualified Control.Monad.ST as Lazy +import qualified Control.Monad.ST as Strict +import Data.Bits (unsafeShiftL, unsafeShiftR) +import Data.Foldable.WithIndex (ifor_) +import Data.Primitive (SmallMutableArray, newSmallArray, + readSmallArray, writeSmallArray) +import Data.Primitive.PrimVar (PrimVar, newPrimVar, readPrimVar, + writePrimVar) +import Unsafe.Coerce (unsafeCoerce) + +-- | Mutable heap for k-merge algorithm. +data MutableHeap s a = MH + !(PrimVar s Int) -- ^ element count, size + !(SmallMutableArray s a) + +-- | Placeholder value used to fill the internal array. +placeholder :: a +placeholder = unsafeCoerce () + +-- | Create new heap, and immediately extract its minimum value. +newMutableHeap :: forall a m. (PrimMonad m, Ord a) => [a] -> m (MutableHeap (PrimState m) a, Maybe a) +newMutableHeap xs = do + let !size = length xs + + arr <- newSmallArray size placeholder + ifor_ xs $ \idx x -> do + writeSmallArray arr idx x + siftUp arr x idx + + sizeRef <- newPrimVar size + + if size <= 0 + then return $! (MH sizeRef arr, Nothing) + else do + x <- readSmallArray arr 0 + writeSmallArray arr 0 placeholder + return $! (MH sizeRef arr, Just x) + +-- | Replace the minimum-value, and immediately extract the new minimum value. +replaceRoot :: forall a m. (PrimMonad m, Ord a) => MutableHeap (PrimState m) a -> a -> m a +replaceRoot (MH sizeRef arr) val = do + size <- readPrimVar sizeRef + if size <= 1 + then return val + else do + writeSmallArray arr 0 val + siftDown arr size val 0 + x <- readSmallArray arr 0 + return x + +{-# SPECIALIZE replaceRoot :: forall a. Ord a => MutableHeap RealWorld a -> a -> IO a #-} +{-# SPECIALIZE replaceRoot :: forall a s. Ord a => MutableHeap s a -> a -> Strict.ST s a #-} +{-# SPECIALIZE replaceRoot :: forall a s. Ord a => MutableHeap s a -> a -> Lazy.ST s a #-} + +-- | Extract the next minimum value. +extract :: forall a m. (PrimMonad m, Ord a) => MutableHeap (PrimState m) a -> m (Maybe a) +extract (MH sizeRef arr) = do + size <- readPrimVar sizeRef + if size <= 1 + then return Nothing + else do + writePrimVar sizeRef $! size - 1 + val <- readSmallArray arr (size - 1) + writeSmallArray arr 0 val + siftDown arr size val 0 + x <- readSmallArray arr 0 + writeSmallArray arr (size - 1) placeholder + return $! Just x + +{-# SPECIALIZE extract :: forall a. Ord a => MutableHeap RealWorld a -> IO (Maybe a) #-} +{-# SPECIALIZE extract :: forall a s. Ord a => MutableHeap s a -> Strict.ST s (Maybe a) #-} +{-# SPECIALIZE extract :: forall a s. Ord a => MutableHeap s a -> Lazy.ST s (Maybe a) #-} + +{------------------------------------------------------------------------------- + Internal operations +-------------------------------------------------------------------------------} + +siftUp :: forall a m. (PrimMonad m, Ord a) => SmallMutableArray (PrimState m) a -> a -> Int -> m () +siftUp !arr !x = loop where + loop !idx + | idx <= 0 + = return () + + | otherwise + = do + let !parent = halfOf (idx - 1) + p <- readSmallArray arr parent + if x < p + then do + writeSmallArray arr parent x + writeSmallArray arr idx p + loop parent + else return () + +{-# SPECIALIZE siftUp :: forall a. Ord a => SmallMutableArray RealWorld a -> a -> Int -> IO () #-} +{-# SPECIALIZE siftUp :: forall a s. Ord a => SmallMutableArray s a -> a -> Int -> Strict.ST s () #-} +{-# SPECIALIZE siftUp :: forall a s. Ord a => SmallMutableArray s a -> a -> Int -> Lazy.ST s () #-} + +siftDown :: forall a m. (PrimMonad m, Ord a) => SmallMutableArray (PrimState m) a -> Int -> a -> Int -> m () +siftDown !arr !size !x = loop where + loop !idx + | rgt < size + = do + l <- readSmallArray arr lft + r <- readSmallArray arr rgt + + if x <= l + then do + if x <= r + then return () + else do + -- r < x <= l; swap x and r + writeSmallArray arr rgt x + writeSmallArray arr idx r + loop rgt + else do + if l <= r + then do + -- l < x, l <= r; swap x and l + writeSmallArray arr idx l + writeSmallArray arr lft x + loop lft + else do + -- r < l <= x; swap x and r + writeSmallArray arr rgt x + writeSmallArray arr idx r + loop rgt + + -- there's only left value + | lft < size + = do + l <- readSmallArray arr lft + if x <= l + then return () + else do + writeSmallArray arr idx l + writeSmallArray arr lft x + -- there is no need to loop further, lft was the last value. + + | otherwise + = return () + where + !lft = doubleOf idx + 1 + !rgt = doubleOf idx + 2 + +{-# SPECIALIZE siftDown :: forall a. Ord a => SmallMutableArray RealWorld a -> Int -> a -> Int -> IO () #-} +{-# SPECIALIZE siftDown :: forall a s. Ord a => SmallMutableArray s a -> Int -> a -> Int -> Strict.ST s () #-} +{-# SPECIALIZE siftDown :: forall a s. Ord a => SmallMutableArray s a -> Int -> a -> Int -> Lazy.ST s () #-} + +{------------------------------------------------------------------------------- + Helpers +-------------------------------------------------------------------------------} + +halfOf :: Int -> Int +halfOf i = unsafeShiftR i 1 +{-# INLINE halfOf #-} + +doubleOf :: Int -> Int +doubleOf i = unsafeShiftL i 1 +{-# INLINE doubleOf #-} diff --git a/src-kmerge/KMerge/LoserTree.hs b/src-kmerge/KMerge/LoserTree.hs new file mode 100644 index 000000000..2fc897b47 --- /dev/null +++ b/src-kmerge/KMerge/LoserTree.hs @@ -0,0 +1,201 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -fexpose-all-unfoldings #-} +module KMerge.LoserTree ( + MutableLoserTree, + newLoserTree, + replace, + remove, +) where + +import Control.Monad.Primitive (PrimMonad (PrimState), RealWorld) +import qualified Control.Monad.ST as Lazy +import qualified Control.Monad.ST as Strict +import Data.Bits (unsafeShiftR) +import Data.Primitive (MutablePrimArray, SmallMutableArray, + newPrimArray, newSmallArray, readPrimArray, readSmallArray, + setPrimArray, writePrimArray, writeSmallArray) +import Data.Primitive.PrimVar (PrimVar, newPrimVar, readPrimVar, + writePrimVar) +import Unsafe.Coerce (unsafeCoerce) + +-- | Mutable Loser Tree. +data MutableLoserTree s a = MLT + !(PrimVar s Int) -- ^ element count, i.e. size. + !(PrimVar s Int) -- ^ index of the hole (i.e. winner's initial index) + !(MutablePrimArray s Int) -- ^ indices, we store the index of first match. -1 if there is no match. + !(SmallMutableArray s a) -- ^ values + +placeholder :: a +placeholder = unsafeCoerce () + +-- | Create new 'MutableLoserTree'. +-- +-- The second half of a pair is the winner value (only losers are stored in the tree). +-- +newLoserTree :: forall a m. (PrimMonad m, Ord a) => [a] -> m (MutableLoserTree (PrimState m) a, Maybe a) +newLoserTree [] = do + ids <- newPrimArray 0 + arr <- newSmallArray 0 placeholder + sizeRef <- newPrimVar 0 + holeRef <- newPrimVar 0 + return $! (MLT sizeRef holeRef ids arr, Nothing) +newLoserTree [x] = do + ids <- newPrimArray 0 + arr <- newSmallArray 0 placeholder + sizeRef <- newPrimVar 0 + holeRef <- newPrimVar 0 + return $! (MLT sizeRef holeRef ids arr, Just x) +newLoserTree xs0 = do + -- allocate array, we need one less than there are elements. + -- one of the elements will be the winner. + ids <- newPrimArray (len - 1) + setPrimArray ids 0 (len - 1) (-1) + arr <- newSmallArray (len - 1) placeholder + + loop ids arr (len - 1) xs0 + where + !len = length xs0 + + loop :: MutablePrimArray (PrimState m) Int -> SmallMutableArray (PrimState m) a -> Int -> [a] -> m (MutableLoserTree (PrimState m) a, Maybe a) + loop !_ !_ !_ [] = error "should not happen" + loop ids arr idx (x:xs) = do + sift ids arr (parentOf idx) (parentOf idx) x idx xs + + sift :: MutablePrimArray (PrimState m) Int -> SmallMutableArray (PrimState m) a -> Int -> Int -> a -> Int -> [a] -> m (MutableLoserTree (PrimState m) a, Maybe a) + sift !ids !arr !idxX !j !x !idx0 xs = do + !idxY <- readPrimArray ids j + y <- readSmallArray arr j + if idxY < 0 + then do + writePrimArray ids j idxX + writeSmallArray arr j x + loop ids arr (idx0 + 1) xs + else if j <= 0 + then do + if x <= y + then do + sizeRef <- newPrimVar (len - 1) + holeRef <- newPrimVar idxX + return (MLT sizeRef holeRef ids arr, Just x) + else do + writePrimArray ids j idxX + writeSmallArray arr j x + sizeRef <- newPrimVar (len - 1) + holeRef <- newPrimVar idxY + return (MLT sizeRef holeRef ids arr, Just y) + else do + if x < y + then do + sift ids arr idxX (parentOf j) x idx0 xs + else do + writePrimArray ids j idxX + writeSmallArray arr j x + sift ids arr idxY (parentOf j) y idx0 xs + +{-# SPECIALIZE newLoserTree :: forall a. Ord a => [a] -> IO (MutableLoserTree RealWorld a, Maybe a) #-} +{-# SPECIALIZE newLoserTree :: forall a s. Ord a => [a] -> Strict.ST s (MutableLoserTree s a, Maybe a) #-} +{-# SPECIALIZE newLoserTree :: forall a s. Ord a => [a] -> Lazy.ST s (MutableLoserTree s a, Maybe a) #-} + +{------------------------------------------------------------------------------- + Updates +-------------------------------------------------------------------------------} + +{-# SPECIALIZE replace :: forall a. Ord a => MutableLoserTree RealWorld a -> a -> IO a #-} +{-# SPECIALIZE replace :: forall a s. Ord a => MutableLoserTree s a -> a -> Strict.ST s a #-} +{-# SPECIALIZE replace :: forall a s. Ord a => MutableLoserTree s a -> a -> Lazy.ST s a #-} + +{-# SPECIALIZE remove :: forall a. Ord a => MutableLoserTree RealWorld a -> IO (Maybe a) #-} +{-# SPECIALIZE remove :: forall a s. Ord a => MutableLoserTree s a -> Strict.ST s (Maybe a) #-} +{-# SPECIALIZE remove :: forall a s. Ord a => MutableLoserTree s a -> Lazy.ST s (Maybe a) #-} + +-- | Don't fill the winner "hole". Return a next winner of (smaller) tournament. +remove :: forall a m. (PrimMonad m, Ord a) => MutableLoserTree (PrimState m) a -> m (Maybe a) +remove (MLT sizeRef holeRef ids arr) = do + size <- readPrimVar sizeRef + if size <= 0 + then return Nothing + else do + writePrimVar sizeRef (size - 1) + hole <- readPrimVar holeRef + siftEmpty hole + where + siftEmpty :: Int -> m (Maybe a) + siftEmpty !j = do + !idxY <- readPrimArray ids j + y <- readSmallArray arr j + if j <= 0 + then if idxY < 0 + then return Nothing + else do + writePrimArray ids j (-1) + writeSmallArray arr j placeholder + writePrimVar holeRef idxY + return (Just y) + else if idxY < 0 + then + siftEmpty (parentOf j) + else do + writePrimArray ids j (-1) + writeSmallArray arr j placeholder + Just <$> siftUp ids arr holeRef (parentOf j) idxY y + +-- | Fill the winner "hole" with a new element. Return a new tournament winner. +replace :: forall a m. (PrimMonad m, Ord a) => MutableLoserTree (PrimState m) a -> a -> m a +replace (MLT sizeRef holeRef ids arr) val = do + size <- readPrimVar sizeRef + if size <= 0 + then return val + else do + hole <- readPrimVar holeRef + siftUp ids arr holeRef hole hole val + +{-# SPECIALIZE siftUp :: forall a. Ord a => MutablePrimArray RealWorld Int -> SmallMutableArray RealWorld a -> PrimVar RealWorld Int -> Int -> Int -> a -> IO a #-} +{-# SPECIALIZE siftUp :: forall a s. Ord a => MutablePrimArray s Int -> SmallMutableArray s a -> PrimVar s Int -> Int -> Int -> a -> Strict.ST s a #-} +{-# SPECIALIZE siftUp :: forall a s. Ord a => MutablePrimArray s Int -> SmallMutableArray s a -> PrimVar s Int -> Int -> Int -> a -> Lazy.ST s a #-} + +siftUp :: forall a m. (PrimMonad m, Ord a) => MutablePrimArray (PrimState m) Int -> SmallMutableArray (PrimState m) a -> PrimVar (PrimState m) Int -> Int -> Int -> a -> m a +siftUp ids arr holeRef = sift + where + sift :: Int -> Int -> a -> m a + sift !j !idxX !x = do + !idxY <- readPrimArray ids j + y <- readSmallArray arr j + if j <= 0 + then if idxY < 0 + then do + writePrimVar holeRef idxX + return x + else do + if x <= y + then do + writePrimVar holeRef idxX + return x + else do + writePrimArray ids j idxX + writeSmallArray arr j x + writePrimVar holeRef idxY + return y + else if idxY < 0 + then sift (parentOf j) idxX x + else do + if x <= y + then do + sift (parentOf j) idxX x + else do + writePrimArray ids j idxX + writeSmallArray arr j x + sift (parentOf j) idxY y + +{------------------------------------------------------------------------------- + Helpers +-------------------------------------------------------------------------------} + +halfOf :: Int -> Int +halfOf i = unsafeShiftR i 1 +{-# INLINE halfOf #-} + +parentOf :: Int -> Int +parentOf i = halfOf (i - 1) +{-# INLINE parentOf #-} diff --git a/test/kmerge-test.hs b/test/kmerge-test.hs index d4dbd4fe4..30a77ee49 100644 --- a/test/kmerge-test.hs +++ b/test/kmerge-test.hs @@ -1,56 +1,227 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -fspecialize-aggressively #-} module Main (main) where -import Control.DeepSeq (force) +import Control.DeepSeq (NFData (..), force) import Control.Exception (evaluate) +import Control.Monad.ST.Strict (ST, runST) import qualified Data.Heap as Heap +import Data.IORef import qualified Data.List as L import Data.WideWord.Word256 (Word256 (..)) import Data.Word (Word64) +import System.IO.Unsafe (unsafePerformIO) import qualified System.Random.SplitMix as SM -import Test.Tasty (defaultMainWithIngredients, testGroup) +import Test.Tasty (TestName, TestTree, defaultMainWithIngredients, + testGroup) import qualified Test.Tasty.Bench as B -import Test.Tasty.QuickCheck (Property, testProperty, (===)) +import Test.Tasty.HUnit (testCase, (@?=)) +import Test.Tasty.QuickCheck (testProperty, (===)) +import qualified KMerge.Heap as K.Heap +import qualified KMerge.LoserTree as K.Tree + +-- tests and benchmarks for various k-way merge implementations. +-- in short: loser tree is optimal in comparision counts performed, +-- but mutable heap implementation has lower constant factors. +-- +-- Noteworthy, maybe not obvious observations: +-- - mutable heap does the same amount of comparisions as persistent heap +-- (from @heaps@ package), +-- - tree-shaped iterative two-way merge performs optimal amount of comparisons +-- loser tree is an explicit state variant of that. +-- main :: IO () main = do _ <- evaluate $ force input8 + _ <- evaluate $ force input7 _ <- evaluate $ force input5 defaultMainWithIngredients B.benchIngredients $ testGroup "kmerge" [ testGroup "tests" - [ testProperty "twoWayMerge" prop_twoWayMerge - , testProperty "twoWayMerge2" prop_twoWayMerge2 - , testProperty "heapKWayMerge" prop_kWayMerge + [ testGroup "merge" + [ mergeProperty "listMerge" listMerge + , mergeProperty "treeMerge" treeMerge + , mergeProperty "heapMerge" heapMerge + , mergeProperty "loserTreeMerge" loserTreeMerge + , mergeProperty "mutHeapMerge" mutHeapMerge + ] + , testGroup "count" + [ testGroup "eight" + -- loserTree comparison upper bounds for 8 inputs is 3 x element count. + -- for 8 100-element lists, i.e. 800 elements the total comparison count is 2400 + -- loserTree (and tree merge) implementations hit exactly that number. + -- + -- (because the input values are unformly random, + -- there shouldn't be a lot of "cheap" leftovers elements, + -- i.e. when other inputs are exhausted, but there are few) + [ testCount "sortConcat" 3190 (L.sort . concat) input8 + , testCount "listMerge" 3479 listMerge input8 + , testCount "treeMerge" 2391 treeMerge input8 + , testCount "heapMerge" 3168 heapMerge input8 + , testCount "loserTreeMerge" 2391 loserTreeMerge input8 + , testCount "mutHeapMerge" 3169 mutHeapMerge input8 + ] + -- seven inputs: we have 6x100 elements with 3 comparisions + -- and 1x100 elements with just 2. + -- i.e. target is 2000 total comparisions. + -- + -- The difference here and in five-input case between + -- treeMerge and loserTreeMerge is caused by + -- different "tournament bracket" assignments done by the + -- algorithms. + -- + -- In particular in five case, the treeMerge bracket looks like + -- + -- * + -- / \ + -- * 5 + -- / \ + -- * * + -- / \ / \ + -- 1 2 3 4 + -- + -- But the LoserTree is balanced: + -- + -- * + -- / \ + -- * * + -- / \ / \ + -- * 3 4 5 + -- / \ + -- 1 2 + -- + -- (maybe treeMerge can be better balanced too, + -- but I'm too lazy to think how to do that) + -- + , testGroup "seven" + [ testCount "sortConcat" 2691 (L.sort . concat) input7 + , testCount "listMerge" 2682 listMerge input7 + , testCount "treeMerge" 1992 treeMerge input7 + , testCount "heapMerge" 2645 heapMerge input7 + , testCount "loserTreeMerge" 1989 loserTreeMerge input7 + , testCount "mutHeapMerge" 2570 mutHeapMerge input7 + ] + -- five inputs: we have 3x100 elements with 2 comparisons + -- and 2x100 with 3 comparisons. + -- i.e. target is 1200 total comparisions. + , testGroup "five" + [ testCount "sortConcat" 1790 (L.sort . concat) input5 + , testCount "listMerge" 1389 listMerge input5 + , testCount "treeMerge" 1291 treeMerge input5 + , testCount "heapMerge" 1485 heapMerge input5 + , testCount "loserTreeMerge" 1191 loserTreeMerge input5 + , testCount "mutHeapMerge" 1592 mutHeapMerge input5 + ] + ] ] , testGroup "bench" [ testGroup "eight" - [ B.bench "sortConcat" $ B.nf (L.sort . concat) input8 - , B.bench "twoWayMerge" $ B.nf recursiveTwoWayMerge input8 - , B.bench "twoWayMerge2" $ B.nf recursiveTwoWayMerge2 input8 - , B.bench "heapKWayMerge" $ B.nf heapKWayMerge input8 + [ B.bench "sortConcat" $ B.nf (L.sort . concat) input8 + , B.bench "listMerge" $ B.nf listMerge input8 + , B.bench "treeMerge" $ B.nf treeMerge input8 + , B.bench "heapMerge" $ B.nf heapMerge input8 + , B.bench "loserTreeMerge" $ B.nf loserTreeMerge input8 + , B.bench "mutHeapMerge" $ B.nf mutHeapMerge input8 + ] + , testGroup "seven" + [ B.bench "sortConcat" $ B.nf (L.sort . concat) input7 + , B.bench "listMerge" $ B.nf listMerge input7 + , B.bench "treeMerge" $ B.nf treeMerge input7 + , B.bench "heapMerge" $ B.nf heapMerge input7 + , B.bench "loserTreeMerge" $ B.nf loserTreeMerge input7 + , B.bench "mutHeapMerge" $ B.nf mutHeapMerge input7 ] , testGroup "five" - [ B.bench "sortConcat" $ B.nf (L.sort . concat) input5 - , B.bench "twoWayMerge" $ B.nf recursiveTwoWayMerge input5 - , B.bench "twoWayMerge2" $ B.nf recursiveTwoWayMerge2 input5 - , B.bench "heapKWayMerge" $ B.nf heapKWayMerge input5 + [ B.bench "sortConcat" $ B.nf (L.sort . concat) input5 + , B.bench "listMerge" $ B.nf listMerge input5 + , B.bench "treeMerge" $ B.nf treeMerge input5 + , B.bench "heapMerge" $ B.nf heapMerge input5 + , B.bench "loserTreeMerge" $ B.nf loserTreeMerge input5 + , B.bench "mutHeapMerge" $ B.nf mutHeapMerge input5 ] ] ] +{------------------------------------------------------------------------------- + Test utils +-------------------------------------------------------------------------------} + +counter :: IORef Int +counter = unsafePerformIO $ newIORef 0 +{-# NOINLINE counter #-} + +newtype Wrapped a = Wrap a -- { unwrap :: Word256 } + +instance Eq a => Eq (Wrapped a) where + Wrap x == Wrap y = unsafePerformIO $ do + atomicModifyIORef' counter $ \n -> (1 + n, ()) + return $! x == y + {-# NOINLINE (==) #-} + +instance Ord a => Ord (Wrapped a) where + compare (Wrap x) (Wrap y) = unsafePerformIO $ do + atomicModifyIORef' counter $ \n -> (1 + n, ()) + return $! compare x y + Wrap x < Wrap y = unsafePerformIO $ do + atomicModifyIORef' counter $ \n -> (1 + n, ()) + return $! x < y + Wrap x <= Wrap y = unsafePerformIO $ do + atomicModifyIORef' counter $ \n -> (1 + n, ()) + return $! x <= y + + {-# NOINLINE compare #-} + {-# NOINLINE (<) #-} + {-# NOINLINE (<=) #-} + +instance NFData a => NFData (Wrapped a) where + rnf (Wrap x) = rnf x + +testCount :: (NFData b, Ord b) => TestName -> Int -> (forall a. Ord a => [[a]] -> [a]) -> [[b]] -> TestTree +testCount name expected f input = testCase name $ do + n <- readIORef counter + _ <- evaluate $ force $ f $ map (map Wrap) input + m <- readIORef counter + m - n @?= expected +{-# NOINLINE testCount #-} + +mergeProperty :: TestName -> (forall a. Ord a => [[a]] -> [a]) -> TestTree +mergeProperty name f = testProperty name $ \xss -> + let lhs = L.sort (concat xss) + rhs = f $ map L.sort (xss :: [[Word64]]) + in lhs === rhs + +type Element = Word256 +-- type Element = (Word256, Word256, Word256, Word256) + -- Using Word256 to make key comparison a bit more expensive. -input8 :: [[Word256]] +input8 :: [[Element]] input8 = - [ L.sort $ take 100 $ L.unfoldr (Just . genWord256) $ SM.mkSMGen seed - | seed <- [1..8] + [ L.sort $ take 100 $ L.unfoldr (Just . genElement) $ SM.mkSMGen seed + | seed <- take 8 $ iterate (3 +) 42 ] +-- Seven inputs is not optimal case for "binary tree" patterns. +input7 :: [[Element]] +input7 = take 7 input8 + -- Five inputs is bad case for "binary tree" patterns. -input5 :: [[Word256]] +input5 :: [[Element]] input5 = take 5 input8 +genElement :: SM.SMGen -> (Element, SM.SMGen) +genElement = genWord256 +{- +genElement g0 = + let (!w1, g1) = genWord256 g0 + (!w2, g2) = genWord256 g1 + (!w3, g3) = genWord256 g2 + (!w4, g4) = genWord256 g3 + in ((w1, w2, w3, w4), g4) +-} + genWord256 :: SM.SMGen -> (Word256, SM.SMGen) genWord256 g0 = let (!w1, g1) = SM.nextWord64 g0 @@ -63,15 +234,10 @@ genWord256 g0 = Recursive 2-way merge -------------------------------------------------------------------------------} -prop_twoWayMerge :: [[Word64]] -> Property -prop_twoWayMerge xss = lhs === rhs where - lhs = L.sort (concat xss) - rhs = recursiveTwoWayMerge $ map L.sort xss - -recursiveTwoWayMerge :: Ord a => [[a]] -> [a] -recursiveTwoWayMerge [] = [] -recursiveTwoWayMerge [xs] = xs -recursiveTwoWayMerge (xs:xss) = merge xs (recursiveTwoWayMerge xss) +listMerge :: Ord a => [[a]] -> [a] +listMerge [] = [] +listMerge [xs] = xs +listMerge (xs:xss) = merge xs (listMerge xss) merge :: Ord a => [a] -> [a] -> [a] merge [] [] = [] @@ -82,21 +248,16 @@ merge xs@(x:xs') ys@(y:ys') | otherwise = y : merge xs ys' {------------------------------------------------------------------------------- - Recursive 2-way merge 2 + Recursive 2-way merge, tree shape -------------------------------------------------------------------------------} -prop_twoWayMerge2 :: [[Word64]] -> Property -prop_twoWayMerge2 xss = lhs === rhs where - lhs = L.sort (concat xss) - rhs = recursiveTwoWayMerge2 $ map L.sort xss - --- | Like 'recursiveTwoWayMerge', but merges in binary-tree pattern. +-- | Like 'listMerge', but merges in binary-tree pattern. -- -- Given inputs of about the same length, there will be less work in merges. -recursiveTwoWayMerge2 :: Ord a => [[a]] -> [a] -recursiveTwoWayMerge2 [] = [] -recursiveTwoWayMerge2 [xs] = xs -recursiveTwoWayMerge2 (xs:ys:xss) = recursiveTwoWayMerge2 (merge xs ys : go xss) where +treeMerge :: Ord a => [[a]] -> [a] +treeMerge [] = [] +treeMerge [xs] = xs +treeMerge (xs:ys:xss) = treeMerge (merge xs ys : go xss) where go [] = [] go [vs] = [vs] go (vs:ws:vss) = merge vs ws : go vss @@ -105,13 +266,8 @@ recursiveTwoWayMerge2 (xs:ys:xss) = recursiveTwoWayMerge2 (merge xs ys : go xss) Direct k-way merge using heaps Data.Heap.Heap -------------------------------------------------------------------------------} -prop_kWayMerge :: [[Word64]] -> Property -prop_kWayMerge xss = lhs === rhs where - lhs = L.sort (concat xss) - rhs = heapKWayMerge $ map L.sort xss - -heapKWayMerge :: forall a. Ord a => [[a]] -> [a] -heapKWayMerge xss = go $ Heap.fromList +heapMerge :: forall a. Ord a => [[a]] -> [a] +heapMerge xss = go $ Heap.fromList [ Heap.Entry x xs | x:xs <- xss ] @@ -122,3 +278,35 @@ heapKWayMerge xss = go $ Heap.fromList Just (Heap.Entry x xs, heap') -> x : case xs of [] -> go heap' x':xs' -> go (Heap.insert (Heap.Entry x' xs') heap') + +{------------------------------------------------------------------------------- + Direct k-way merge using LoserTree +-------------------------------------------------------------------------------} + +loserTreeMerge :: forall a. Ord a => [[a]] -> [a] +loserTreeMerge xss = runST $ do + -- we reuse Heap.Entry structure here. + (tree, element) <- K.Tree.newLoserTree [ Heap.Entry x xs | x:xs <- xss ] + go tree element + where + go :: K.Tree.MutableLoserTree s (Heap.Entry a [a]) -> Maybe (Heap.Entry a [a]) -> ST s [a] + go !_ Nothing = return [] + go !tree (Just (Heap.Entry x xs)) = fmap (x :) $ case xs of + [] -> K.Tree.remove tree >>= go tree + x':xs' -> K.Tree.replace tree (Heap.Entry x' xs') >>= go tree . Just + +{------------------------------------------------------------------------------- + Direct k-way merge using MutableHeap +-------------------------------------------------------------------------------} + +mutHeapMerge :: forall a. Ord a => [[a]] -> [a] +mutHeapMerge xss = runST $ do + -- we reuse Heap.Entry structure here. + (heap, element) <- K.Heap.newMutableHeap [ Heap.Entry x xs | x:xs <- xss ] + go heap element + where + go :: K.Heap.MutableHeap s (Heap.Entry a [a]) -> Maybe (Heap.Entry a [a]) -> ST s [a] + go !_ Nothing = return [] + go !heap (Just (Heap.Entry x xs)) = fmap (x :) $ case xs of + [] -> K.Heap.extract heap >>= go heap + x':xs' -> K.Heap.replaceRoot heap (Heap.Entry x' xs') >>= go heap . Just