Skip to content

Commit

Permalink
Fix the fold demux combinators to not restart the fold
Browse files Browse the repository at this point in the history
  • Loading branch information
adithyaov committed Dec 16, 2024
1 parent 9969702 commit 5f27b3e
Showing 1 changed file with 47 additions and 32 deletions.
79 changes: 47 additions & 32 deletions core/src/Streamly/Internal/Data/Fold/Container.hs
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,12 @@ demuxerToContainer getKey getFold =
step (Tuple' kv kv1) a = do
let k = getKey a
case IsMap.mapLookup k kv of
Nothing -> do
fld <- getFold k
runFold kv kv1 fld (k, a)
Nothing ->
case IsMap.mapLookup k kv1 of
Just _ -> pure $ Tuple' kv kv1
Nothing -> do
fld <- getFold k
runFold kv kv1 fld (k, a)
Just f -> runFold kv kv1 f (k, a)

final (Tuple' kv kv1) = do
Expand All @@ -406,7 +409,7 @@ demuxerToContainer getKey getFold =

-- | Scanning variant of 'demuxerToContainer'.
{-# INLINE demuxScanGeneric #-}
demuxScanGeneric :: (Monad m, IsMap f, Traversable f) =>
demuxScanGeneric :: (Monad m, IsMap f, Traversable f, Ord (Key f)) =>
(a -> Key f)
-> (Key f -> m (Fold m a b))
-> Scanl m a (m (f b), Maybe (Key f, b))
Expand All @@ -415,10 +418,10 @@ demuxScanGeneric getKey getFold =

where

initial = return $ Tuple' IsMap.mapEmpty Nothing
initial = return $ Tuple3' IsMap.mapEmpty Set.empty Nothing

{-# INLINE runFold #-}
runFold kv (Fold step1 initial1 extract1 final1) (k, a) = do
runFold kv set (Fold step1 initial1 extract1 final1) (k, a) = do
res <- initial1
case res of
Partial s -> do
Expand All @@ -427,23 +430,28 @@ demuxScanGeneric getKey getFold =
$ case res1 of
Partial _ ->
let fld = Fold step1 (return res1) extract1 final1
in Tuple' (IsMap.mapInsert k fld kv) Nothing
Done b -> Tuple' (IsMap.mapDelete k kv) (Just (k, b))
set1 = Set.insert k set
kv1 = IsMap.mapInsert k fld kv
in Tuple3' kv1 set1 Nothing
Done b ->
let kv1 = IsMap.mapDelete k kv
set1 = Set.insert k set
in Tuple3' kv1 set1 (Just (k, b))
Done b ->
-- Done in "initial" is possible only for the very first time
-- the fold is initialized, and in that case we have not yet
-- inserted it in the Map, so we do not need to delete it.
return $ Tuple' kv (Just (k, b))
return $ Tuple3' kv (Set.insert k set) (Just (k, b))

step (Tuple' kv _) a = do
step (Tuple3' kv set _) a = do
let k = getKey a
case IsMap.mapLookup k kv of
Nothing -> do
fld <- getFold k
runFold kv fld (k, a)
Just f -> runFold kv f (k, a)
runFold kv set fld (k, a)
Just f -> runFold kv set f (k, a)

extract (Tuple' kv x) = return (Prelude.mapM f kv, x)
extract (Tuple3' kv _ x) = return (Prelude.mapM f kv, x)

where

Expand All @@ -453,7 +461,7 @@ demuxScanGeneric getKey getFold =
Partial s -> e s
_ -> error "demuxGeneric: unreachable code"

final (Tuple' kv x) = return (Prelude.mapM f kv, x)
final (Tuple3' kv _ x) = return (Prelude.mapM f kv, x)

where

Expand Down Expand Up @@ -647,8 +655,11 @@ demuxerToContainerIO getKey getFold =
let k = getKey a
case IsMap.mapLookup k kv of
Nothing -> do
f <- getFold k
initFold kv kv1 f (k, a)
case IsMap.mapLookup k kv1 of
Just _ -> pure $ Tuple' kv kv1
Nothing -> do
f <- getFold k
initFold kv kv1 f (k, a)
Just ref -> do
f <- liftIO $ readIORef ref
runFold kv kv1 ref f (k, a)
Expand All @@ -673,7 +684,7 @@ demuxerToContainerIO getKey getFold =
-- ongoing fold if you are using those concurrently in another thread.
--
{-# INLINE demuxScanGenericIO #-}
demuxScanGenericIO :: (MonadIO m, IsMap f, Traversable f) =>
demuxScanGenericIO :: (MonadIO m, IsMap f, Traversable f, Ord (Key f)) =>
(a -> Key f)
-> (Key f -> m (Fold m a b))
-> Scanl m a (m (f b), Maybe (Key f, b))
Expand All @@ -682,10 +693,10 @@ demuxScanGenericIO getKey getFold =

where

initial = return $ Tuple' IsMap.mapEmpty Nothing
initial = return $ Tuple3' IsMap.mapEmpty Set.empty Nothing

{-# INLINE initFold #-}
initFold kv (Fold step1 initial1 extract1 final1) (k, a) = do
initFold kv set (Fold step1 initial1 extract1 final1) (k, a) = do
res <- initial1
case res of
Partial s -> do
Expand All @@ -697,12 +708,12 @@ demuxScanGenericIO getKey getFold =
-- accumulator. That will reduce the allocations.
let fld = Fold step1 (return res1) extract1 final1
ref <- liftIO $ newIORef fld
return $ Tuple' (IsMap.mapInsert k ref kv) Nothing
Done b -> return $ Tuple' kv (Just (k, b))
Done b -> return $ Tuple' kv (Just (k, b))
return $ Tuple3' (IsMap.mapInsert k ref kv) set Nothing
Done b -> pure $ Tuple3' kv (Set.insert k set) (Just (k, b))
Done b -> return $ Tuple3' kv (Set.insert k set) (Just (k, b))

{-# INLINE runFold #-}
runFold kv ref (Fold step1 initial1 extract1 final1) (k, a) = do
runFold kv set ref (Fold step1 initial1 extract1 final1) (k, a) = do
res <- initial1
case res of
Partial s -> do
Expand All @@ -711,23 +722,27 @@ demuxScanGenericIO getKey getFold =
Partial _ -> do
let fld = Fold step1 (return res1) extract1 final1
liftIO $ writeIORef ref fld
return $ Tuple' kv Nothing
return $ Tuple3' kv set Nothing
Done b ->
let kv1 = IsMap.mapDelete k kv
in return $ Tuple' kv1 (Just (k, b))
set1 = Set.insert k set
in return $ Tuple3' kv1 set1 (Just (k, b))
Done _ -> error "demuxGenericIO: unreachable"

step (Tuple' kv _) a = do
step (Tuple3' kv set _) a = do
let k = getKey a
case IsMap.mapLookup k kv of
Nothing -> do
f <- getFold k
initFold kv f (k, a)
Nothing ->
if Set.member k set
then return (Tuple3' kv set Nothing)
else do
f <- getFold k
initFold kv set f (k, a)
Just ref -> do
f <- liftIO $ readIORef ref
runFold kv ref f (k, a)
runFold kv set ref f (k, a)

extract (Tuple' kv x) = return (Prelude.mapM f kv, x)
extract (Tuple3' kv _ x) = return (Prelude.mapM f kv, x)

where

Expand All @@ -738,7 +753,7 @@ demuxScanGenericIO getKey getFold =
Partial s -> e s
_ -> error "demuxGenericIO: unreachable code"

final (Tuple' kv x) = return (Prelude.mapM f kv, x)
final (Tuple3' kv _ x) = return (Prelude.mapM f kv, x)

where

Expand Down

0 comments on commit 5f27b3e

Please sign in to comment.