diff --git a/src/Data/Graph/Haggle/Algorithms/DFS.hs b/src/Data/Graph/Haggle/Algorithms/DFS.hs index 6b9bbbf..e3a033f 100644 --- a/src/Data/Graph/Haggle/Algorithms/DFS.hs +++ b/src/Data/Graph/Haggle/Algorithms/DFS.hs @@ -67,22 +67,25 @@ xdfsWith :: (Graph g) -> [c] xdfsWith g nextVerts f roots | isEmpty g || null roots = [] - | otherwise = runST $ do - bs <- newBitSet (maxVertexId g + 1) - res <- foldM (go bs) [] roots - return $ reverse res + | otherwise = + if any (not . (`elem` vertices g)) roots + then [] + else runST $ do + bs <- newBitSet (maxVertexId g + 1) + res <- foldM (go bs) [] roots + return $ reverse res where go bs acc v = do - isMarked <- testBit bs (vertexId v) + isMarked <- testBitUnsafe bs (vertexId v) case isMarked of True -> return acc False -> do - setBit bs (vertexId v) + setBitUnsafe bs (vertexId v) nxt <- filterM (notVisited bs) (nextVerts v) foldM (go bs) (f v : acc) nxt notVisited :: BitSet s -> Vertex -> ST s Bool -notVisited bs v = liftM not (testBit bs (vertexId v)) +notVisited bs v = liftM not (testBitUnsafe bs (vertexId v)) -- | Forward parameterized DFS dfsWith :: (Graph g) @@ -130,17 +133,20 @@ xdffWith :: (Graph g) -> [Tree c] xdffWith g nextVerts f roots | isEmpty g || null roots = [] - | otherwise = runST $ do - bs <- newBitSet (maxVertexId g + 1) - res <- foldM (go bs) [] roots - return $ reverse res + | otherwise = + if any (not . (`elem` vertices g)) roots + then [] + else runST $ do + bs <- newBitSet (maxVertexId g + 1) + res <- foldM (go bs) [] roots + return $ reverse res where go bs acc v = do - isMarked <- testBit bs (vertexId v) + isMarked <- testBitUnsafe bs (vertexId v) case isMarked of True -> return acc False -> do - setBit bs (vertexId v) + setBitUnsafe bs (vertexId v) nxt <- filterM (notVisited bs) (nextVerts v) ts <- foldM (go bs) [] nxt return $ T.Node (f v) (reverse ts) : acc diff --git a/src/Data/Graph/Haggle/Internal/BitSet.hs b/src/Data/Graph/Haggle/Internal/BitSet.hs index 3e9d1ad..2beb2da 100644 --- a/src/Data/Graph/Haggle/Internal/BitSet.hs +++ b/src/Data/Graph/Haggle/Internal/BitSet.hs @@ -2,7 +2,9 @@ module Data.Graph.Haggle.Internal.BitSet ( BitSet, newBitSet, setBit, - testBit + testBit, + setBitUnsafe, + testBitUnsafe ) where import Control.Monad.ST @@ -11,6 +13,21 @@ import Data.Vector.Unboxed.Mutable ( STVector ) import qualified Data.Vector.Unboxed.Mutable as V import Data.Word ( Word64 ) +-- Note that the implementation here assumes thaththe bit numbers are all +-- unsigned. A proper implementation would perhaps use 'Natural' instead of +-- 'Int', but that would require gratuitous fromEnum/toEnum conversions from all +-- the other API's that just use 'Int', which has about a 33% performance impact +-- when measured. +-- +-- The 'setBit' and 'testBit' operations use V.unsafeRead instead of V.read +-- (where the latter is roughly 25% slower) because this is an internal module +-- that is generally always used with a positive 'Int' value, and the value is +-- also checked against 'sz' (which is also probably superfluous). In other +-- words, this module prioritizes performance over robustness and should only be +-- used when the caller can guarantee positive Int values and otherwise good +-- behavior. + + data BitSet s = BS (STVector s Word64) {-# UNPACK #-} !Int bitsPerWord :: Int @@ -28,22 +45,31 @@ newBitSet n = do -- | Set a bit in the bitset. Out of range has no effect. setBit :: BitSet s -> Int -> ST s () -setBit (BS v sz) bitIx +setBit b@(BS _ sz) bitIx | bitIx >= sz = return () - | otherwise = do + | bitIx < 0 = return () + | otherwise = setBitUnsafe b bitIx + +-- | Set a bit in the bitset. The specified bit must be in range. +setBitUnsafe :: BitSet s -> Int -> ST s () +setBitUnsafe (BS v _) bitIx = do let wordIx = bitIx `div` bitsPerWord bitPos = bitIx `mod` bitsPerWord - oldWord <- V.read v wordIx + oldWord <- V.unsafeRead v wordIx let newWord = B.setBit oldWord bitPos V.write v wordIx newWord -- | Return True if the bit is set. Out of range will return False. testBit :: BitSet s -> Int -> ST s Bool -testBit (BS v sz) bitIx +testBit b@(BS _ sz) bitIx | bitIx >= sz = return False - | otherwise = do + | bitIx < 0 = return False + | otherwise = testBitUnsafe b bitIx + +-- | Return True if the bit is set. The specified bit must be in range. +testBitUnsafe :: BitSet s -> Int -> ST s Bool +testBitUnsafe (BS v _) bitIx = do let wordIx = bitIx `div` bitsPerWord bitPos = bitIx `mod` bitsPerWord - w <- V.read v wordIx + w <- V.unsafeRead v wordIx return $ B.testBit w bitPos -