From 5f27b3e50d628f1c18ed1281827a3a36581272fd Mon Sep 17 00:00:00 2001 From: Adithya Kumar Date: Mon, 16 Dec 2024 17:18:50 +0530 Subject: [PATCH] Fix the fold demux combinators to not restart the fold --- .../Streamly/Internal/Data/Fold/Container.hs | 79 +++++++++++-------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/core/src/Streamly/Internal/Data/Fold/Container.hs b/core/src/Streamly/Internal/Data/Fold/Container.hs index 390f1d4a24..b7ec1bf307 100644 --- a/core/src/Streamly/Internal/Data/Fold/Container.hs +++ b/core/src/Streamly/Internal/Data/Fold/Container.hs @@ -387,9 +387,12 @@ demuxerToContainer getKey getFold = step (Tuple' kv kv1) a = do let k = getKey a case IsMap.mapLookup k kv of - Nothing -> do - fld <- getFold k - runFold kv kv1 fld (k, a) + Nothing -> + case IsMap.mapLookup k kv1 of + Just _ -> pure $ Tuple' kv kv1 + Nothing -> do + fld <- getFold k + runFold kv kv1 fld (k, a) Just f -> runFold kv kv1 f (k, a) final (Tuple' kv kv1) = do @@ -406,7 +409,7 @@ demuxerToContainer getKey getFold = -- | Scanning variant of 'demuxerToContainer'. {-# INLINE demuxScanGeneric #-} -demuxScanGeneric :: (Monad m, IsMap f, Traversable f) => +demuxScanGeneric :: (Monad m, IsMap f, Traversable f, Ord (Key f)) => (a -> Key f) -> (Key f -> m (Fold m a b)) -> Scanl m a (m (f b), Maybe (Key f, b)) @@ -415,10 +418,10 @@ demuxScanGeneric getKey getFold = where - initial = return $ Tuple' IsMap.mapEmpty Nothing + initial = return $ Tuple3' IsMap.mapEmpty Set.empty Nothing {-# INLINE runFold #-} - runFold kv (Fold step1 initial1 extract1 final1) (k, a) = do + runFold kv set (Fold step1 initial1 extract1 final1) (k, a) = do res <- initial1 case res of Partial s -> do @@ -427,23 +430,28 @@ demuxScanGeneric getKey getFold = $ case res1 of Partial _ -> let fld = Fold step1 (return res1) extract1 final1 - in Tuple' (IsMap.mapInsert k fld kv) Nothing - Done b -> Tuple' (IsMap.mapDelete k kv) (Just (k, b)) + set1 = Set.insert k set + kv1 = IsMap.mapInsert k fld kv + in Tuple3' kv1 set1 Nothing + Done b -> + let kv1 = IsMap.mapDelete k kv + set1 = Set.insert k set + in Tuple3' kv1 set1 (Just (k, b)) Done b -> -- Done in "initial" is possible only for the very first time -- the fold is initialized, and in that case we have not yet -- inserted it in the Map, so we do not need to delete it. - return $ Tuple' kv (Just (k, b)) + return $ Tuple3' kv (Set.insert k set) (Just (k, b)) - step (Tuple' kv _) a = do + step (Tuple3' kv set _) a = do let k = getKey a case IsMap.mapLookup k kv of Nothing -> do fld <- getFold k - runFold kv fld (k, a) - Just f -> runFold kv f (k, a) + runFold kv set fld (k, a) + Just f -> runFold kv set f (k, a) - extract (Tuple' kv x) = return (Prelude.mapM f kv, x) + extract (Tuple3' kv _ x) = return (Prelude.mapM f kv, x) where @@ -453,7 +461,7 @@ demuxScanGeneric getKey getFold = Partial s -> e s _ -> error "demuxGeneric: unreachable code" - final (Tuple' kv x) = return (Prelude.mapM f kv, x) + final (Tuple3' kv _ x) = return (Prelude.mapM f kv, x) where @@ -647,8 +655,11 @@ demuxerToContainerIO getKey getFold = let k = getKey a case IsMap.mapLookup k kv of Nothing -> do - f <- getFold k - initFold kv kv1 f (k, a) + case IsMap.mapLookup k kv1 of + Just _ -> pure $ Tuple' kv kv1 + Nothing -> do + f <- getFold k + initFold kv kv1 f (k, a) Just ref -> do f <- liftIO $ readIORef ref runFold kv kv1 ref f (k, a) @@ -673,7 +684,7 @@ demuxerToContainerIO getKey getFold = -- ongoing fold if you are using those concurrently in another thread. -- {-# INLINE demuxScanGenericIO #-} -demuxScanGenericIO :: (MonadIO m, IsMap f, Traversable f) => +demuxScanGenericIO :: (MonadIO m, IsMap f, Traversable f, Ord (Key f)) => (a -> Key f) -> (Key f -> m (Fold m a b)) -> Scanl m a (m (f b), Maybe (Key f, b)) @@ -682,10 +693,10 @@ demuxScanGenericIO getKey getFold = where - initial = return $ Tuple' IsMap.mapEmpty Nothing + initial = return $ Tuple3' IsMap.mapEmpty Set.empty Nothing {-# INLINE initFold #-} - initFold kv (Fold step1 initial1 extract1 final1) (k, a) = do + initFold kv set (Fold step1 initial1 extract1 final1) (k, a) = do res <- initial1 case res of Partial s -> do @@ -697,12 +708,12 @@ demuxScanGenericIO getKey getFold = -- accumulator. That will reduce the allocations. let fld = Fold step1 (return res1) extract1 final1 ref <- liftIO $ newIORef fld - return $ Tuple' (IsMap.mapInsert k ref kv) Nothing - Done b -> return $ Tuple' kv (Just (k, b)) - Done b -> return $ Tuple' kv (Just (k, b)) + return $ Tuple3' (IsMap.mapInsert k ref kv) set Nothing + Done b -> pure $ Tuple3' kv (Set.insert k set) (Just (k, b)) + Done b -> return $ Tuple3' kv (Set.insert k set) (Just (k, b)) {-# INLINE runFold #-} - runFold kv ref (Fold step1 initial1 extract1 final1) (k, a) = do + runFold kv set ref (Fold step1 initial1 extract1 final1) (k, a) = do res <- initial1 case res of Partial s -> do @@ -711,23 +722,27 @@ demuxScanGenericIO getKey getFold = Partial _ -> do let fld = Fold step1 (return res1) extract1 final1 liftIO $ writeIORef ref fld - return $ Tuple' kv Nothing + return $ Tuple3' kv set Nothing Done b -> let kv1 = IsMap.mapDelete k kv - in return $ Tuple' kv1 (Just (k, b)) + set1 = Set.insert k set + in return $ Tuple3' kv1 set1 (Just (k, b)) Done _ -> error "demuxGenericIO: unreachable" - step (Tuple' kv _) a = do + step (Tuple3' kv set _) a = do let k = getKey a case IsMap.mapLookup k kv of - Nothing -> do - f <- getFold k - initFold kv f (k, a) + Nothing -> + if Set.member k set + then return (Tuple3' kv set Nothing) + else do + f <- getFold k + initFold kv set f (k, a) Just ref -> do f <- liftIO $ readIORef ref - runFold kv ref f (k, a) + runFold kv set ref f (k, a) - extract (Tuple' kv x) = return (Prelude.mapM f kv, x) + extract (Tuple3' kv _ x) = return (Prelude.mapM f kv, x) where @@ -738,7 +753,7 @@ demuxScanGenericIO getKey getFold = Partial s -> e s _ -> error "demuxGenericIO: unreachable code" - final (Tuple' kv x) = return (Prelude.mapM f kv, x) + final (Tuple3' kv _ x) = return (Prelude.mapM f kv, x) where