diff --git a/hedis.cabal b/hedis.cabal index 7a115a7a..2efcbe25 100644 --- a/hedis.cabal +++ b/hedis.cabal @@ -64,14 +64,17 @@ library ghc-prof-options: -auto-all exposed-modules: Database.Redis build-depends: scanner >= 0.2, + async >= 2.1, base >= 4.6 && < 5, bytestring >= 0.9, bytestring-lexing >= 0.5, + unordered-containers, text, deepseq, mtl >= 2, network >= 2, resource-pool >= 0.2, + stm, time, vector >= 0.9 @@ -100,12 +103,17 @@ benchmark hedis-benchmark test-suite hedis-test type: exitcode-stdio-1.0 - main-is: test/Test.hs + hs-source-dirs: test + main-is: Test.hs + other-modules: PubSubTest build-depends: base == 4.*, bytestring >= 0.10, hedis, HUnit, + async, + stm, + text, mtl == 2.*, slave-thread, test-framework, diff --git a/src/Database/Redis/Core.hs b/src/Database/Redis/Core.hs index 8b30e8f3..73c098c9 100644 --- a/src/Database/Redis/Core.hs +++ b/src/Database/Redis/Core.hs @@ -2,7 +2,7 @@ MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, CPP #-} module Database.Redis.Core ( - Connection, connect, + Connection(..), connect, ConnectInfo(..), defaultConnectInfo, Redis(), runRedis, unRedis, reRedis, RedisCtx(..), MonadRedis(..), @@ -40,7 +40,6 @@ newtype Redis a = Redis (ReaderT RedisEnv IO a) data RedisEnv = Env { envConn :: PP.Connection, envLastReply :: IORef Reply } - -- |This class captures the following behaviour: In a context @m@, a command -- will return it's result wrapped in a \"container\" of type @f@. -- diff --git a/src/Database/Redis/ProtocolPipelining.hs b/src/Database/Redis/ProtocolPipelining.hs index ee9375cf..bddb50ea 100644 --- a/src/Database/Redis/ProtocolPipelining.hs +++ b/src/Database/Redis/ProtocolPipelining.hs @@ -15,7 +15,7 @@ -- module Database.Redis.ProtocolPipelining ( Connection, - connect, disconnect, request, send, recv, + connect, disconnect, request, send, recv, flush, ConnectionLostException(..), HostName, PortID(..) ) where @@ -92,6 +92,12 @@ recv Conn{..} = do writeIORef connReplies rs return r +-- | Flush the socket. Normally, the socket is flushed in 'recv' (actually 'conGetReplies'), but +-- for the multithreaded pub/sub code, the sending thread needs to explicitly flush the subscription +-- change requests. +flush :: Connection -> IO () +flush Conn{..} = hFlush connHandle + -- |Send a request and receive the corresponding reply request :: Connection -> S.ByteString -> IO Reply request conn req = send conn req >> recv conn diff --git a/src/Database/Redis/PubSub.hs b/src/Database/Redis/PubSub.hs index d8e4298c..8497fccb 100644 --- a/src/Database/Redis/PubSub.hs +++ b/src/Database/Redis/PubSub.hs @@ -1,23 +1,42 @@ {-# LANGUAGE CPP, OverloadedStrings, RecordWildCards, EmptyDataDecls, - FlexibleInstances, FlexibleContexts #-} + FlexibleInstances, FlexibleContexts, GeneralizedNewtypeDeriving #-} module Database.Redis.PubSub ( publish, + + -- ** Subscribing to channels + -- $pubsubexpl + + -- *** Single-thread Pub/Sub pubSub, Message(..), PubSub(), - subscribe, unsubscribe, psubscribe, punsubscribe + subscribe, unsubscribe, psubscribe, punsubscribe, + -- *** Continuous Pub/Sub message controller + pubSubForever, + RedisChannel, RedisPChannel, MessageCallback, PMessageCallback, + PubSubController, newPubSubController, currentChannels, currentPChannels, + addChannels, addChannelsAndWait, removeChannels, removeChannelsAndWait, + UnregisterCallbacksAction ) where #if __GLASGOW_HASKELL__ < 710 import Control.Applicative import Data.Monoid #endif +import Control.Concurrent.Async (withAsync, waitEitherCatch, waitEitherCatchSTM) +import Control.Concurrent.STM +import Control.Exception (throwIO) import Control.Monad import Control.Monad.State import Data.ByteString.Char8 (ByteString) +import Data.List (foldl') +import Data.Maybe (isJust) +import Data.Pool +import qualified Data.HashMap.Strict as HM import qualified Database.Redis.Core as Core -import Database.Redis.Protocol (Reply(..)) +import qualified Database.Redis.ProtocolPipelining as PP +import Database.Redis.Protocol (Reply(..), renderRequest) import Database.Redis.Types -- |While in PubSub mode, we keep track of the number of current subscriptions @@ -83,6 +102,18 @@ sendCmd cmd = do lift $ Core.send (redisCmd cmd : changes cmd) modifyPending (updatePending cmd) +cmdCount :: Cmd a b -> Int +cmdCount DoNothing = 0 +cmdCount (Cmd c) = length c + +totalPendingChanges :: PubSub -> Int +totalPendingChanges (PubSub{..}) = + cmdCount subs + cmdCount unsubs + cmdCount psubs + cmdCount punsubs + +rawSendCmd :: (Command (Cmd a b)) => PP.Connection -> Cmd a b -> IO () +rawSendCmd _ DoNothing = return () +rawSendCmd conn cmd = PP.send conn $ renderRequest $ redisCmd cmd : changes cmd + plusChangeCnt :: Cmd a b -> Int -> Int plusChangeCnt DoNothing = id plusChangeCnt (Cmd cs) = (+ length cs) @@ -212,6 +243,343 @@ pubSub initial callback PubSubState{..} <- get unless (subCnt == 0 && pending == 0) recv +-- | A Redis channel name +type RedisChannel = ByteString + +-- | A Redis pattern channel name +type RedisPChannel = ByteString + +-- | A handler for a message from a subscribed channel. +-- The callback is passed the message content. +-- +-- Messages are processed synchronously in the receiving thread, so if the callback +-- takes a long time it will block other callbacks and other messages from being +-- received. If you need to move long-running work to a different thread, we suggest +-- you use 'TBQueue' with a reasonable bound, so that if messages are arriving faster +-- than you can process them, you do eventually block. +-- +-- If the callback throws an exception, the exception will be thrown from 'pubSubForever' +-- which will cause the entire Redis connection for all subscriptions to be closed. +-- As long as you call 'pubSubForever' in a loop you will reconnect to your subscribed +-- channels, but you should probably add an exception handler to each callback to +-- prevent this. +type MessageCallback = ByteString -> IO () + +-- | A handler for a message from a psubscribed channel. +-- The callback is passed the channel the message was sent on plus the message content. +-- +-- Similar to 'MessageCallback', callbacks are executed synchronously and any exceptions +-- are rethrown from 'pubSubForever'. +type PMessageCallback = RedisChannel -> ByteString -> IO () + +-- | An action that when executed will unregister the callbacks. It is returned from 'addChannels' +-- or 'addChannelsAndWait' and typically you would use it in 'bracket' to guarantee that you +-- unsubscribe from channels. For example, if you are using websockets to distribute messages to +-- clients, you could use something such as: +-- +-- > websocketConn <- Network.WebSockets.acceptRequest pending +-- > let mycallback msg = Network.WebSockets.sendTextData websocketConn msg +-- > bracket (addChannelsAndWait ctrl [("hello", mycallback)] []) id $ const $ do +-- > {- loop here calling Network.WebSockets.receiveData -} +type UnregisterCallbacksAction = IO () + +newtype UnregisterHandle = UnregisterHandle Integer + deriving (Eq, Show, Num) + +-- | A controller that stores a set of channels, pattern channels, and callbacks. +-- It allows you to manage Pub/Sub subscriptions and pattern subscriptions and alter them at +-- any time throughout the life of your program. +-- You should typically create the controller at the start of your program and then store it +-- through the life of your program, using 'addChannels' and 'removeChannels' to update the +-- current subscriptions. +data PubSubController = PubSubController + { callbacks :: TVar (HM.HashMap RedisChannel [(UnregisterHandle, MessageCallback)]) + , pcallbacks :: TVar (HM.HashMap RedisPChannel [(UnregisterHandle, PMessageCallback)]) + , sendChanges :: TBQueue PubSub + , pendingCnt :: TVar Int + , lastUsedCallbackId :: TVar UnregisterHandle + } + +-- | Create a new 'PubSubController'. Note that this does not subscribe to any channels, it just +-- creates the controller. The subscriptions will happen once 'pubSubForever' is called. +newPubSubController :: MonadIO m => [(RedisChannel, MessageCallback)] -- ^ the initial subscriptions + -> [(RedisPChannel, PMessageCallback)] -- ^ the initial pattern subscriptions + -> m PubSubController +newPubSubController x y = liftIO $ do + cbs <- newTVarIO (HM.map (\z -> [(0,z)]) $ HM.fromList x) + pcbs <- newTVarIO (HM.map (\z -> [(0,z)]) $ HM.fromList y) + c <- newTBQueueIO 10 + pending <- newTVarIO 0 + lastId <- newTVarIO 0 + return $ PubSubController cbs pcbs c pending lastId + +-- | Get the list of current channels in the 'PubSubController'. WARNING! This might not +-- exactly reflect the subscribed channels in the Redis server, because there is a delay +-- between adding or removing a channel in the 'PubSubController' and when Redis receives +-- and processes the subscription change request. +currentChannels :: MonadIO m => PubSubController -> m [RedisChannel] +currentChannels ctrl = HM.keys <$> (liftIO $ atomically $ readTVar $ callbacks ctrl) + +-- | Get the list of current pattern channels in the 'PubSubController'. WARNING! This might not +-- exactly reflect the subscribed channels in the Redis server, because there is a delay +-- between adding or removing a channel in the 'PubSubController' and when Redis receives +-- and processes the subscription change request. +currentPChannels :: MonadIO m => PubSubController -> m [RedisPChannel] +currentPChannels ctrl = HM.keys <$> (liftIO $ atomically $ readTVar $ pcallbacks ctrl) + +-- | Add channels into the 'PubSubController', and if there is an active 'pubSubForever', send the subscribe +-- and psubscribe commands to Redis. The 'addChannels' function is thread-safe. This function +-- does not wait for Redis to acknowledge that the channels have actually been subscribed; use +-- 'addChannelsAndWait' for that. +-- +-- You can subscribe to the same channel or pattern channel multiple times; the 'PubSubController' keeps +-- a list of callbacks and executes each callback in response to a message. +-- +-- The return value is an action 'UnregisterCallbacksAction' which will unregister the callbacks, +-- which should typically used with 'bracket'. +addChannels :: MonadIO m => PubSubController + -> [(RedisChannel, MessageCallback)] -- ^ the channels to subscribe to + -> [(RedisPChannel, PMessageCallback)] -- ^ the channels to pattern subscribe to + -> m UnregisterCallbacksAction +addChannels _ [] [] = return $ return () +addChannels ctrl newChans newPChans = liftIO $ do + ident <- atomically $ do + modifyTVar (lastUsedCallbackId ctrl) (+1) + ident <- readTVar $ lastUsedCallbackId ctrl + cm <- readTVar $ callbacks ctrl + pm <- readTVar $ pcallbacks ctrl + let newChans' = [ n | (n,_) <- newChans, not $ HM.member n cm] + newPChans' = [ n | (n, _) <- newPChans, not $ HM.member n pm] + ps = subscribe newChans' `mappend` psubscribe newPChans' + writeTBQueue (sendChanges ctrl) ps + writeTVar (callbacks ctrl) (HM.unionWith (++) cm (fmap (\z -> [(ident,z)]) $ HM.fromList newChans)) + writeTVar (pcallbacks ctrl) (HM.unionWith (++) pm (fmap (\z -> [(ident,z)]) $ HM.fromList newPChans)) + modifyTVar (pendingCnt ctrl) (+ totalPendingChanges ps) + return ident + return $ unsubChannels ctrl (map fst newChans) (map fst newPChans) ident + +-- | Call 'addChannels' and then wait for Redis to acknowledge that the channels are actually subscribed. +-- +-- Note that this function waits for all pending subscription change requests, so if you for example call +-- 'addChannelsAndWait' from multiple threads simultaneously, they all will wait for all pending +-- subscription changes to be acknowledged by Redis (this is due to the fact that we just track the total +-- number of pending change requests sent to Redis and just wait until that count reaches zero). +-- +-- This also correctly waits if the network connection dies during the subscription change. Say that the +-- network connection dies right after we send a subscription change to Redis. 'pubSubForever' will throw +-- 'ConnectionLost' and 'addChannelsAndWait' will continue to wait. Once you recall 'pubSubForever' +-- with the same 'PubSubController', 'pubSubForever' will open a new connection, send subscription commands +-- for all channels in the 'PubSubController' (which include the ones we are waiting for), +-- and wait for the responses from Redis. Only once we receive the response from Redis that it has subscribed +-- to all channels in 'PubSubController' will 'addChannelsAndWait' unblock and return. +addChannelsAndWait :: MonadIO m => PubSubController + -> [(RedisChannel, MessageCallback)] -- ^ the channels to subscribe to + -> [(RedisPChannel, PMessageCallback)] -- ^ the channels to psubscribe to + -> m UnregisterCallbacksAction +addChannelsAndWait _ [] [] = return $ return () +addChannelsAndWait ctrl newChans newPChans = do + unreg <- addChannels ctrl newChans newPChans + liftIO $ atomically $ do + r <- readTVar (pendingCnt ctrl) + when (r > 0) retry + return unreg + +-- | Remove channels from the 'PubSubController', and if there is an active 'pubSubForever', send the +-- unsubscribe commands to Redis. Note that as soon as this function returns, no more callbacks will be +-- executed even if more messages arrive during the period when we request to unsubscribe from the channel +-- and Redis actually processes the unsubscribe request. This function is thread-safe. +-- +-- If you remove all channels, the connection in 'pubSubForever' to redis will stay open and waiting for +-- any new channels from a call to 'addChannels'. If you really want to close the connection, +-- use 'Control.Concurrent.killThread' or 'Control.Concurrent.Async.cancel' to kill the thread running +-- 'pubSubForever'. +removeChannels :: MonadIO m => PubSubController + -> [RedisChannel] + -> [RedisPChannel] + -> m () +removeChannels _ [] [] = return () +removeChannels ctrl remChans remPChans = liftIO $ atomically $ do + cm <- readTVar $ callbacks ctrl + pm <- readTVar $ pcallbacks ctrl + let remChans' = filter (\n -> HM.member n cm) remChans + remPChans' = filter (\n -> HM.member n pm) remPChans + ps = (if null remChans' then mempty else unsubscribe remChans') + `mappend` (if null remPChans' then mempty else punsubscribe remPChans') + writeTBQueue (sendChanges ctrl) ps + writeTVar (callbacks ctrl) (foldl' (flip HM.delete) cm remChans') + writeTVar (pcallbacks ctrl) (foldl' (flip HM.delete) pm remPChans') + modifyTVar (pendingCnt ctrl) (+ totalPendingChanges ps) + +-- | Internal function to unsubscribe only from those channels matching the given handle. +unsubChannels :: PubSubController -> [RedisChannel] -> [RedisPChannel] -> UnregisterHandle -> IO () +unsubChannels ctrl chans pchans h = liftIO $ atomically $ do + cm <- readTVar $ callbacks ctrl + pm <- readTVar $ pcallbacks ctrl + + -- only worry about channels that exist + let remChans = filter (\n -> HM.member n cm) chans + remPChans = filter (\n -> HM.member n pm) pchans + + -- helper functions to filter out handlers that match + let filterHandle :: Maybe [(UnregisterHandle,a)] -> Maybe [(UnregisterHandle,a)] + filterHandle Nothing = Nothing + filterHandle (Just lst) = case filter (\x -> fst x /= h) lst of + [] -> Nothing + xs -> Just xs + let removeHandles :: HM.HashMap ByteString [(UnregisterHandle,a)] + -> ByteString + -> HM.HashMap ByteString [(UnregisterHandle,a)] + removeHandles m k = case filterHandle (HM.lookup k m) of -- recent versions of unordered-containers have alter + Nothing -> HM.delete k m + Just v -> HM.insert k v m + + -- maps after taking out channels matching the handle + let cm' = foldl' removeHandles cm remChans + pm' = foldl' removeHandles pm remPChans + + -- the channels to unsubscribe are those that no longer exist in cm' and pm' + let remChans' = filter (\n -> not $ HM.member n cm') remChans + remPChans' = filter (\n -> not $ HM.member n pm') remPChans + ps = (if null remChans' then mempty else unsubscribe remChans') + `mappend` (if null remPChans' then mempty else punsubscribe remPChans') + + -- do the unsubscribe + writeTBQueue (sendChanges ctrl) ps + writeTVar (callbacks ctrl) cm' + writeTVar (pcallbacks ctrl) pm' + modifyTVar (pendingCnt ctrl) (+ totalPendingChanges ps) + return () + +-- | Call 'removeChannels' and then wait for all pending subscription change requests to be acknowledged +-- by Redis. This uses the same waiting logic as 'addChannelsAndWait'. Since 'removeChannels' immediately +-- notifies the 'PubSubController' to start discarding messages, you likely don't need this function and +-- can just use 'removeChannels'. +removeChannelsAndWait :: MonadIO m => PubSubController + -> [RedisChannel] + -> [RedisPChannel] + -> m () +removeChannelsAndWait _ [] [] = return () +removeChannelsAndWait ctrl remChans remPChans = do + removeChannels ctrl remChans remPChans + liftIO $ atomically $ do + r <- readTVar (pendingCnt ctrl) + when (r > 0) retry + +-- | Internal thread which listens for messages and executes callbacks. +-- This is the only thread which ever receives data from the underlying +-- connection. +listenThread :: PubSubController -> PP.Connection -> IO () +listenThread ctrl rawConn = forever $ do + msg <- PP.recv rawConn + case decodeMsg msg of + Msg (Message channel msgCt) -> do + cm <- atomically $ readTVar (callbacks ctrl) + case HM.lookup channel cm of + Nothing -> return () + Just c -> mapM_ (\(_,x) -> x msgCt) c + Msg (PMessage pattern channel msgCt) -> do + pm <- atomically $ readTVar (pcallbacks ctrl) + case HM.lookup pattern pm of + Nothing -> return () + Just c -> mapM_ (\(_,x) -> x channel msgCt) c + Subscribed -> atomically $ + modifyTVar (pendingCnt ctrl) (\x -> x - 1) + Unsubscribed _ -> atomically $ + modifyTVar (pendingCnt ctrl) (\x -> x - 1) + +-- | Internal thread which sends subscription change requests. +-- This is the only thread which ever sends data on the underlying +-- connection. +sendThread :: PubSubController -> PP.Connection -> IO () +sendThread ctrl rawConn = forever $ do + PubSub{..} <- atomically $ readTBQueue (sendChanges ctrl) + rawSendCmd rawConn subs + rawSendCmd rawConn unsubs + rawSendCmd rawConn psubs + rawSendCmd rawConn punsubs + -- normally, the socket is flushed during 'recv', but + -- 'recv' could currently be blocking on a message. + PP.flush rawConn + +-- | Open a connection to the Redis server, register to all channels in the 'PubSubController', +-- and process messages and subscription change requests forever. The only way this will ever +-- exit is if there is an exception from the network code or an unhandled exception +-- in a 'MessageCallback' or 'PMessageCallback'. For example, if the network connection to Redis +-- dies, 'pubSubForever' will throw a 'ConnectionLost'. When such an exception is +-- thrown, you can recall 'pubSubForever' with the same 'PubSubController' which will open a +-- new connection and resubscribe to all the channels which are tracked in the 'PubSubController'. +-- +-- The general pattern is therefore during program startup create a 'PubSubController' and fork +-- a thread which calls 'pubSubForever' in a loop (using an exponential backoff algorithm +-- such as the package to not hammer the Redis +-- server if it does die). For example, +-- +-- @ +-- myhandler :: ByteString -> IO () +-- myhandler msg = putStrLn $ unpack $ decodeUtf8 msg +-- +-- onInitialComplete :: IO () +-- onInitialComplete = putStrLn "Redis acknowledged that mychannel is now subscribed" +-- +-- main :: IO () +-- main = do +-- conn <- connect defaultConnectInfo +-- pubSubCtrl <- newPubSubController [("mychannel", myhandler)] [] +-- forkIO $ forever $ +-- pubSubForever conn pubSubCtrl onInitialComplete +-- \`catch\` (\\(e :: SomeException) -> do +-- putStrLn $ "Got error: " ++ show e +-- threadDelay $ 50*1000) -- TODO: use exponential backoff +-- +-- {- elsewhere in your program, use pubSubCtrl to change subscriptions -} +-- @ +-- +-- At most one active 'pubSubForever' can be running against a single 'PubSubController' at any time. If +-- two active calls to 'pubSubForever' share a single 'PubSubController' there will be deadlocks. If +-- you do want to process messages using multiple connections to Redis, you can create more than one +-- 'PubSubController'. For example, create one PubSubController for each 'Control.Concurrent.getNumCapabilities' +-- and then create a Haskell thread bound to each capability each calling 'pubSubForever' in a loop. +-- This will create one network connection per controller/capability and allow you to +-- register separate channels and callbacks for each controller, spreading the load across the capabilities. +pubSubForever :: Core.Connection -- ^ The connection pool + -> PubSubController -- ^ The controller which keeps track of all subscriptions and handlers + -> IO () -- ^ This action is executed once Redis acknowledges that all the subscriptions in + -- the controller are now subscribed. You can use this after an exception (such as + -- 'ConnectionLost') to signal that all subscriptions are now reactivated. + -> IO () +pubSubForever (Core.Conn pool) ctrl onInitialLoad = withResource pool $ \rawConn -> do + -- get initial subscriptions and write them into the queue. + atomically $ do + let loop = tryReadTBQueue (sendChanges ctrl) >>= + \x -> if isJust x then loop else return () + loop + cm <- readTVar $ callbacks ctrl + pm <- readTVar $ pcallbacks ctrl + let ps = subscribe (HM.keys cm) `mappend` psubscribe (HM.keys pm) + writeTBQueue (sendChanges ctrl) ps + writeTVar (pendingCnt ctrl) (totalPendingChanges ps) + + withAsync (listenThread ctrl rawConn) $ \listenT -> + withAsync (sendThread ctrl rawConn) $ \sendT -> do + + -- wait for initial subscription count to go to zero or for threads to fail + mret <- atomically $ + (Left <$> (waitEitherCatchSTM listenT sendT)) + `orElse` + (Right <$> (readTVar (pendingCnt ctrl) >>= + \x -> if x > 0 then retry else return ())) + case mret of + Right () -> onInitialLoad + _ -> return () -- if there is an error, waitEitherCatch below will also see it + + -- wait for threads to end with error + merr <- waitEitherCatch listenT sendT + case merr of + (Right (Left err)) -> throwIO err + (Left (Left err)) -> throwIO err + _ -> return () -- should never happen, since threads exit only with an error + + ------------------------------------------------------------------------------ -- Helpers -- @@ -230,8 +598,19 @@ decodeMsg r@(MultiBulk (Just (r0:r1:r2:rs))) = either (errMsg r) id $ do decodeMessage = Message <$> decode r1 <*> decode r2 decodePMessage = PMessage <$> decode r1 <*> decode r2 <*> decode (head rs) decodeCnt = fromInteger <$> decode r2 - + decodeMsg r = errMsg r errMsg :: Reply -> a errMsg r = error $ "Hedis: expected pub/sub-message but got: " ++ show r + + +-- $pubsubexpl +-- There are two Pub/Sub implementations. First, there is a single-threaded implementation 'pubSub' +-- which is simpler to use but has the restriction that subscription changes can only be made in +-- response to a message. Secondly, there is a more complicated Pub/Sub controller 'pubSubForever' +-- that uses concurrency to support changing subscriptions at any time but requires more setup. +-- You should only use one or the other. In addition, no types or utility functions (that are part +-- of the public API) are shared, so functions or types in one of the following sections cannot +-- be used for the other. In particular, be aware that they use different utility functions to subscribe +-- and unsubscribe to channels. diff --git a/test/ManualPubSub.hs b/test/ManualPubSub.hs new file mode 100644 index 00000000..e8e4c839 --- /dev/null +++ b/test/ManualPubSub.hs @@ -0,0 +1,92 @@ +{-# LANGUAGE OverloadedStrings, ScopedTypeVariables #-} +module ManualPubSub (main) where + +-- A test for PubSub which must be run manually to be able to kill and restart the redis-server. +-- I execute this with `stack runghc ManualPubSub.hs` + +import Database.Redis +import Data.Monoid ((<>)) +import Control.Monad +import Control.Exception +import Control.Monad.Trans (liftIO) +import Control.Concurrent +import Control.Concurrent.Async +import Data.Text +import Data.ByteString (ByteString) +import Data.Text.Encoding +import System.IO + +-- | publish messages every 2 seconds to several channels +publishThread :: Connection -> IO () +publishThread c = runRedis c $ loop (0 :: Int) + where + loop i = do + let msg = encodeUtf8 $ pack $ "Publish iteration " ++ show i + void $ publish "foo" ("foo" <> msg) + void $ publish "bar" ("bar" <> msg) + void $ publish "baz:1" ("baz1" <> msg) + void $ publish "baz:2" ("baz2" <> msg) + liftIO $ threadDelay $ 2*1000*1000 -- 2 seconds + loop (i+1) + +onInitialComplete :: IO () +onInitialComplete = hPutStrLn stderr "Initial subscr complete" + +handlerThread :: Connection -> PubSubController -> IO () +handlerThread conn ctrl = forever $ + pubSubForever conn ctrl onInitialComplete + `catch` (\(e :: SomeException) -> do + hPutStrLn stderr $ "Got error: " ++ show e + threadDelay $ 50*1000) + +msgHandler :: ByteString -> IO () +msgHandler msg = hPutStrLn stderr $ "Saw msg: " ++ unpack (decodeUtf8 msg) + +pmsgHandler :: RedisChannel -> ByteString -> IO () +pmsgHandler channel msg = hPutStrLn stderr $ "Saw pmsg: " ++ unpack (decodeUtf8 channel) ++ unpack (decodeUtf8 msg) + +showChannels :: Connection -> IO () +showChannels c = do + resp :: Either Reply [ByteString] <- runRedis c $ sendRequest ["PUBSUB", "CHANNELS"] + liftIO $ hPutStrLn stderr $ "Current redis channels: " ++ show resp + +main :: IO () +main = do + ctrl <- newPubSubController [("foo", msgHandler)] [] + conn <- connect defaultConnectInfo + + withAsync (publishThread conn) $ \_pubT -> do + withAsync (handlerThread conn ctrl) $ \_handlerT -> do + + void $ hPutStrLn stderr "Press enter to subscribe to bar" >> getLine + void $ addChannels ctrl [("bar", msgHandler)] [] + + void $ hPutStrLn stderr "Press enter to subscribe to baz:*" >> getLine + void $ addChannels ctrl [] [("baz:*", pmsgHandler)] + + void $ hPutStrLn stderr "Press enter to unsub from foo" >> getLine + removeChannels ctrl ["foo"] [] + + void $ hPutStrLn stderr "Try killing and restarting the redis server" >> getLine + withAsync (publishThread conn) $ \_pubT -> do + + void $ hPutStrLn stderr "Press enter to unsub from baz:*" >> getLine + removeChannels ctrl [] ["baz:*"] + + void $ hPutStrLn stderr "Press enter to sub to foo and baz:*" >> getLine + unsub1 <- addChannelsAndWait ctrl [("foo", msgHandler)] [("baz:*", pmsgHandler)] + showChannels conn + + void $ hPutStrLn stderr "Press enter to sub to foo again and baz:1" >> getLine + unsub2 <- addChannelsAndWait ctrl [("foo", msgHandler), ("baz:1", msgHandler)] [] + showChannels conn + + void $ hPutStrLn stderr "Press enter to unsub to foo and baz:1" >> getLine + unsub2 + + void $ hPutStrLn stderr "Press enter to unsub to foo and baz:*" >> getLine + showChannels conn + unsub1 + + void $ hPutStrLn stderr "Press enter to exit" >> getLine + showChannels conn diff --git a/test/PubSubTest.hs b/test/PubSubTest.hs new file mode 100644 index 00000000..17e6419d --- /dev/null +++ b/test/PubSubTest.hs @@ -0,0 +1,181 @@ +{-# LANGUAGE CPP, OverloadedStrings #-} +module PubSubTest (testPubSubThreaded) where + +#if __GLASGOW_HASKELL__ < 710 +import Control.Applicative +import Data.Monoid (mappend) +#endif +import Control.Concurrent +import Control.Monad +import Control.Concurrent.Async +import Control.Exception +import Data.Typeable +--import Control.Monad.Trans +--import Data.Time +import qualified Data.List +import Data.Text +import Data.ByteString +import Control.Concurrent.STM +--import Data.Time.Clock.POSIX +import qualified Test.Framework as Test +import qualified Test.Framework.Providers.HUnit as Test (testCase) +import qualified Test.HUnit as HUnit + +import Database.Redis + +testPubSubThreaded :: [Connection -> Test.Test] +testPubSubThreaded = [removeAllTest, callbackErrorTest, removeFromUnregister] + +-- | A handler label to be able to distinguish the handlers from one another +-- to help make sure we unregister the correct handler. +type HandlerLabel = Text + +data TestMsg = MsgFromChannel HandlerLabel ByteString + | MsgFromPChannel HandlerLabel RedisChannel ByteString + deriving (Show, Eq) + +type MsgVar = TVar [TestMsg] + +-- | A handler that just writes the message into the TVar +handler :: HandlerLabel -> MsgVar -> MessageCallback +handler label ref msg = atomically $ + modifyTVar ref $ \x -> x ++ [MsgFromChannel label msg] + +-- | A pattern handler that just writes the message into the TVar +phandler :: HandlerLabel -> MsgVar -> PMessageCallback +phandler label ref chan msg = atomically $ + modifyTVar ref $ \x -> x ++ [MsgFromPChannel label chan msg] + +-- | Wait for a given message to be received +waitForMessage :: MsgVar -> HandlerLabel -> ByteString -> IO () +waitForMessage ref label msg = atomically $ do + let expected = MsgFromChannel label msg + lst <- readTVar ref + unless (expected `Prelude.elem` lst) retry + writeTVar ref $ Prelude.filter (/= expected) lst + +-- | Wait for a given pattern message to be received +waitForPMessage :: MsgVar -> HandlerLabel -> RedisChannel -> ByteString -> IO () +waitForPMessage ref label chan msg = atomically $ do + let expected = MsgFromPChannel label chan msg + lst <- readTVar ref + unless (expected `Prelude.elem` lst) retry + writeTVar ref $ Prelude.filter (/= expected) lst + +expectRedisChannels :: Connection -> [RedisChannel] -> IO () +expectRedisChannels conn expected = do + actual <- runRedis conn $ sendRequest ["PUBSUB", "CHANNELS"] + case actual of + Left err -> HUnit.assertFailure $ "Error geting channels: " ++ show err + Right s -> HUnit.assertEqual "redis channels" (Data.List.sort s) (Data.List.sort expected) + +-- | Test basic messages, plus using removeChannels +removeAllTest :: Connection -> Test.Test +removeAllTest conn = Test.testCase "Multithreaded Pub/Sub - basic" $ do + msgVar <- newTVarIO [] + initialComplete <- newTVarIO False + ctrl <- newPubSubController [("foo1", handler "InitialFoo1" msgVar), ("foo2", handler "InitialFoo2" msgVar)] + [("bar1:*", phandler "InitialBar1" msgVar), ("bar2:*", phandler "InitialBar2" msgVar)] + withAsync (pubSubForever conn ctrl (atomically $ writeTVar initialComplete True)) $ \_ -> do + + -- wait for initial + atomically $ readTVar initialComplete >>= \b -> if b then return () else retry + expectRedisChannels conn ["foo1", "foo2"] + + runRedis conn $ publish "foo1" "Hello" + waitForMessage msgVar "InitialFoo1" "Hello" + + runRedis conn $ publish "bar2:zzz" "World" + waitForPMessage msgVar "InitialBar2" "bar2:zzz" "World" + + -- subscribe to foo1 and bar1 again + addChannelsAndWait ctrl [("foo1", handler "NewFoo1" msgVar)] [("bar1:*", phandler "NewBar1" msgVar)] + expectRedisChannels conn ["foo1", "foo2"] + + runRedis conn $ publish "foo1" "abcdef" + waitForMessage msgVar "InitialFoo1" "abcdef" + waitForMessage msgVar "NewFoo1" "abcdef" + + -- unsubscribe from foo1 and bar1 + removeChannelsAndWait ctrl ["foo1", "unusued"] ["bar1:*", "unused:*"] + expectRedisChannels conn ["foo2"] + + -- foo2 and bar2 are still subscribed + runRedis conn $ publish "foo2" "12345" + waitForMessage msgVar "InitialFoo2" "12345" + + runRedis conn $ publish "bar2:aaa" "0987" + waitForPMessage msgVar "InitialBar2" "bar2:aaa" "0987" + +data TestError = TestError ByteString + deriving (Eq, Show, Typeable) +instance Exception TestError + +-- | Test an error thrown from a message handler +callbackErrorTest :: Connection -> Test.Test +callbackErrorTest conn = Test.testCase "Multithreaded Pub/Sub - error in handler" $ do + initialComplete <- newTVarIO False + ctrl <- newPubSubController [("foo", throwIO . TestError)] [] + + thread <- async (pubSubForever conn ctrl (atomically $ writeTVar initialComplete True)) + atomically $ readTVar initialComplete >>= \b -> if b then return () else retry + + runRedis conn $ publish "foo" "Hello" + + ret <- waitCatch thread + case ret of + Left (SomeException e) | cast e == Just (TestError "Hello") -> return () + _ -> HUnit.assertFailure $ "Did not properly throw error from message thread " ++ show ret + +-- | Test removing channels by using the return value of 'addHandlersAndWait'. +removeFromUnregister :: Connection -> Test.Test +removeFromUnregister conn = Test.testCase "Multithreaded Pub/Sub - unregister handlers" $ do + msgVar <- newTVarIO [] + initialComplete <- newTVarIO False + ctrl <- newPubSubController [] [] + withAsync (pubSubForever conn ctrl (atomically $ writeTVar initialComplete True)) $ \_ -> do + atomically $ readTVar initialComplete >>= \b -> if b then return () else retry + + -- register to some channels + void $ addChannelsAndWait ctrl + [("abc", handler "InitialAbc" msgVar), ("xyz", handler "InitialXyz" msgVar)] + [("def:*", phandler "InitialDef" msgVar), ("uvw", phandler "InitialUvw" msgVar)] + expectRedisChannels conn ["abc", "xyz"] + + runRedis conn $ publish "abc" "Hello" + waitForMessage msgVar "InitialAbc" "Hello" + + -- register to some more channels + unreg <- addChannelsAndWait ctrl + [("abc", handler "SecondAbc" msgVar), ("123", handler "Second123" msgVar)] + [("def:*", phandler "SecondDef" msgVar), ("890:*", phandler "Second890" msgVar)] + expectRedisChannels conn ["abc", "xyz", "123"] + + -- check messages on all channels + runRedis conn $ publish "abc" "World" + waitForMessage msgVar "InitialAbc" "World" + waitForMessage msgVar "SecondAbc" "World" + + runRedis conn $ publish "123" "World2" + waitForMessage msgVar "Second123" "World2" + + runRedis conn $ publish "def:bbbb" "World3" + waitForPMessage msgVar "InitialDef" "def:bbbb" "World3" + waitForPMessage msgVar "SecondDef" "def:bbbb" "World3" + + runRedis conn $ publish "890:tttt" "World4" + waitForPMessage msgVar "Second890" "890:tttt" "World4" + + -- unregister + unreg + + -- we have no way of waiting until unregister actually happened, so just delay and hope + threadDelay $ 1000*1000 -- 1 second + expectRedisChannels conn ["abc", "xyz"] + + -- now only initial should be around. In particular, abc should still be subscribed + runRedis conn $ publish "abc" "World5" + waitForMessage msgVar "InitialAbc" "World5" + + runRedis conn $ publish "def:cccc" "World6" + waitForPMessage msgVar "InitialDef" "def:cccc" "World6" diff --git a/test/Test.hs b/test/Test.hs index 9d22797b..5cfdeb41 100644 --- a/test/Test.hs +++ b/test/Test.hs @@ -16,7 +16,7 @@ import qualified Test.Framework.Providers.HUnit as Test (testCase) import qualified Test.HUnit as HUnit import Database.Redis - +import PubSubTest ------------------------------------------------------------------------------ -- Main and helpers @@ -57,6 +57,7 @@ tests conn = map ($conn) $ concat [ testsMisc, testsKeys, testsStrings, [testHashes], testsLists, testsSets, [testHyperLogLog] , testsZSets, [testPubSub], [testTransaction], [testScripting] , testsConnection, testsServer, [testScans], [testZrangelex] + , testPubSubThreaded -- should always be run last as connection gets closed after it , [testQuit] ]