diff --git a/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs b/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs index e7d51104ab..13620703c3 100644 --- a/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs +++ b/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs @@ -8,6 +8,9 @@ module Network.Mux.Bearer.AttenuatedChannel , Size , SuccessOrFailure (..) , Attenuation (..) + , QueueChannel + , newAttenuatedChannel + , echoQueueChannel , newConnectedAttenuatedChannelPair , attenuationChannelAsBearer -- * Trace @@ -60,6 +63,19 @@ data QueueChannel m = QueueChannel { qcWrite :: StrictTVar m (Maybe (StrictTQueue m Message)) } + +-- A `QueueChannel` which receives what is written to it. +-- +echoQueueChannel :: MonadSTM m => STM m (QueueChannel m) +echoQueueChannel = do + q <- newTQueue + v <- newTVar (Just q) + return QueueChannel { + qcRead = v, + qcWrite = v + } + + -- -- QueueChannel API -- diff --git a/ouroboros-network-framework/CHANGELOG.md b/ouroboros-network-framework/CHANGELOG.md index 49f4022aa2..823d534dc3 100644 --- a/ouroboros-network-framework/CHANGELOG.md +++ b/ouroboros-network-framework/CHANGELOG.md @@ -12,6 +12,8 @@ `unregister{Inbound,Outbound}Connection` to `release{Inbound,Outbound}Connection`. `AssertionLocation` constructors were renamed as well. * Added `RawBearer` API (see https://github.com/IntersectMBO/ouroboros-network/pull/4395) +* Connection manager is using `ConnectionId`s to identify connections, this + affects its API. ### Non-breaking changes diff --git a/ouroboros-network-framework/demo/connection-manager.hs b/ouroboros-network-framework/demo/connection-manager.hs index 577dc37d7e..f05892a373 100644 --- a/ouroboros-network-framework/demo/connection-manager.hs +++ b/ouroboros-network-framework/demo/connection-manager.hs @@ -485,7 +485,7 @@ bidirectionalExperiment muxBundle res <- releaseOutboundConnection - connectionManager remoteAddr + connectionManager connId case res of UnsupportedState inState -> do traceWith debugTracer ( "initiator-loop" diff --git a/ouroboros-network-framework/ouroboros-network-framework.cabal b/ouroboros-network-framework/ouroboros-network-framework.cabal index 18f64c26ce..580e3acfbb 100644 --- a/ouroboros-network-framework/ouroboros-network-framework.cabal +++ b/ouroboros-network-framework/ouroboros-network-framework.cabal @@ -29,8 +29,10 @@ library Ouroboros.Network.Channel Ouroboros.Network.ConnectionHandler Ouroboros.Network.ConnectionId + Ouroboros.Network.ConnectionManager.ConnMap Ouroboros.Network.ConnectionManager.Core Ouroboros.Network.ConnectionManager.InformationChannel + Ouroboros.Network.ConnectionManager.State Ouroboros.Network.ConnectionManager.Types Ouroboros.Network.Context Ouroboros.Network.Driver diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs index d8e04f422c..0e6c9e49ed 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs @@ -422,7 +422,7 @@ mkSnocket scheduleMap = do ) . getSchedule - <$> (getScheduleMap scheduleMap) + <$> getScheduleMap scheduleMap v <- newTVarIO inboundSchedule return $ Snocket { getLocalAddr, @@ -449,10 +449,10 @@ mkSnocket scheduleMap = do $> x getLocalAddr (FD v) = - fdLocalAddress <$> atomically (readTVar v) + fdLocalAddress <$> readTVarIO v getRemoteAddr (FD v) = do - mbRemote <- fdRemoteAddress <$> atomically (readTVar v) + mbRemote <- fdRemoteAddress <$> readTVarIO v case mbRemote of Nothing -> throwIO InvalidArgumentError Just addr -> pure addr @@ -830,7 +830,7 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = Right (Just (Disconnected {})) -> pure () - Right (Just (Connected _ _ _)) -> do + Right (Just (Connected connId _ _)) -> do threadDelay (either id id (seActiveDelay conn)) -- if this outbound connection is not -- executed within inbound connection, @@ -850,11 +850,11 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = -- successful. void $ releaseInboundConnection - connectionManager addr + connectionManager connId res <- releaseOutboundConnection - connectionManager addr + connectionManager connId case res of UnsupportedState st -> throwIO (UnsupportedStateError @@ -879,17 +879,20 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = (Accepted fd' addr', acceptNext) -> do thread <- async $ do + localAddress <- getLocalAddr snocket fd' + let connId = ConnectionId { localAddress, + remoteAddress = addr' } labelThisThread ("th-inbound-" ++ show (getTestAddress addr)) Just conn' <- fdScheduleEntry - <$> atomically (readTVar (fdState fd')) + <$> readTVarIO (fdState fd') when (addr /= addr' && seIdx conn /= seIdx conn') $ throwIO (MismatchedScheduleEntry (addr, seIdx conn) (addr', seIdx conn')) _ <- includeInboundConnection - connectionManager maxBound fd' addr + connectionManager maxBound fd' connId t <- getMonotonicTime let activeDelay = either id id (seActiveDelay conn) @@ -902,11 +905,11 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = threadDelay x _ <- promotedToWarmRemote - connectionManager addr + connectionManager connId threadDelay y _ <- demotedToColdRemote - connectionManager addr + connectionManager connId return () ) (threadDelay activeDelay) @@ -930,7 +933,7 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = -- TODO: should we run 'unregisterInboundConnection' depending on 'seActiveDelay' void $ releaseInboundConnection - connectionManager addr + connectionManager connId go (thread : threads) acceptNext conns' (AcceptFailure err, _acceptNext) -> throwIO err diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs index 26980ae994..bbb968f1b0 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs @@ -884,7 +884,7 @@ multinodeExperiment inboundTrTracer trTracer inboundTracer debugTracer cmTracer m <- readTVar connVar check (Map.member (connId remoteAddr) m) writeTVar connVar (Map.delete (connId remoteAddr) m) - void (releaseOutboundConnection cm remoteAddr) + void (releaseOutboundConnection cm (connId remoteAddr)) go (Map.delete remoteAddr connMap) RunMiniProtocols remoteAddr reqs -> do atomically $ do @@ -1039,6 +1039,7 @@ prop_connection_manager_valid_transitions_racy (Fixed rnd) serverAcc (ArbDataFlow dataFlow) defaultBearerInfo mns@(MultiNodeScript events attenuationMap) = exploreSimTrace id sim $ \_ trace -> + counterexample (ppTrace trace) $ validate_transitions mns trace where sim :: IOSim s () @@ -1159,12 +1160,13 @@ prop_connection_manager_valid_transition_order (Fixed rnd) serverAcc (ArbDataFlo in tabulate "ConnectionEvents" (map showConnectionEvents events) . counterexample (ppScript mns) . counterexample (Trace.ppTrace show show abstractTransitionEvents) + . counterexample (ppTrace trace) . bifoldMap ( \ case MainReturn {} -> mempty _ -> All False ) - (verifyAbstractTransitionOrder True) + (verifyAbstractTransitionOrder id True) . fmap (map ttTransition) . groupConns id abstractStateIsFinalTransition $ abstractTransitionEvents @@ -1204,7 +1206,7 @@ prop_connection_manager_valid_transition_order_racy (Fixed rnd) serverAcc (ArbDa MainReturn {} -> mempty _ -> All False ) - (verifyAbstractTransitionOrder True) + (verifyAbstractTransitionOrder id True) . fmap (map ttTransition) . groupConns id abstractStateIsFinalTransition $ abstractTransitionEvents diff --git a/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs b/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs index a02fa86e35..1c5e11399f 100644 --- a/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs @@ -40,6 +40,7 @@ import Data.ByteString.Lazy (ByteString) import Data.Foldable (traverse_) import Data.Functor (void) import Data.Map qualified as Map +import Data.Maybe (isNothing) import Data.Set (Set) import Data.Set qualified as Set import Text.Printf @@ -50,6 +51,7 @@ import Ouroboros.Network.Snocket import Simulation.Network.Snocket import Network.Mux as Mx +import Network.Mux.Types as Mx import Network.TypedProtocol.Codec.CBOR import Network.TypedProtocol.Core import Network.TypedProtocol.Peer @@ -81,6 +83,8 @@ tests = prop_connect_to_not_listening_socket , testProperty "simultaneous_open" prop_simultaneous_open + , testProperty "self connect" + prop_self_connect ] type TestAddr = TestAddress Int @@ -543,6 +547,47 @@ prop_simultaneous_open defaultBearerInfo = snocket getState wait clientAsync + +-- | Check that when we bind both outbound and inbound socket to the same +-- address, and connect the outbound to inbound: +-- +-- * accept loop never returns +-- * the outbound socket acts as a mirror +-- +-- This is how socket API behaves on Linux. +-- +prop_self_connect :: ByteString -> Property +prop_self_connect payload = + runSimOrThrow sim + where + addr :: TestAddress Int + addr = TestAddress 0 + + sim :: forall s. IOSim s Property + sim = + withSnocket nullTracer noAttenuation Map.empty + $ \snocket _getState -> + withAsync (runServer addr snocket + (close snocket) acceptOne return) + $ \serverThread -> do + bracket (openToConnect snocket addr) + (close snocket) + $ \fd -> do + bind snocket fd addr + connect snocket fd addr + bearer <- getBearer makeFDBearer 10 nullTracer fd + let channel = bearerAsChannel bearer (MiniProtocolNum 0) InitiatorDir + send channel payload + payload' <- recv channel + threadDelay 1 + serverResult <- atomically $ + (Just <$> waitSTM serverThread) + `orElse` + pure Nothing + return $ Just payload === payload' + .&&. isNothing serverResult + + -- -- Utils -- diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs index 6971e051b3..e2969ab874 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs @@ -411,10 +411,10 @@ makeConnectionHandler muxTracer singMuxMode -- --- | 'ConnectionHandlerTrace' is embedded into 'ConnectionManagerTrace' with --- 'Ouroboros.Network.ConnectionManager.Types.ConnectionHandlerTrace' --- constructor. It already includes 'ConnectionId' so we don't need to take --- care of it here. +-- | 'ConnectionHandlerTrace' is embedded into +-- 'Ouroboros.Network.ConnectionManager.Core.Trace' with +-- 'Ouroboros.Network.ConnectionManager.Types.TrConnectionHandler' constructor. +-- It already includes 'ConnectionId' so we don't need to take care of it here. -- -- TODO: when 'Handshake' will get its own tracer, independent of 'Mux', it -- should be embedded into 'ConnectionHandlerTrace'. diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionId.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionId.hs index b607fbff1c..b505db4af0 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionId.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionId.hs @@ -33,7 +33,8 @@ data ConnectionId addr = ConnectionId { -- -- /Note:/ we relay on the fact that `remoteAddress` is an order -- preserving map (which allows us to use `Map.mapKeysMonotonic` in some --- cases). +-- cases. We also relay on this particular order in +-- `Ouroboros.Network.ConnectionManager.State.liveConnections` -- instance Ord addr => Ord (ConnectionId addr) where conn `compare` conn' = diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/ConnMap.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/ConnMap.hs new file mode 100644 index 0000000000..22c4e79ea3 --- /dev/null +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/ConnMap.hs @@ -0,0 +1,269 @@ +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Ouroboros.Network.ConnectionManager.ConnMap + ( ConnMap (..) + , LocalAddr (..) + , toList + , toMap + , empty + , insert + , insertUnknownLocalAddr + , delete + -- , deleteUnknownLocalAddr + , deleteAtRemoteAddr + , lookup + , lookupByRemoteAddr + , updateLocalAddr + , traverseMaybeWithKey + ) where + +import Prelude hiding (lookup) + +import Data.Foldable qualified as Foldable +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as Map +import System.Random (RandomGen, uniformR) + +import Ouroboros.Network.ConnectionId + +data LocalAddr peerAddr = + -- | A reserved slot for an outbound connection which is being created. The + -- outbound connection must be in the `ReservedOutbound` state. + UnknownLocalAddr + -- | All connections which are not in the `ReservedOutbound` state use + -- `LocalAddr`, since for them the local address is known. + | LocalAddr peerAddr + deriving (Show, Eq, Ord) + + +-- | The outer map keys are remote addresses, the internal ones are local +-- addresses. +-- +newtype ConnMap peerAddr a + = ConnMap { + getConnMap :: + Map peerAddr + (Map (LocalAddr peerAddr) a) + } + deriving (Show, Functor, Foldable) + +instance Traversable (ConnMap peerAddr) where + traverse f (ConnMap m) = ConnMap <$> traverse (traverse f) m + + +toList :: ConnMap m a -> [a] +toList = Foldable.toList + + +-- | Create a map of all connections with a known `ConnectionId`. +-- +toMap :: forall peerAddr a. + Ord peerAddr + => ConnMap peerAddr a + -> Map (ConnectionId peerAddr) a +toMap = + -- We can use `fromAscList` because of the `Ord` instance of `ConnectionId`. + -- /NOTE:/ if `fromAscList` is used on input which doesn't satisfy its + -- precondition, then `Map.lookup` might fail when it shouldn't. + Map.fromAscList + . Map.foldrWithKey + (\remoteAddress st conns -> + Map.foldrWithKey + (\localAddr conn conns' -> + case localAddr of + UnknownLocalAddr -> conns' + LocalAddr localAddress -> + (ConnectionId { remoteAddress, localAddress }, conn) : conns' + ) + conns + st + ) + [] + . getConnMap + + +empty :: ConnMap peerAddr a +empty = ConnMap Map.empty + + +insert :: Ord peerAddr + => ConnectionId peerAddr + -> a + -> ConnMap peerAddr a + -> ConnMap peerAddr a +insert ConnectionId { remoteAddress, localAddress } a = + ConnMap + . Map.alter + (\case + Nothing -> Just $! Map.singleton (LocalAddr localAddress) a + Just st -> Just $! Map.insert (LocalAddr localAddress) a st) + remoteAddress + . getConnMap + + +insertUnknownLocalAddr + :: Ord peerAddr + => peerAddr + -> a + -> ConnMap peerAddr a + -> ConnMap peerAddr a +insertUnknownLocalAddr remoteAddress a = + ConnMap + . Map.alter + (\case + Nothing -> Just $! Map.singleton UnknownLocalAddr a + Just st -> Just $! Map.insert UnknownLocalAddr a st + ) + remoteAddress + . getConnMap + + +delete :: Ord peerAddr + => ConnectionId peerAddr + -> ConnMap peerAddr a + -> ConnMap peerAddr a +delete ConnectionId { remoteAddress, localAddress } = + ConnMap + . Map.alter + (\case + Nothing -> Nothing + Just st -> + let st' = Map.delete (LocalAddr localAddress) st + in if Map.null st' + then Nothing + else Just st' + ) + remoteAddress + . getConnMap + + +deleteAtRemoteAddr + :: (Ord peerAddr, Eq a) + => peerAddr + -- ^ remoteAddr + -> a + -- ^ element to remove + -> ConnMap peerAddr a + -> ConnMap peerAddr a +deleteAtRemoteAddr remoteAddress a = + ConnMap + . Map.alter + (\case + Nothing -> Nothing + Just st -> + let st' = Map.filter (/=a) st in + if Map.null st' + then Nothing + else Just st' + ) + remoteAddress + . getConnMap + + + +lookup :: Ord peerAddr + => ConnectionId peerAddr + -> ConnMap peerAddr a + -> Maybe a +lookup ConnectionId { remoteAddress, localAddress } (ConnMap st) = + case remoteAddress `Map.lookup` st of + Nothing -> Nothing + Just st' -> LocalAddr localAddress `Map.lookup` st' + + +-- | Find a random entry for a given remote address. +-- +-- NOTE: the outbound governor will only ask for a connection to a peer if it +-- doesn't have one (and one isn't being created). This property, simplifies +-- `lookupOutbound`: we can pick (randomly) one of the connections to the +-- remote peer. When the outbound governor is asking, likely all of the +-- connections are in an inbound state. The outbound governor will has a grace +-- period after demoting a peer, so a race (when a is being demoted and +-- promoted at the same time) is unlikely. +-- +lookupByRemoteAddr + :: forall rnd peerAddr a. + ( Ord peerAddr + , RandomGen rnd + ) + => rnd + -- ^ a fresh `rnd` (it must come from a `split`) + -> peerAddr + -- ^ remote address + -> ConnMap peerAddr a + -> (Maybe a) +lookupByRemoteAddr rnd remoteAddress (ConnMap st) = + case remoteAddress `Map.lookup` st of + Nothing -> Nothing + Just st' -> + case UnknownLocalAddr `Map.lookup` st' of + Just a -> Just a + Nothing -> + if Map.null st' + then Nothing + else let (indx, _rnd') = uniformR (0, Map.size st' - 1) rnd + in Just $ snd $ Map.elemAt indx st' + + +-- | Promote `UnknownLocalAddr` to `LocalAddr`. +-- +updateLocalAddr + :: Ord peerAddr + => ConnectionId peerAddr + -> ConnMap peerAddr a + -> (Bool, ConnMap peerAddr a) + -- ^ Return `True` iff the entry was updated. +updateLocalAddr ConnectionId { remoteAddress, localAddress } (ConnMap m) = + ConnMap + <$> Map.alterF + (\case + Nothing -> (False, Nothing) + Just m' -> + let -- delete & lookup for entry in `UnknownLocalAddr` + (ma, m'') = + Map.alterF + (\x -> (x,Nothing)) + UnknownLocalAddr + m' + in + case ma of + -- there was no entry, so no need to update the inner map + Nothing -> (False, Just m') + -- we have an entry: put it in the `LocalAddr`, but only if it's + -- not present in the map. + Just {} -> + fmap Just + . Map.alterF + (\case + Nothing -> (True, ma) + a@Just{} -> (False, a) + ) + (LocalAddr localAddress) + $ m'' + ) + remoteAddress + m + + +traverseMaybeWithKey + :: Applicative f + => (Either peerAddr (ConnectionId peerAddr) -> a -> f (Maybe b)) + -> ConnMap peerAddr a + -> f [b] +traverseMaybeWithKey fn = + fmap (concat . Map.elems) + . Map.traverseMaybeWithKey + (\remoteAddress st -> + fmap (Just . Map.elems) + . Map.traverseMaybeWithKey + (\case + UnknownLocalAddr -> fn (Left remoteAddress) + LocalAddr localAddress -> fn (Right ConnectionId { remoteAddress, + localAddress }) + ) + $ st + ) + . getConnMap diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 90083bac36..8b461a6e7f 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -1,17 +1,14 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} --- Undecidable instances are need for 'Show' instance of 'ConnectionState'. -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE UndecidableInstances #-} -- | The implementation of connection manager. -- @@ -25,14 +22,14 @@ module Ouroboros.Network.ConnectionManager.Core , defaultProtocolIdleTimeout , defaultResetTimeout , ConnectionState (..) - , abstractState + , State.abstractState ) where import Control.Applicative (Alternative) import Control.Concurrent.Class.MonadSTM qualified as LazySTM import Control.Concurrent.Class.MonadSTM.Strict import Control.Exception (assert) -import Control.Monad (forM_, guard, when, (>=>)) +import Control.Monad (forM_, guard, unless, when, (>=>)) import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadFork (throwTo) import Control.Monad.Class.MonadThrow hiding (handle) @@ -41,20 +38,19 @@ import Control.Monad.Class.MonadTimer.SI import Control.Monad.Fix import Control.Tracer (Tracer, contramap, traceWith) import Data.Foldable (foldMap', traverse_) -import Data.Function (on) import Data.Functor (void, ($>)) -import Data.Maybe (maybeToList) import Data.Proxy (Proxy (..)) import Data.Typeable (Typeable) import GHC.Stack (CallStack, HasCallStack, callStack) import System.Random (StdGen, split) -import Data.Map (Map) -import Data.Map qualified as Map +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as Map import Data.Set qualified as Set import Data.Monoid.Synchronisation import Data.Set (Set) +import Data.Tuple (swap) import Data.Wedge import Data.Word (Word32) @@ -65,12 +61,14 @@ import Ouroboros.Network.ConnectionId import Ouroboros.Network.ConnectionManager.InformationChannel (InformationChannel) import Ouroboros.Network.ConnectionManager.InformationChannel qualified as InfoChannel +import Ouroboros.Network.ConnectionManager.State (ConnectionManagerState, + ConnectionState (..), FreshIdSupply, MutableConnState (..)) +import Ouroboros.Network.ConnectionManager.State qualified as State import Ouroboros.Network.ConnectionManager.Types import Ouroboros.Network.InboundGovernor.Event (NewConnectionInfo (..)) import Ouroboros.Network.MuxMode import Ouroboros.Network.Server.RateLimiting (AcceptedConnectionsLimit (..)) import Ouroboros.Network.Snocket -import Ouroboros.Network.Testing.Utils (WithName (..)) -- | Arguments for a 'ConnectionManager' which are independent of 'MuxMode'. @@ -83,6 +81,8 @@ data Arguments handlerTrace socket peerAddr handle handleError versionNumber ver -- | Trace state transitions. -- + -- TODO: do we need this tracer? In some tests we relay on `traceTVar` in + -- `newNetworkMutableState` instead. trTracer :: Tracer m (TransitionTrace peerAddr (ConnectionState peerAddr handle handleError versionNumber m)), @@ -151,151 +151,12 @@ data Arguments handlerTrace socket peerAddr handle handleError versionNumber ver } --- | 'MutableConnState', which supplies a unique identifier. --- --- TODO: We can get away without id, by tracking connections in --- `TerminatingState` using a separate priority search queue. --- -data MutableConnState peerAddr handle handleError version m = MutableConnState { - -- | A unique identifier - -- - connStateId :: !Int - - , -- | Mutable state - -- - connVar :: !(StrictTVar m (ConnectionState peerAddr handle handleError - version m)) - } - - -instance Eq (MutableConnState peerAddr handle handleError version m) where - (==) = (==) `on` connStateId - - --- | A supply of fresh id's. --- --- We use a fresh ids for 'MutableConnState'. --- -newtype FreshIdSupply m = FreshIdSupply { getFreshId :: STM m Int } - - --- | Create a 'FreshIdSupply' inside an 'STM' monad. --- -newFreshIdSupply :: forall m. MonadSTM m - => Proxy m -> STM m (FreshIdSupply m) -newFreshIdSupply _ = do - (v :: StrictTVar m Int) <- newTVar 0 - let getFreshId :: STM m Int - getFreshId = do - c <- readTVar v - writeTVar v (succ c) - return c - return $ FreshIdSupply { getFreshId } - - -newMutableConnState :: forall peerAddr handle handleError version m. - ( MonadTraceSTM m - , Typeable peerAddr - ) - => peerAddr - -> FreshIdSupply m - -> ConnectionState peerAddr handle handleError - version m - -> STM m (MutableConnState peerAddr handle handleError - version m) -newMutableConnState peerAddr freshIdSupply connState = do - connStateId <- getFreshId freshIdSupply - connVar <- newTVar connState - -- This tracing is a no op in IO. - -- - -- We need this for IOSimPOR testing of connection manager state - -- transition tests. It can happen that the transitions happen - -- correctly but IOSimPOR reorders the threads that log the transitions. - -- This is a false positive and we don't want that to happen. - -- - -- The simplest way to do so is to leverage the `traceTVar` IOSim - -- capabilities. These trace messages won't be reordered by IOSimPOR - -- since these happen atomically in STM. - -- - traceTVar - (Proxy @m) connVar - (\mbPrev curr -> - let currAbs = abstractState (Known curr) - in case mbPrev of - Just prev | - let prevAbs = abstractState (Known prev) - , prevAbs /= currAbs -> pure - $ TraceDynamic - $ WithName connStateId - $ TransitionTrace peerAddr - $ mkAbsTransition prevAbs - currAbs - Nothing -> pure - $ TraceDynamic - $ WithName connStateId - $ TransitionTrace peerAddr - $ mkAbsTransition TerminatedSt - currAbs - _ -> pure DontTrace - ) - return $ MutableConnState { connStateId, connVar } - - --- | 'ConnectionManager' state: for each peer we keep a 'ConnectionState' in --- a mutable variable, which reduces congestion on the 'TMVar' which keeps --- 'ConnectionManagerState'. --- --- It is important we can lookup by remote @peerAddr@; this way we can find if --- the connection manager is already managing a connection towards that --- @peerAddr@ and reuse the 'ConnectionState'. --- -type ConnectionManagerState peerAddr handle handleError version m - = Map peerAddr (MutableConnState peerAddr handle handleError version m) connectionManagerStateToCounters - :: Map peerAddr (ConnectionState peerAddr handle handleError version m) + :: State.ConnMap peerAddr (ConnectionState peerAddr handle handleError version m) -> ConnectionManagerCounters -connectionManagerStateToCounters = - foldMap' connectionStateToCounters - --- | State of a connection. --- -data ConnectionState peerAddr handle handleError version m = - -- | Each outbound connections starts in this state. - ReservedOutboundState - - -- | Each inbound connection starts in this state, outbound connection - -- reach this state once `connect` call returns. - -- - -- note: the async handle is lazy, because it's passed with 'mfix'. - | UnnegotiatedState !Provenance - !(ConnectionId peerAddr) - (Async m ()) - - -- | @OutboundState Unidirectional@ state. - | OutboundUniState !(ConnectionId peerAddr) !(Async m ()) !handle +connectionManagerStateToCounters = foldMap' connectionStateToCounters - -- | Either @OutboundState Duplex@ or @OutboundState^\tau Duplex@. - | OutboundDupState !(ConnectionId peerAddr) !(Async m ()) !handle !TimeoutExpired - - -- | Before connection is reset it is put in 'OutboundIdleState' for the - -- duration of 'outboundIdleTimeout'. - -- - | OutboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow - | InboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow - | InboundState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow - | DuplexState !(ConnectionId peerAddr) !(Async m ()) !handle - | TerminatingState !(ConnectionId peerAddr) !(Async m ()) !(Maybe handleError) - | TerminatedState !(Maybe handleError) - - --- | Return 'True' for states in which the connection was already closed. --- -connectionTerminated :: ConnectionState peerAddr handle handleError version m - -> Bool -connectionTerminated TerminatingState {} = True -connectionTerminated TerminatedState {} = True -connectionTerminated _ = False -- | Perform counting from an 'AbstractState' @@ -310,10 +171,10 @@ connectionStateToCounters state = UnnegotiatedState Outbound _ _ -> outboundConn - OutboundUniState _ _ _ -> unidirectionalConn + OutboundUniState {} -> unidirectionalConn <> outboundConn - OutboundDupState _ _ _ _ -> duplexConn + OutboundDupState {} -> duplexConn <> outboundConn OutboundIdleState _ _ _ Unidirectional -> unidirectionalConn @@ -334,13 +195,13 @@ connectionStateToCounters state = InboundState _ _ _ Duplex -> duplexConn <> inboundConn - DuplexState _ _ _ -> fullDuplexConn + DuplexState {} -> fullDuplexConn <> duplexConn - <> inboundConn + <> inboundConn <> outboundConn - TerminatingState _ _ _ -> mempty - TerminatedState _ -> mempty + TerminatingState {} -> mempty + TerminatedState {} -> mempty where fullDuplexConn = ConnectionManagerCounters 1 0 0 0 0 duplexConn = ConnectionManagerCounters 0 1 0 0 0 @@ -349,76 +210,6 @@ connectionStateToCounters state = outboundConn = ConnectionManagerCounters 0 0 0 0 1 -instance ( Show peerAddr - , Show handleError - , MonadAsync m - ) - => Show (ConnectionState peerAddr handle handleError version m) where - show ReservedOutboundState = "ReservedOutboundState" - show (UnnegotiatedState pr connId connThread) = - concat ["UnnegotiatedState " - , show pr - , " " - , show connId - , " " - , show (asyncThreadId connThread) - ] - show (OutboundUniState connId connThread _handle) = - concat [ "OutboundState Unidirectional " - , show connId - , " " - , show (asyncThreadId connThread) - ] - show (OutboundDupState connId connThread _handle expired) = - concat [ "OutboundState " - , show connId - , " " - , show (asyncThreadId connThread) - , " " - , show expired - ] - show (OutboundIdleState connId connThread _handle df) = - concat [ "OutboundIdleState " - , show connId - , " " - , show (asyncThreadId connThread) - , " " - , show df - ] - show (InboundIdleState connId connThread _handle df) = - concat [ "InboundIdleState " - , show connId - , " " - , show (asyncThreadId connThread) - , " " - , show df - ] - show (InboundState connId connThread _handle df) = - concat [ "InboundState " - , show connId - , " " - , show (asyncThreadId connThread) - , " " - , show df - ] - show (DuplexState connId connThread _handle) = - concat [ "DuplexState " - , show connId - , " " - , show (asyncThreadId connThread) - ] - show (TerminatingState connId connThread handleError) = - concat ([ "TerminatingState " - , show connId - , " " - , show (asyncThreadId connThread) - ] - ++ maybeToList ((' ' :) . show <$> handleError)) - show (TerminatedState handleError) = - concat (["TerminatedState"] - ++ maybeToList ((' ' :) . show <$> handleError)) - - getConnThread :: ConnectionState peerAddr handle handleError version m -> Maybe (Async m ()) getConnThread ReservedOutboundState = Nothing @@ -465,25 +256,6 @@ isInboundConn DuplexState {} = True isInboundConn TerminatingState {} = False isInboundConn TerminatedState {} = False - -abstractState :: MaybeUnknown (ConnectionState muxMode peerAddr m a b) -> AbstractState -abstractState = \case - Unknown -> UnknownConnectionSt - Race s' -> go s' - Known s' -> go s' - where - go :: ConnectionState muxMode peerAddr m a b -> AbstractState - go ReservedOutboundState {} = ReservedOutboundSt - go (UnnegotiatedState pr _ _) = UnnegotiatedSt pr - go (OutboundUniState _ _ _) = OutboundUniSt - go (OutboundDupState _ _ _ te) = OutboundDupSt te - go (OutboundIdleState _ _ _ df) = OutboundIdleSt df - go (InboundIdleState _ _ _ df) = InboundIdleSt df - go (InboundState _ _ _ df) = InboundSt df - go DuplexState {} = DuplexSt - go TerminatingState {} = TerminatingSt - go TerminatedState {} = TerminatedSt - -- | The default value for 'timeWaitTimeout'. -- defaultTimeWaitTimeout :: DiffTime @@ -606,9 +378,9 @@ with -- will be closed. -> m a with args@Arguments { - tracer = tracer, - trTracer = trTracer, - muxTracer = muxTracer, + tracer, + trTracer, + muxTracer, ipv4Address, ipv6Address, addressType, @@ -634,38 +406,28 @@ with args@Arguments { , StrictTVar m StdGen )) <- atomically $ do - v <- newTMVar Map.empty + v <- newTMVar State.empty labelTMVar v "cm-state" - traceTMVar (Proxy :: Proxy m) v - $ \old new -> - case (old, new) of - (Nothing, _) -> pure DontTrace - -- taken - (Just (Just _), Nothing) -> pure (TraceString "cm-state: taken") - -- released - (Just Nothing, Just _) -> pure (TraceString "cm-state: released") - (_, _) -> pure DontTrace - - freshIdSupply <- newFreshIdSupply (Proxy :: Proxy m) + traceTMVar (Proxy :: Proxy m) v $ \_ mbst -> do + st' <- case mbst of + Nothing -> pure Nothing + Just st -> Just <$> traverse (inspectTVar (Proxy :: Proxy m) . toLazyTVar . connVar) st + return (TraceString (show st')) + + freshIdSupply <- State.newFreshIdSupply (Proxy :: Proxy m) stdGenVar <- newTVar (stdGen args) return (freshIdSupply, v, stdGenVar) let readState - :: STM m (Map peerAddr AbstractState) - readState = do - state <- readTMVar stateVar - traverse ( fmap (abstractState . Known) - . readTVar - . connVar - ) - state + :: STM m (State.ConnMap peerAddr AbstractState) + readState = readTMVar stateVar >>= State.readAbstractStateMap waitForOutboundDemotion - :: peerAddr + :: ConnectionId peerAddr -> STM m () - waitForOutboundDemotion addr = do + waitForOutboundDemotion connId = do state <- readState - case Map.lookup addr state of + case State.lookup connId state of Nothing -> return () Just UnknownConnectionSt -> return () Just InboundIdleSt {} -> return () @@ -692,7 +454,7 @@ with args@Arguments { OutboundConnectionManager { ocmAcquireConnection = acquireOutboundConnectionImpl freshIdSupply stateVar - outboundHandler, + stdGenVar outboundHandler, ocmReleaseConnection = releaseOutboundConnectionImpl stateVar stdGenVar }, @@ -728,7 +490,7 @@ with args@Arguments { OutboundConnectionManager { ocmAcquireConnection = acquireOutboundConnectionImpl freshIdSupply stateVar - outboundHandler, + stdGenVar outboundHandler, ocmReleaseConnection = releaseOutboundConnectionImpl stateVar stdGenVar } @@ -760,9 +522,9 @@ with args@Arguments { -- Spawning one thread for each connection cleanup avoids spending time -- waiting for locks and cleanup logic that could delay closing the -- connections and making us not respecting certain timeouts. - asyncs <- Map.elems - <$> Map.traverseMaybeWithKey - (\peerAddr MutableConnState { connVar } -> do + asyncs <- State.traverseMaybeWithKey + (\peerAddrOrConnId MutableConnState { connVar } -> do + let remoteAddr = either id remoteAddress peerAddrOrConnId -- cleanup handler for that thread will close socket associated -- with the thread. We put each connection in 'TerminatedState' to -- try that none of the connection threads will enter @@ -777,12 +539,12 @@ with args@Arguments { connState <- readTVar connVar let connState' = TerminatedState Nothing trT = - TransitionTrace peerAddr (mkTransition connState connState') - absConnState = abstractState (Known connState) + TransitionTrace remoteAddr (mkTransition connState connState') + absConnState = State.abstractState (Known connState) shouldTraceTerminated = absConnState /= TerminatedSt shouldTraceUnknown = absConnState == ReservedOutboundSt trU = TransitionTrace - peerAddr + remoteAddr (Transition { fromState = Known connState' , toState = Unknown }) @@ -818,8 +580,8 @@ with args@Arguments { where traceCounters :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) -> m () traceCounters stateVar = do - mState <- atomically $ readTMVar stateVar >>= traverse (readTVar . connVar) - traceWith tracer (TrConnectionManagerCounters (connectionManagerStateToCounters mState)) + state <- atomically $ readTMVar stateVar >>= State.readConnectionStates + traceWith tracer (TrConnectionManagerCounters (connectionManagerStateToCounters state)) countIncomingConnections :: ConnectionManagerState peerAddr handle handleError version m @@ -827,7 +589,7 @@ with args@Arguments { countIncomingConnections st = inboundConns . connectionManagerStateToCounters - <$> traverse (readTVar . connVar) st + <$> State.readConnectionStates st -- Fork connection thread. @@ -872,7 +634,7 @@ with args@Arguments { -- hits there we will update `connVar`. uninterruptibleMask $ \unmask -> do traceWith tracer (TrConnectionCleanup connId) - eTransition <- modifyTMVar stateVar $ \state -> do + mbTransition <- modifyTMVar stateVar $ \state -> do eTransition <- atomically $ do connState <- readTVar connVar let connState' = TerminatedState Nothing @@ -913,17 +675,16 @@ with args@Arguments { traverse_ (traceWith trTracer) mbTransition close snocket socket return ( state - , Left () + , Nothing ) Right transition -> do close snocket socket return ( state - , Right transition + , Just transition ) - case eTransition of - Left () -> do - + case mbTransition of + Nothing -> do let transition = TransitionTrace peerAddr @@ -931,17 +692,18 @@ with args@Arguments { { fromState = Known (TerminatedState Nothing) , toState = Unknown } - mbTransition <- modifyTMVar stateVar $ \state -> - case Map.lookup peerAddr state of + mbTransition' <- modifyTMVar stateVar $ \state -> + case State.lookup connId state of Nothing -> pure (state, Nothing) Just v -> if mutableConnState == v - then pure (Map.delete peerAddr state , Just transition) + then pure (State.delete connId state , Just transition) else pure (state , Nothing) - traverse_ (traceWith trTracer) mbTransition + traverse_ (traceWith trTracer) mbTransition' traceCounters stateVar - Right transition -> + + Just transition -> do traceWith tracer (TrConnectionTimeWait connId) when (timeWaitTimeout > 0) $ let -- make sure we wait at least 'timeWaitTimeout', we @@ -967,7 +729,7 @@ with args@Arguments { trs <- atomically $ do connState <- readTVar connVar let transition' = transition { fromState = Known connState } - shouldTrace = abstractState (Known connState) + shouldTrace = State.abstractState (Known connState) /= TerminatedSt writeTVar connVar (TerminatedState Nothing) -- We have to be careful when deleting it from @@ -976,12 +738,12 @@ with args@Arguments { modifyTMVarPure stateVar ( \state -> - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> (state, False) Just v -> if mutableConnState == v - then (Map.delete peerAddr state , True) - else (state , False) + then (State.delete connId state, True) + else (state , False) ) if updated @@ -1016,7 +778,7 @@ with args@Arguments { -- their state to 'TerminatedState'; -- * an io action which logs and cancels all the connection handler -- threads. - mkPruneAction :: peerAddr + mkPruneAction :: ConnectionId peerAddr -> Int -- ^ number of connections to prune -> ConnectionManagerState peerAddr handle handleError version m @@ -1028,27 +790,28 @@ with args@Arguments { -> STM m (Bool, PruneAction m) -- ^ return if the connection was choose to be pruned and the -- 'PruneAction' - mkPruneAction peerAddr numberToPrune state connState' connVar stdGenVar connThread = do - (choiceMap' :: Map peerAddr ( ConnectionType - , Async m () - , StrictTVar m - (ConnectionState - peerAddr - handle handleError - version m) - )) - <- flip Map.traverseMaybeWithKey state $ \_peerAddr MutableConnState { connVar = connVar' } -> - (\cs -> do - -- this expression returns @Maybe (connType, connThread)@; - -- 'traverseMaybeWithKey' collects all 'Just' cases. - guard (isInboundConn cs) - (,,connVar') <$> getConnType cs - <*> getConnThread cs) - <$> readTVar connVar' - let choiceMap = + mkPruneAction connId numberToPrune state connState' connVar stdGenVar connThread = do + choiceMap' + <- Map.traverseMaybeWithKey + (\_ MutableConnState { connVar = connVar' } -> + (\cs -> do + -- this expression returns @Maybe (connType, connThread)@; + -- 'traverseMaybeWithKey' collects all 'Just' cases. + guard (isInboundConn cs) + (,,connVar') <$> getConnType cs + <*> getConnThread cs) + <$> readTVar connVar' + ) + (State.toMap state) + let choiceMap :: Map (ConnectionId peerAddr) + ( ConnectionType + , Async m () + , StrictTVar m (ConnectionState peerAddr handle handleError version m) + ) + choiceMap = case getConnType connState' of Nothing -> assert False choiceMap' - Just a -> Map.insert peerAddr (a, connThread, connVar) + Just a -> Map.insert connId (a, connThread, connVar) choiceMap' stdGen <- stateTVar stdGenVar split @@ -1061,7 +824,7 @@ with args@Arguments { forM_ pruneMap $ \(_, _, connVar') -> writeTVar connVar' (TerminatedState Nothing) - return ( peerAddr `Set.member` pruneSet + return ( connId `Set.member` pruneSet , PruneAction $ do traceWith tracer (TrPruneConnections (Map.keysSet pruneMap) numberToPrune @@ -1089,29 +852,26 @@ with args@Arguments { -- whether we can include the connection or not. -> socket -- ^ resource to include in the state - -> peerAddr - -- ^ remote address used as an identifier of the resource + -> ConnectionId peerAddr + -- ^ connection id used as an identifier of the resource -> m (Connected peerAddr handle handleError) includeInboundConnectionImpl freshIdSupply stateVar handler hardLimit socket - peerAddr = do - (r, connId) <- modifyTMVar stateVar $ \state -> do - localAddress <- getLocalAddr snocket socket + connId = do + r <- modifyTMVar stateVar $ \state -> do numberOfCons <- atomically $ countIncomingConnections state - let connId = ConnectionId { localAddress, remoteAddress = peerAddr } - - -- Check if after accepting this connection we get above the + let -- Check if after accepting this connection we get above the -- hard limit canAccept = numberOfCons + 1 <= fromIntegral hardLimit if canAccept then do let provenance = Inbound - traceWith tracer (TrIncludeConnection provenance peerAddr) + traceWith tracer (TrIncludeConnection provenance (remoteAddress connId)) (reader, writer) <- newEmptyPromiseIO (connThread, connVar, connState0, connState) <- mfix $ \ ~(connThread, _mutableConnVar, _connState0, _connState) -> do @@ -1137,11 +897,11 @@ with args@Arguments { let connState' = UnnegotiatedState provenance connId connThread (mutableConnVar', connState0') <- atomically $ do - let v0 = Map.lookup peerAddr state + let v0 = State.lookup connId state case v0 of Nothing -> do -- 'Accepted' - v <- newMutableConnState peerAddr freshIdSupply connState' + v <- State.newMutableConnState (remoteAddress connId) freshIdSupply connState' labelTVar (connVar v) ("conn-state-" ++ show connId) return (v, Nothing) Just v -> do @@ -1169,8 +929,8 @@ with args@Arguments { InboundState {} -> writeTVar (connVar v) connState' $> assert False v - TerminatingState {} -> newMutableConnState peerAddr freshIdSupply connState' - TerminatedState {} -> newMutableConnState peerAddr freshIdSupply connState' + TerminatingState {} -> State.newMutableConnState (remoteAddress connId) freshIdSupply connState' + TerminatedState {} -> State.newMutableConnState (remoteAddress connId) freshIdSupply connState' labelTVar (connVar v') ("conn-state-" ++ show connId) return (v', Just connState0') @@ -1180,16 +940,16 @@ with args@Arguments { stateVar mutableConnVar' socket connId writer handler return (connThread', mutableConnVar', connState0', connState') - traceWith trTracer (TransitionTrace peerAddr + traceWith trTracer (TransitionTrace (remoteAddress connId) Transition { fromState = maybe Unknown Known connState0 , toState = Known connState }) - return ( Map.insert peerAddr connVar state - , (Just (connVar, connThread, reader), connId) + return ( State.insert connId connVar state + , Just (connVar, connThread, reader) ) else return ( state - , (Nothing, connId) + , Nothing ) case r of @@ -1203,10 +963,10 @@ with args@Arguments { res <- atomically $ readPromise reader case res of Left handleError -> do - terminateInboundWithErrorOrQuery connId connVar connThread peerAddr stateVar mutableConnState $ Just handleError + terminateInboundWithErrorOrQuery connId connVar connThread stateVar mutableConnState $ Just handleError Right HandshakeConnectionQuery -> do - terminateInboundWithErrorOrQuery connId connVar connThread peerAddr stateVar mutableConnState Nothing + terminateInboundWithErrorOrQuery connId connVar connThread stateVar mutableConnState Nothing Right (HandshakeConnectionResult handle (_version, versionData)) -> do let dataFlow = connectionDataFlow versionData @@ -1216,7 +976,7 @@ with args@Arguments { -- Inbound connections cannot be found in this state at this -- stage. ReservedOutboundState -> - throwSTM (withCallStack (ImpossibleState peerAddr)) + throwSTM (withCallStack (ImpossibleState (remoteAddress connId))) -- -- The common case. @@ -1250,23 +1010,23 @@ with args@Arguments { ) InboundIdleState {} -> - throwSTM (withCallStack (ImpossibleState peerAddr)) + throwSTM (withCallStack (ImpossibleState (remoteAddress connId))) -- At this stage the inbound connection cannot be in -- 'InboundState', it would mean that there was another thread -- that included that connection, but this would violate @TCP@ -- constraints. InboundState {} -> - throwSTM (withCallStack (ImpossibleState peerAddr)) + throwSTM (withCallStack (ImpossibleState (remoteAddress connId))) DuplexState {} -> - throwSTM (withCallStack (ImpossibleState peerAddr)) + throwSTM (withCallStack (ImpossibleState (remoteAddress connId))) TerminatingState {} -> return (False, Nothing, Inbound) TerminatedState {} -> return (False, Nothing, Inbound) - traverse_ (traceWith trTracer . TransitionTrace peerAddr) mbTransition + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) mbTransition traceCounters stateVar -- Note that we don't set a timeout thread here which would @@ -1296,7 +1056,16 @@ with args@Arguments { else return $ Disconnected connId Nothing - terminateInboundWithErrorOrQuery connId connVar connThread peerAddr stateVar mutableConnState handleErrorM = do + terminateInboundWithErrorOrQuery + :: ConnectionId peerAddr + -> StrictTVar m (ConnectionState peerAddr handle handleError version m) + -> Async m () + -> StrictTMVar + m (ConnectionManagerState peerAddr handle handleError version m) + -> MutableConnState peerAddr handle handleError version m + -> Maybe handleError + -> m (Connected peerAddr handle1 handleError) + terminateInboundWithErrorOrQuery connId connVar connThread stateVar mutableConnState handleErrorM = do transitions <- atomically $ do connState <- readTVar connVar @@ -1304,22 +1073,22 @@ with args@Arguments { case classifyHandleError <$> handleErrorM of Just HandshakeFailure -> TerminatingState connId connThread - handleErrorM + handleErrorM Just HandshakeProtocolViolation -> TerminatedState handleErrorM -- On inbound query, connection is terminating. Nothing -> TerminatingState connId connThread - handleErrorM + handleErrorM transition = mkTransition connState connState' - absConnState = abstractState (Known connState) + absConnState = State.abstractState (Known connState) shouldTrace = absConnState /= TerminatedSt updated <- modifyTMVarSTM stateVar ( \state -> - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> return (state, False) Just mutableConnState' -> if mutableConnState' == mutableConnState @@ -1338,8 +1107,9 @@ with args@Arguments { -- tracing accordingly. writeTVar connVar connState' - return (Map.delete peerAddr state , True) - else return (state , False) + return (State.delete connId state, True) + else + return (state , False) ) if updated @@ -1370,7 +1140,7 @@ with args@Arguments { -- overwriting. else return [ ] - traverse_ (traceWith trTracer . TransitionTrace peerAddr) transitions + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) transitions traceCounters stateVar return (Disconnected connId handleErrorM) @@ -1380,13 +1150,13 @@ with args@Arguments { -- action. releaseInboundConnectionImpl :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult DemotedToColdRemoteTr) - releaseInboundConnectionImpl stateVar peerAddr = mask_ $ do - traceWith tracer (TrReleaseConnection Inbound peerAddr) + releaseInboundConnectionImpl stateVar connId = mask_ $ do + traceWith tracer (TrReleaseConnection Inbound connId) (mbThread, mbTransition, result, mbAssertion) <- atomically $ do state <- readTMVar stateVar - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> do -- Note: this can happen if the inbound connection manager is -- notified late about the connection which has already terminated @@ -1398,7 +1168,7 @@ with args@Arguments { ) Just MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of -- In any of the following two states releasing is not -- supported. 'includeInboundConnection' is a synchronous @@ -1421,7 +1191,7 @@ with args@Arguments { -- TimeoutExpired : OutboundState^\tau Duplex -- → OutboundState Duplex -- @ - OutboundDupState connId connThread handle Ticking -> do + OutboundDupState _connId connThread handle Ticking -> do let connState' = OutboundDupState connId connThread handle Expired writeTVar connVar connState' return ( Nothing @@ -1429,7 +1199,7 @@ with args@Arguments { , OperationSuccess KeepTr , Nothing ) - OutboundDupState connId _connThread _handle Expired -> + OutboundDupState _connId _connThread _handle Expired -> assert False $ return ( Nothing , Nothing @@ -1449,7 +1219,7 @@ with args@Arguments { -- unexpected state, this state is reachable only from outbound -- states - OutboundIdleState connId _connThread _handle _dataFlow -> + OutboundIdleState _connId _connThread _handle _dataFlow -> return ( Nothing , Nothing , OperationSuccess CommitTr @@ -1465,7 +1235,7 @@ with args@Arguments { -- @ -- -- Note: the 'TrDemotedToColdRemote' is logged by the server. - InboundIdleState connId connThread _handle _dataFlow -> do + InboundIdleState _connId connThread _handle _dataFlow -> do let connState' = TerminatingState connId connThread Nothing writeTVar connVar connState' return ( Just connThread @@ -1476,7 +1246,7 @@ with args@Arguments { -- the inbound protocol governor was supposed to call -- 'demotedToColdRemote' first. - InboundState connId connThread _handle _dataFlow -> do + InboundState _connId connThread _handle _dataFlow -> do let connState' = TerminatingState connId connThread Nothing writeTVar connVar connState' return ( Just connThread @@ -1490,7 +1260,7 @@ with args@Arguments { -- the inbound connection governor ought to call -- 'demotedToColdRemote' first. - DuplexState connId connThread handle -> do + DuplexState _connId connThread handle -> do let connState' = OutboundDupState connId connThread handle Ticking writeTVar connVar connState' return ( Nothing @@ -1521,7 +1291,7 @@ with args@Arguments { , Nothing ) - traverse_ (traceWith trTracer . TransitionTrace peerAddr) mbTransition + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) mbTransition traceCounters stateVar -- 'throwTo' avoids blocking until 'timeWaitTimeout' expires. @@ -1539,19 +1309,21 @@ with args@Arguments { :: HasCallStack => FreshIdSupply m -> StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) + -> StrictTVar m StdGen -> ConnectionHandlerFn handlerTrace socket peerAddr handle handleError (version, versionData) m -> peerAddr -> m (Connected peerAddr handle handleError) - acquireOutboundConnectionImpl freshIdSupply stateVar handler peerAddr = do + acquireOutboundConnectionImpl freshIdSupply stateVar stdGenVar handler peerAddr = do let provenance = Outbound traceWith tracer (TrIncludeConnection provenance peerAddr) (trace, mutableConnState@MutableConnState { connVar } , eHandleWedge) <- atomically $ do state <- readTMVar stateVar - case Map.lookup peerAddr state of + stdGen <- stateTVar stdGenVar split + case State.lookupByRemoteAddr stdGen peerAddr state of Just mutableConnState@MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of ReservedOutboundState -> return ( Just (Right (TrConnectionExists provenance peerAddr st)) @@ -1591,7 +1363,7 @@ with args@Arguments { ) OutboundIdleState _connId _connThread _handle _dataFlow -> - let tr = abstractState (Known connState) in + let tr = State.abstractState (Known connState) in return ( Just (Right (TrForbiddenOperation peerAddr tr)) , mutableConnState , Left (withCallStack (ForbiddenOperation peerAddr tr)) @@ -1663,23 +1435,16 @@ with args@Arguments { let connState' = ReservedOutboundState (mutableConnState :: MutableConnState peerAddr handle handleError version m) - <- newMutableConnState peerAddr freshIdSupply connState' + <- State.newMutableConnState peerAddr freshIdSupply connState' -- TODO: label `connVar` using 'ConnectionId' labelTVar (connVar mutableConnState) ("conn-state-" ++ show peerAddr) - -- record the @connVar@ in 'ConnectionManagerState' we can use - -- 'swapTMVar' as we did not use 'takeTMVar' at the beginning of - -- this transaction. Since we already 'readTMVar', it will not - -- block. - (mbConnState - :: Maybe (ConnectionState peerAddr handle handleError version m)) - <- swapTMVar stateVar - (Map.insert peerAddr mutableConnState state) - >>= traverse (readTVar . connVar) . Map.lookup peerAddr + writeTMVar stateVar + (State.insertUnknownLocalAddr peerAddr mutableConnState state) return ( Just (Left (TransitionTrace peerAddr Transition { - fromState = maybe Unknown Known mbConnState, + fromState = Unknown, toState = Known connState' })) , mutableConnState @@ -1723,51 +1488,17 @@ with args@Arguments { (\socket -> uninterruptibleMask_ $ do close snocket socket trs <- atomically $ modifyTMVarSTM stateVar $ \state -> do - case Map.lookup peerAddr state of - -- Lookup failed, which means connection was already - -- removed. So we just update the connVar and trace - -- accordingly. - Nothing -> do - connState <- readTVar connVar - let connState' = TerminatedState Nothing - writeTVar connVar connState' - return - ( state - , [ mkTransition connState connState' - , Transition (Known connState') Unknown - ] - ) - - -- Current connVar. - Just mutableConnState' -> do - connState <- readTVar connVar - case connState of - -- Update the state only if the connection was in - -- 'ReservedOutboundState'. This covers the case - -- when we connect to ourselves, in which case: we - -- first set the connection state to - -- `ReservedOutboundState`, then race connect - -- & accept calls. If the connection was - -- accepted it, it will use the same - -- 'MutableConnState', and if the inbound side is - -- using the connection the state will be - -- different than `ReservedOutboundState`. - ReservedOutboundState | mutableConnState' == mutableConnState -> do - let state' = Map.delete peerAddr state - connState' = TerminatedState Nothing - writeTVar connVar connState' - return - ( state' - , [ mkTransition connState connState' - , Transition (Known connState') - Unknown - ] - ) - - -- self connection: the connection might have - -- been accepted, in such case do not modify its - -- state. - _ -> return (state, []) + connState <- readTVar connVar + let state' = State.deleteAtRemoteAddr peerAddr mutableConnState state + connState' = TerminatedState Nothing + writeTVar connVar connState' + return + ( state' + , [ mkTransition connState connState' + , Transition (Known connState') + Unknown + ] + ) traverse_ (traceWith trTracer . TransitionTrace peerAddr) trs traceCounters stateVar @@ -1799,6 +1530,21 @@ with args@Arguments { let connId = ConnectionId { localAddress , remoteAddress = peerAddr } + updated <- atomically $ modifyTMVarPure stateVar (swap . State.updateLocalAddr connId) + unless updated $ + -- there exists a connection with exact same + -- `ConnectionId` + -- + -- NOTE: + -- When we are connecting from our own `(ip, port)` to + -- itself. In this case on linux, the `connect` + -- returns, while `accept` doesn't. The outbound + -- socket is connected to itself (simultaneuos TCP + -- open?). Since the `accept` call never returns, the + -- `connId` slot must have been available, and thus + -- `State.updateLocalAddr` must have returned `True`. + throwIO (withCallStack $ ConnectionExists provenance peerAddr) + return (socket, connId) -- @@ -1847,7 +1593,7 @@ with args@Arguments { , Just (TrUnexpectedlyFalseAssertion (AcquireOutboundConnection (Just connId) - (abstractState (Known connState)) + (State.abstractState (Known connState)) ) ) ) @@ -1932,7 +1678,7 @@ with args@Arguments { TerminatedState _ -> return Nothing _ -> - let st = abstractState (Known connState) in + let st = State.abstractState (Known connState) in throwSTM (withCallStack (ForbiddenOperation peerAddr st)) traverse_ (traceWith trTracer . TransitionTrace peerAddr) mbTransition @@ -1966,7 +1712,7 @@ with args@Arguments { throwSTM (withCallStack (ConnectionExists provenance connId)) OutboundIdleState _connId _connThread _handle _dataFlow -> - let tr = abstractState (Known connState) in + let tr = State.abstractState (Known connState) in throwSTM (withCallStack (ForbiddenOperation peerAddr tr)) InboundIdleState _connId connThread handle dataFlow@Duplex -> do @@ -2054,32 +1800,33 @@ with args@Arguments { case classifyHandleError <$> handleErrorM of Just HandshakeFailure -> TerminatingState connId connThread - handleErrorM + handleErrorM Just HandshakeProtocolViolation -> TerminatedState handleErrorM -- On outbound query, connection is terminated. Nothing -> TerminatedState handleErrorM transition = mkTransition connState connState' - absConnState = abstractState (Known connState) - shouldTrace = absConnState /= TerminatedSt + absConnState = State.abstractState (Known connState) + shouldTransition = absConnState /= TerminatedSt -- 'handleError' might be either a handshake negotiation -- a protocol failure (an IO exception, a timeout or -- codec failure). In the first case we should not reset -- the connection as this is not a protocol error. - writeTVar connVar connState' + when shouldTransition $ do + writeTVar connVar connState' updated <- modifyTMVarPure stateVar ( \state -> - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> (state, False) Just mutableConnState' -> if mutableConnState' == mutableConnState - then (Map.delete peerAddr state , True) - else (state , False) + then (State.delete connId state, True) + else (state , False) ) if updated @@ -2087,7 +1834,7 @@ with args@Arguments { -- Key was present in the dictionary (stateVar) and -- removed so we trace the removal. return $ - if shouldTrace + if shouldTransition then [ transition , Transition { fromState = Known (TerminatedState Nothing) @@ -2121,14 +1868,14 @@ with args@Arguments { :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) -> StrictTVar m StdGen - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult AbstractState) - releaseOutboundConnectionImpl stateVar stdGenVar peerAddr = do - traceWith tracer (TrReleaseConnection Outbound peerAddr) + releaseOutboundConnectionImpl stateVar stdGenVar connId = do + traceWith tracer (TrReleaseConnection Outbound connId) (transition, mbAssertion) <- atomically $ do state <- readTMVar stateVar - case Map.lookup peerAddr state of + case State.lookup connId state of -- if the connection errored, it will remove itself from the state. -- Calling 'releaseOutboundConnection' is a no-op in this case. Nothing -> pure ( DemoteToColdLocalNoop Nothing UnknownConnectionSt @@ -2136,7 +1883,7 @@ with args@Arguments { Just MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of -- In any of the following three states releaseing is not -- supported. 'acquireOutboundConnection' is a synchronous @@ -2145,20 +1892,20 @@ with args@Arguments { ReservedOutboundState -> return ( DemoteToColdLocalError - (TrForbiddenOperation peerAddr st) + (TrForbiddenOperation (remoteAddress connId) st) st , Nothing ) - UnnegotiatedState _ _ _ -> + UnnegotiatedState {} -> return ( DemoteToColdLocalError - (TrForbiddenOperation peerAddr st) + (TrForbiddenOperation (remoteAddress connId) st) st , Nothing ) - OutboundUniState connId connThread handle -> do + OutboundUniState _connId connThread handle -> do -- @ -- DemotedToCold^{Unidirectional}_{Local} -- : OutboundState Unidirectional @@ -2172,7 +1919,7 @@ with args@Arguments { , Nothing ) - OutboundDupState connId connThread handle Expired -> do + OutboundDupState _connId connThread handle Expired -> do -- @ -- DemotedToCold^{Duplex}_{Local} -- : OutboundState Duplex @@ -2186,7 +1933,7 @@ with args@Arguments { , Nothing ) - OutboundDupState connId connThread handle Ticking -> do + OutboundDupState _connId connThread handle Ticking -> do let connState' = InboundIdleState connId connThread handle Duplex tr = mkTransition connState connState' @@ -2204,7 +1951,7 @@ with args@Arguments { if numberToPrune > 0 then do (_, prune) - <- mkPruneAction peerAddr numberToPrune state connState' connVar stdGenVar connThread + <- mkPruneAction connId numberToPrune state connState' connVar stdGenVar connThread return ( PruneConnections prune (Left connState) , Nothing @@ -2235,7 +1982,7 @@ with args@Arguments { return ( DemoteToColdLocalNoop Nothing st , Nothing ) - InboundState connId _connThread _handle dataFlow -> do + InboundState _connId _connThread _handle dataFlow -> do let mbAssertion = if dataFlow == Duplex then Nothing @@ -2246,12 +1993,12 @@ with args@Arguments { ) return ( DemoteToColdLocalError - (TrForbiddenOperation peerAddr st) + (TrForbiddenOperation (remoteAddress connId) st) st , mbAssertion ) - DuplexState connId connThread handle -> do + DuplexState _connId connThread handle -> do -- @ -- DemotedToCold^{Duplex}_{Local} : DuplexState -- → InboundState Duplex @@ -2286,8 +2033,8 @@ with args@Arguments { pure () case transition of - DemotedToColdLocal connId connThread connVar tr -> do - traceWith trTracer (TransitionTrace peerAddr tr) + DemotedToColdLocal _connId connThread connVar tr -> do + traceWith trTracer (TransitionTrace (remoteAddress connId) tr) traceCounters stateVar timeoutVar <- registerDelay outboundIdleTimeout r <- atomically $ runFirstToFinish $ @@ -2306,7 +2053,7 @@ with args@Arguments { Right connState -> do let connState' = TerminatingState connId connThread Nothing atomically $ writeTVar connVar connState' - traceWith trTracer (TransitionTrace peerAddr + traceWith trTracer (TransitionTrace (remoteAddress connId) (mkTransition connState connState')) traceCounters stateVar -- We rely on the `finally` handler of connection thread to: @@ -2316,26 +2063,26 @@ with args@Arguments { -- - 'throwTo' avoids blocking until 'timeWaitTimeout' expires. throwTo (asyncThreadId connThread) AsyncCancelled - return (OperationSuccess (abstractState $ Known connState')) + return (OperationSuccess (State.abstractState $ Known connState')) - Left connState | connectionTerminated connState + Left connState | State.connectionTerminated connState -> - return (OperationSuccess (abstractState $ Known connState)) + return (OperationSuccess (State.abstractState $ Known connState)) Left connState -> - return (UnsupportedState (abstractState $ Known connState)) + return (UnsupportedState (State.abstractState $ Known connState)) PruneConnections prune eTr -> do - traverse_ (traceWith trTracer . TransitionTrace peerAddr) eTr + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) eTr runPruneAction prune traceCounters stateVar - return (OperationSuccess (abstractState (either Known fromState eTr))) + return (OperationSuccess (State.abstractState (either Known fromState eTr))) DemoteToColdLocalError trace st -> do traceWith tracer trace return (UnsupportedState st) DemoteToColdLocalNoop tr a -> do - traverse_ (traceWith trTracer) (TransitionTrace peerAddr <$> tr) + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) tr traceCounters stateVar return (OperationSuccess a) @@ -2346,12 +2093,12 @@ with args@Arguments { promotedToWarmRemoteImpl :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) -> StrictTVar m StdGen - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult AbstractState) - promotedToWarmRemoteImpl stateVar stdGenVar peerAddr = mask_ $ do + promotedToWarmRemoteImpl stateVar stdGenVar connId = mask_ $ do (result, pruneTr, mbAssertion) <- atomically $ do state <- readTMVar stateVar - let mbConnVar = Map.lookup peerAddr state + let mbConnVar = State.lookup connId state case mbConnVar of Nothing -> return ( UnsupportedState UnknownConnectionSt , Nothing @@ -2359,7 +2106,7 @@ with args@Arguments { ) Just MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of ReservedOutboundState {} -> do return ( UnsupportedState st @@ -2370,7 +2117,7 @@ with args@Arguments { st) ) ) - UnnegotiatedState _ connId _ -> + UnnegotiatedState _ _connId _ -> return ( UnsupportedState st , Nothing , Just (TrUnexpectedlyFalseAssertion @@ -2379,7 +2126,7 @@ with args@Arguments { st) ) ) - OutboundUniState connId _connThread _handle -> + OutboundUniState _connId _connThread _handle -> return ( UnsupportedState st , Nothing , Just (TrUnexpectedlyFalseAssertion @@ -2388,7 +2135,7 @@ with args@Arguments { st) ) ) - OutboundDupState connId connThread handle _expired -> do + OutboundDupState _connId connThread handle _expired -> do -- @ -- PromotedToWarm^{Duplex}_{Remote} : OutboundState Duplex -- → DuplexState @@ -2418,7 +2165,7 @@ with args@Arguments { if numberToPrune > 0 then do (pruneSelf, prune) - <- mkPruneAction peerAddr numberToPrune state connState' connVar stdGenVar connThread + <- mkPruneAction connId numberToPrune state connState' connVar stdGenVar connThread when (not pruneSelf) $ writeTVar connVar connState' @@ -2435,7 +2182,7 @@ with args@Arguments { , Nothing , Nothing ) - OutboundIdleState connId connThread handle dataFlow@Duplex -> do + OutboundIdleState _connId connThread handle dataFlow@Duplex -> do -- @ -- Awake^{Duplex}_{Remote} : OutboundIdleState^\tau Duplex -- → InboundState Duplex @@ -2457,7 +2204,7 @@ with args@Arguments { if numberToPrune > 0 then do (pruneSelf, prune) - <- mkPruneAction peerAddr numberToPrune state connState' connVar stdGenVar connThread + <- mkPruneAction connId numberToPrune state connState' connVar stdGenVar connThread when (not pruneSelf) $ writeTVar connVar connState' @@ -2478,7 +2225,7 @@ with args@Arguments { , Nothing , Nothing ) - InboundIdleState connId connThread handle dataFlow -> do + InboundIdleState _connId connThread handle dataFlow -> do -- @ -- Awake^{dataFlow}_{Remote} : InboundIdleState Duplex -- → InboundState Duplex @@ -2489,7 +2236,7 @@ with args@Arguments { , Nothing , Nothing ) - InboundState connId _ _ _ -> + InboundState _connId _ _ _ -> return ( OperationSuccess (mkTransition connState connState) , Nothing -- already in 'InboundState'? @@ -2523,32 +2270,32 @@ with args@Arguments { -- trace transition case (result, pruneTr) of (OperationSuccess tr, Nothing) -> do - traceWith trTracer (TransitionTrace peerAddr tr) + traceWith trTracer (TransitionTrace (remoteAddress connId) tr) traceCounters stateVar (OperationSuccess tr, Just prune) -> do - traceWith trTracer (TransitionTrace peerAddr tr) + traceWith trTracer (TransitionTrace (remoteAddress connId) tr) runPruneAction prune traceCounters stateVar _ -> return () - return (abstractState . fromState <$> result) + return (State.abstractState . fromState <$> result) demotedToColdRemoteImpl :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult AbstractState) - demotedToColdRemoteImpl stateVar peerAddr = do + demotedToColdRemoteImpl stateVar connId = do (result, mbAssertion) <- atomically $ do - mbConnVar <- Map.lookup peerAddr <$> readTMVar stateVar + mbConnVar <- State.lookup connId <$> readTMVar stateVar case mbConnVar of Nothing -> return ( UnsupportedState UnknownConnectionSt , Nothing ) Just MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of ReservedOutboundState {} -> do return ( UnsupportedState st @@ -2558,7 +2305,7 @@ with args@Arguments { st) ) ) - UnnegotiatedState _ connId _ -> + UnnegotiatedState _ _connId _ -> return ( UnsupportedState st , Just (TrUnexpectedlyFalseAssertion (DemotedToColdRemote @@ -2566,7 +2313,7 @@ with args@Arguments { st) ) ) - OutboundUniState connId _connThread _handle -> + OutboundUniState _connId _connThread _handle -> return ( UnsupportedState st , Just (TrUnexpectedlyFalseAssertion (DemotedToColdRemote @@ -2594,7 +2341,7 @@ with args@Arguments { -- : InboundState dataFlow -- → InboundIdleState^\tau dataFlow -- @ - InboundState connId connThread handle dataFlow -> do + InboundState _connId connThread handle dataFlow -> do let connState' = InboundIdleState connId connThread handle dataFlow writeTVar connVar connState' return ( OperationSuccess (mkTransition connState connState') @@ -2606,7 +2353,7 @@ with args@Arguments { -- : DuplexState -- → OutboundState^\tau Duplex -- @ - DuplexState connId connThread handle -> do + DuplexState _connId connThread handle -> do let connState' = OutboundDupState connId connThread handle Ticking writeTVar connVar connState' return ( OperationSuccess (mkTransition connState connState') @@ -2630,11 +2377,11 @@ with args@Arguments { -- trace transition case result of OperationSuccess tr -> - traceWith trTracer (TransitionTrace peerAddr tr) + traceWith trTracer (TransitionTrace (remoteAddress connId) tr) _ -> return () traceCounters stateVar - return (abstractState . fromState <$> result) + return (State.abstractState . fromState <$> result) -- @@ -2697,7 +2444,7 @@ withCallStack k = k callStack -- data Trace peerAddr handlerTrace = TrIncludeConnection Provenance peerAddr - | TrReleaseConnection Provenance peerAddr + | TrReleaseConnection Provenance (ConnectionId peerAddr) | TrConnect (Maybe peerAddr) -- ^ local address peerAddr -- ^ remote address | TrConnectError (Maybe peerAddr) -- ^ local address @@ -2712,14 +2459,14 @@ data Trace peerAddr handlerTrace | TrConnectionFailure (ConnectionId peerAddr) | TrConnectionNotFound Provenance peerAddr | TrForbiddenOperation peerAddr AbstractState - | TrPruneConnections (Set peerAddr) -- ^ pruning set + | TrPruneConnections (Set (ConnectionId peerAddr)) -- ^ pruning set Int -- ^ number connections that must be pruned - (Set peerAddr) -- ^ choice set + (Set (ConnectionId peerAddr)) -- ^ choice set | TrConnectionCleanup (ConnectionId peerAddr) | TrConnectionTimeWait (ConnectionId peerAddr) | TrConnectionTimeWaitDone (ConnectionId peerAddr) | TrConnectionManagerCounters ConnectionManagerCounters - | TrState (Map peerAddr AbstractState) + | TrState (State.ConnMap peerAddr AbstractState) -- ^ traced on SIGUSR1 signal, installed in 'runDataDiffusion' | TrUnexpectedlyFalseAssertion (AssertionLocation peerAddr) -- ^ This case is unexpected at call site. diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs new file mode 100644 index 0000000000..10031f5875 --- /dev/null +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs @@ -0,0 +1,277 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +-- Undecidable instances are need for 'Show' instance of 'ConnectionState'. +{-# LANGUAGE QuantifiedConstraints #-} + +module Ouroboros.Network.ConnectionManager.State + ( -- * ConnectionManagerState API + ConnectionManagerState + , module ConnMap + -- ** Monadic API + , readConnectionStates + , readAbstractStateMap + -- * MutableConnState + , MutableConnState (..) + , FreshIdSupply + , newFreshIdSupply + , newMutableConnState + -- * ConnectionState + , ConnectionState (..) + , abstractState + , connectionTerminated + ) where + +import Control.Concurrent.Class.MonadSTM.Strict +import Control.Monad.Class.MonadAsync +import Data.Function (on) +import Data.Proxy (Proxy (..)) +import Data.Typeable (Typeable) +import Prelude hiding (lookup) + +import Ouroboros.Network.ConnectionId +import Ouroboros.Network.ConnectionManager.ConnMap as ConnMap +import Ouroboros.Network.ConnectionManager.Types + +import Ouroboros.Network.Testing.Utils (WithName (..)) + +-- | 'ConnectionManager' state: for each peer we keep a 'ConnectionState' in +-- a mutable variable, which reduces congestion on the 'TMVar' which keeps +-- 'ConnectionManagerState'. +-- +type ConnectionManagerState peerAddr handle handleError version m = + ConnMap peerAddr (MutableConnState peerAddr handle handleError version m) + + +readConnectionStates + :: MonadSTM m + => ConnectionManagerState peerAddr handle handleError version m + -> STM m (ConnMap peerAddr (ConnectionState peerAddr handle handleError version m)) +readConnectionStates = traverse (readTVar . connVar) + + +readAbstractStateMap + :: MonadSTM m + => ConnectionManagerState peerAddr handle handleError version m + -> STM m (ConnMap peerAddr AbstractState) +readAbstractStateMap = traverse (fmap (abstractState . Known) . readTVar . connVar) + +-- | 'MutableConnState', which supplies a unique identifier. +-- +-- TODO: We can get away without id, by tracking connections in +-- `TerminatingState` using a separate priority search queue. +-- +data MutableConnState peerAddr handle handleError version m = MutableConnState { + -- | A unique identifier + -- + connStateId :: !Int + + , -- | Mutable state + -- + connVar :: !(StrictTVar m (ConnectionState peerAddr handle handleError + version m)) + } + + +instance Eq (MutableConnState peerAddr handle handleError version m) where + (==) = (==) `on` connStateId + + +-- | A supply of fresh id's. +-- +-- We use a fresh ids for 'MutableConnState'. +-- +newtype FreshIdSupply m = FreshIdSupply { getFreshId :: STM m Int } + + +-- | Create a 'FreshIdSupply' inside an 'STM' monad. +-- +newFreshIdSupply :: forall m. MonadSTM m + => Proxy m -> STM m (FreshIdSupply m) +newFreshIdSupply _ = do + (v :: StrictTVar m Int) <- newTVar 0 + let getFreshId :: STM m Int + getFreshId = do + c <- readTVar v + writeTVar v (succ c) + return c + return $ FreshIdSupply { getFreshId } + + +newMutableConnState :: forall peerAddr handle handleError version m. + ( MonadTraceSTM m + , Typeable peerAddr + ) + => peerAddr + -> FreshIdSupply m + -> ConnectionState peerAddr handle handleError + version m + -> STM m (MutableConnState peerAddr handle handleError + version m) +newMutableConnState peerAddr freshIdSupply connState = do + connStateId <- getFreshId freshIdSupply + connVar <- newTVar connState + -- This tracing is a no op in IO. + -- + -- We need this for IOSimPOR testing of connection manager state + -- transition tests. It can happen that the transitions happen + -- correctly but IOSimPOR reorders the threads that log the transitions. + -- This is a false positive and we don't want that to happen. + -- + -- The simplest way to do so is to leverage the `traceTVar` IOSim + -- capabilities. These trace messages won't be reordered by IOSimPOR + -- since these happen atomically in STM. + -- + traceTVar + (Proxy @m) connVar + (\mbPrev curr -> + let currAbs = abstractState (Known curr) + in case mbPrev of + Just prev | + let prevAbs = abstractState (Known prev) + , prevAbs /= currAbs -> pure + $ TraceDynamic + $ WithName connStateId + $ TransitionTrace peerAddr + $ mkAbsTransition prevAbs + currAbs + Nothing -> pure + $ TraceDynamic + $ WithName connStateId + $ TransitionTrace peerAddr + $ mkAbsTransition TerminatedSt + currAbs + _ -> pure DontTrace + ) + return $ MutableConnState { connStateId, connVar } + + +abstractState :: MaybeUnknown (ConnectionState muxMode peerAddr m a b) + -> AbstractState +abstractState = \case + Unknown -> UnknownConnectionSt + Race s' -> go s' + Known s' -> go s' + where + go :: ConnectionState muxMode peerAddr m a b -> AbstractState + go ReservedOutboundState {} = ReservedOutboundSt + go (UnnegotiatedState pr _ _) = UnnegotiatedSt pr + go OutboundUniState {} = OutboundUniSt + go (OutboundDupState _ _ _ te) = OutboundDupSt te + go (OutboundIdleState _ _ _ df) = OutboundIdleSt df + go (InboundIdleState _ _ _ df) = InboundIdleSt df + go (InboundState _ _ _ df) = InboundSt df + go DuplexState {} = DuplexSt + go TerminatingState {} = TerminatingSt + go TerminatedState {} = TerminatedSt + + +-- | State of a connection. +-- +data ConnectionState peerAddr handle handleError version m = + -- | Each outbound connections starts in this state. + ReservedOutboundState + + -- | Each inbound connection starts in this state, outbound connection + -- reach this state once `connect` call returns. + -- + -- note: the async handle is lazy, because it's passed with 'mfix'. + | UnnegotiatedState !Provenance + !(ConnectionId peerAddr) + (Async m ()) + + -- | @OutboundState Unidirectional@ state. + | OutboundUniState !(ConnectionId peerAddr) !(Async m ()) !handle + + -- | Either @OutboundState Duplex@ or @OutboundState^\tau Duplex@. + | OutboundDupState !(ConnectionId peerAddr) !(Async m ()) !handle !TimeoutExpired + + -- | Before connection is reset it is put in 'OutboundIdleState' for the + -- duration of 'outboundIdleTimeout'. + -- + | OutboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow + | InboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow + | InboundState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow + | DuplexState !(ConnectionId peerAddr) !(Async m ()) !handle + | TerminatingState !(ConnectionId peerAddr) !(Async m ()) !(Maybe handleError) + | TerminatedState !(Maybe handleError) + + +instance ( Show peerAddr + -- , Show handleError + , MonadAsync m + ) + => Show (ConnectionState peerAddr handle handleError version m) where + show ReservedOutboundState = "ReservedOutboundState" + show (UnnegotiatedState pr connId connThread) = + concat ["UnnegotiatedState " + , show pr + , " " + , show connId + , " " + , show (asyncThreadId connThread) + ] + show (OutboundUniState connId connThread _handle) = + concat [ "OutboundState Unidirectional " + , show connId + , " " + , show (asyncThreadId connThread) + ] + show (OutboundDupState connId connThread _handle expired) = + concat [ "OutboundState " + , show connId + , " " + , show (asyncThreadId connThread) + , " " + , show expired + ] + show (OutboundIdleState connId connThread _handle df) = + concat [ "OutboundIdleState " + , show connId + , " " + , show (asyncThreadId connThread) + , " " + , show df + ] + show (InboundIdleState connId connThread _handle df) = + concat [ "InboundIdleState " + , show connId + , " " + , show (asyncThreadId connThread) + , " " + , show df + ] + show (InboundState connId connThread _handle df) = + concat [ "InboundState " + , show connId + , " " + , show (asyncThreadId connThread) + , " " + , show df + ] + show (DuplexState connId connThread _handle) = + concat [ "DuplexState " + , show connId + , " " + , show (asyncThreadId connThread) + ] + show (TerminatingState connId connThread _handleError) = + concat ([ "TerminatingState " + , show connId + , " " + , show (asyncThreadId connThread) + ] + )-- ++ maybeToList ((' ' :) . show <$> handleError)) + show (TerminatedState _handleError) = + concat (["TerminatedState"] + )-- ++ maybeToList ((' ' :) . show <$> handleError)) + + +-- | Return 'True' for states in which the connection was already closed. +-- +connectionTerminated :: ConnectionState peerAddr handle handleError version m + -> Bool +connectionTerminated TerminatingState {} = True +connectionTerminated TerminatedState {} = True +connectionTerminated _ = False diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs index cb981524c9..0e51c54129 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs @@ -177,7 +177,8 @@ import System.Random (StdGen) import Network.Mux.Types (HasInitiator, HasResponder, MiniProtocolDir) import Network.Mux.Types qualified as Mux -import Ouroboros.Network.ConnectionId (ConnectionId) +import Ouroboros.Network.ConnectionId (ConnectionId (..)) +import Ouroboros.Network.ConnectionManager.ConnMap (ConnMap) import Ouroboros.Network.MuxMode @@ -423,9 +424,9 @@ data HandleErrorType = -- connections. -- type PrunePolicy peerAddr = StdGen - -> Map peerAddr ConnectionType + -> Map (ConnectionId peerAddr) ConnectionType -> Int - -> Set peerAddr + -> Set (ConnectionId peerAddr) -- | The simplest 'PrunePolicy', it should only be used for tests. @@ -500,7 +501,7 @@ type IncludeInboundConnection socket peerAddr handle handleError m -- ^ inbound connections hard limit. -- NOTE: Check TODO over at includeInboundConnectionImpl -- definition. - -> socket -> peerAddr -> m (Connected peerAddr handle handleError) + -> socket -> ConnectionId peerAddr -> m (Connected peerAddr handle handleError) -- | Outbound connection manager API. @@ -509,7 +510,7 @@ data OutboundConnectionManager (muxMode :: Mux.Mode) socket peerAddr handle hand OutboundConnectionManager :: HasInitiator muxMode ~ True => { ocmAcquireConnection :: AcquireOutboundConnection peerAddr handle handleError m - , ocmReleaseConnection :: peerAddr -> m (OperationResult AbstractState) + , ocmReleaseConnection :: ConnectionId peerAddr -> m (OperationResult AbstractState) } -> OutboundConnectionManager muxMode socket peerAddr handle handleError m @@ -522,10 +523,10 @@ data InboundConnectionManager (muxMode :: Mux.Mode) socket peerAddr handle handl InboundConnectionManager :: HasResponder muxMode ~ True => { icmIncludeConnection :: IncludeInboundConnection socket peerAddr handle handleError m - , icmReleaseConnection :: peerAddr -> m (OperationResult DemotedToColdRemoteTr) - , icmPromotedToWarmRemote :: peerAddr -> m (OperationResult AbstractState) + , icmReleaseConnection :: ConnectionId peerAddr -> m (OperationResult DemotedToColdRemoteTr) + , icmPromotedToWarmRemote :: ConnectionId peerAddr -> m (OperationResult AbstractState) , icmDemotedToColdRemote - :: peerAddr -> m (OperationResult AbstractState) + :: ConnectionId peerAddr -> m (OperationResult AbstractState) , icmNumberOfConnections :: STM m Int } -> InboundConnectionManager muxMode socket peerAddr handle handleError m @@ -548,13 +549,13 @@ data ConnectionManager (muxMode :: Mux.Mode) socket peerAddr handle handleError (InboundConnectionManager muxMode socket peerAddr handle handleError m), readState - :: STM m (Map peerAddr AbstractState), + :: STM m (ConnMap peerAddr AbstractState), -- | This STM action will block until the given connection is fully -- closed/terminated. If the connection manager doesn't have any connection to -- that peer it won't block. waitForOutboundDemotion - :: peerAddr + :: ConnectionId peerAddr -> STM m () } @@ -584,7 +585,7 @@ acquireOutboundConnection = releaseOutboundConnection :: HasInitiator muxMode ~ True => ConnectionManager muxMode socket peerAddr handle handleError m - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult AbstractState) -- ^ reports the from-state. releaseOutboundConnection = @@ -603,7 +604,7 @@ releaseOutboundConnection = promotedToWarmRemote :: HasResponder muxMode ~ True => ConnectionManager muxMode socket peerAddr handle handleError m - -> peerAddr -> m (OperationResult AbstractState) + -> ConnectionId peerAddr -> m (OperationResult AbstractState) promotedToWarmRemote = icmPromotedToWarmRemote . withResponderMode . getConnectionManager @@ -619,7 +620,7 @@ promotedToWarmRemote = demotedToColdRemote :: HasResponder muxMode ~ True => ConnectionManager muxMode socket peerAddr handle handleError m - -> peerAddr -> m (OperationResult AbstractState) + -> ConnectionId peerAddr -> m (OperationResult AbstractState) demotedToColdRemote = icmDemotedToColdRemote . withResponderMode . getConnectionManager @@ -636,7 +637,7 @@ includeInboundConnection includeInboundConnection = icmIncludeConnection . withResponderMode . getConnectionManager --- | Release outbound connection. Returns if the operation was successful. +-- | Release inbound connection. Returns if the operation was successful. -- -- This executes: -- @@ -645,7 +646,7 @@ includeInboundConnection = releaseInboundConnection :: HasResponder muxMode ~ True => ConnectionManager muxMode socket peerAddr handle handleError m - -> peerAddr -> m (OperationResult DemotedToColdRemoteTr) + -> ConnectionId peerAddr -> m (OperationResult DemotedToColdRemoteTr) releaseInboundConnection = icmReleaseConnection . withResponderMode . getConnectionManager @@ -838,10 +839,10 @@ connectionManagerErrorFromException x = do -- data AssertionLocation peerAddr = ReleaseInboundConnection !(Maybe (ConnectionId peerAddr)) !AbstractState - | AcquireOutboundConnection !(Maybe (ConnectionId peerAddr)) !AbstractState + | AcquireOutboundConnection !(Maybe (ConnectionId peerAddr)) !AbstractState | ReleaseOutboundConnection !(Maybe (ConnectionId peerAddr)) !AbstractState - | PromotedToWarmRemote !(Maybe (ConnectionId peerAddr)) !AbstractState - | DemotedToColdRemote !(Maybe (ConnectionId peerAddr)) !AbstractState + | PromotedToWarmRemote !(Maybe (ConnectionId peerAddr)) !AbstractState + | DemotedToColdRemote !(Maybe (ConnectionId peerAddr)) !AbstractState deriving Show @@ -888,7 +889,7 @@ mkAbsTransition from to = Transition { fromState = from } data TransitionTrace' peerAddr state = TransitionTrace - { ttPeerAddr :: peerAddr + { ttPeerAddr :: peerAddr -- TODO: use ConnectionId , ttTransition :: Transition' state } deriving Functor diff --git a/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs b/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs index b4827b2948..93a95045cf 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs @@ -39,8 +39,8 @@ module Ouroboros.Network.InboundGovernor import Control.Applicative (Alternative) import Control.Concurrent.Class.MonadSTM qualified as LazySTM import Control.Concurrent.Class.MonadSTM.Strict -import Control.Exception (SomeAsyncException (..), assert) -import Control.Monad (foldM, when) +import Control.Exception (SomeAsyncException (..)) +import Control.Monad (foldM) import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTime.SI @@ -211,12 +211,12 @@ with ) <> Map.foldMapWithKey ( firstMuxToFinish + <> firstPeerDemotedToCold + <> firstPeerCommitRemote <> firstMiniProtocolToFinish connectionDataFlow <> firstPeerPromotedToWarm <> firstPeerPromotedToHot <> firstPeerDemotedToWarm - <> firstPeerDemotedToCold - <> firstPeerCommitRemote :: EventSignal muxMode initiatorCtx peerAddr versionData m a b ) @@ -403,8 +403,7 @@ with -- @ -- NOTE: `demotedToColdRemote` doesn't throw, hence exception handling -- is not needed. - res <- demotedToColdRemote connectionManager - (remoteAddress connId) + res <- demotedToColdRemote connectionManager connId traceWith tracer (TrWaitIdleRemote connId res) case res of TerminatedConnection {} -> do @@ -445,24 +444,19 @@ with -- -- NOTE: `promotedToWarmRemote` doesn't throw, hence exception handling -- is not needed. - res <- promotedToWarmRemote connectionManager - (remoteAddress connId) + res <- promotedToWarmRemote connectionManager connId traceWith tracer (TrPromotedToWarmRemote connId res) - when (resultInState res == UnknownConnectionSt) $ do - traceWith tracer (TrUnexpectedlyFalseAssertion - (InboundGovernorLoop - (Just connId) - UnknownConnectionSt) - ) - evaluate (assert False ()) - - let state' = updateRemoteState - connId - RemoteWarm - state - - return (Just connId, state') + case resultInState res of + UnknownConnectionSt -> do + let state' = unregisterConnection connId state + return (Just connId, state') + _ -> do + let state' = updateRemoteState + connId + RemoteWarm + state + return (Just connId, state') RemotePromotedToHot connId -> do traceWith tracer (TrPromotedToHotRemote connId) @@ -479,8 +473,7 @@ with CommitRemote connId -> do -- NOTE: `releaseInboundConnection` doesn't throw, hence exception -- handling is not needed. - res <- releaseInboundConnection connectionManager - (remoteAddress connId) + res <- releaseInboundConnection connectionManager connId traceWith tracer $ TrDemotedToColdRemote connId res case res of UnsupportedState {} -> do diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Server2.hs b/ouroboros-network-framework/src/Ouroboros/Network/Server2.hs index 4ceda99fb2..f8d65d4084 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Server2.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Server2.hs @@ -49,6 +49,7 @@ import Foreign.C.Error import Network.Mux qualified as Mx import Ouroboros.Network.ConnectionHandler +import Ouroboros.Network.ConnectionId (ConnectionId (..)) import Ouroboros.Network.ConnectionManager.InformationChannel (InboundGovernorInfoChannel) import Ouroboros.Network.ConnectionManager.Types @@ -236,15 +237,19 @@ with Arguments { -- no need to use a rethrow policy _ -> throwIO err - (Accepted socket peerAddr, acceptNext) -> + (Accepted socket remoteAddress, acceptNext) -> (do - traceWith tracer (TrAcceptConnection peerAddr) + localAddress' <- getLocalAddr snocket socket + let connId = ConnectionId { localAddress = localAddress', + remoteAddress } + traceWith tracer (TrAcceptConnection connId) async $ - do a <- + do + a <- unmask (includeInboundConnection connectionManager - hardLimit socket peerAddr) + hardLimit socket connId) case a of Connected {} -> pure () Disconnected {} -> close snocket socket @@ -281,7 +286,7 @@ isECONNABORTED _ = False -- data Trace peerAddr - = TrAcceptConnection peerAddr + = TrAcceptConnection (ConnectionId peerAddr) | TrAcceptError SomeException | TrAcceptPolicyTrace AcceptConnectionsPolicyTrace | TrServerStarted [peerAddr] diff --git a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs index 8029d08e50..3a80257926 100644 --- a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs @@ -90,7 +90,7 @@ data Connection m addr = Connection { -- | Attenuated channels of a connection. -- connChannelLocal :: !(AttenuatedChannel m) - , connChannelRemote :: !(AttenuatedChannel m) + , connChannelRemote :: AttenuatedChannel m -- | SDU size of a connection. -- @@ -139,12 +139,36 @@ mkConnection :: ( MonadDelay m , MonadTimer m , MonadThrow m , MonadThrow (STM m) + , Eq addr ) => Tracer m (WithAddr (TestAddress addr) (SnocketTrace m (TestAddress addr))) -> BearerInfo -> ConnectionId (TestAddress addr) -> STM m (Connection m (TestAddress addr)) + +mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } | localAddress == remoteAddress = do + -- we are connecting to onself. On Linux this returns a connection which + -- mirrors all sent data. + qc <- echoQueueChannel + channel <- newAttenuatedChannel + ( ( WithAddr (Just localAddress) (Just remoteAddress) + . STAttenuatedChannelTrace connId + ) + `contramap` tr) + Attenuation + { aReadAttenuation = biOutboundAttenuation bearerInfo + , aWriteAttenuation = biOutboundWriteFailure bearerInfo + } + qc + return Connection { + connChannelLocal = channel, + connChannelRemote = undefined, + connSDUSize = biSDUSize bearerInfo, + connState = SYN_SENT, + connProvider = localAddress + } + mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = (\(connChannelLocal, connChannelRemote) -> Connection { @@ -895,11 +919,11 @@ mkSnocket state tr = Snocket { getLocalAddr connMap <- readTVar (nsConnections state) case Map.lookup normalisedId connMap of - Just Connection { connState = ESTABLISHED } -> + Just Connection { connState = ESTABLISHED } -> throwSTM (connectedIOError fd_) - Just Connection { connState = SYN_SENT, connProvider } - | connProvider == localAddress -> + Just Connection { connState = SYN_SENT, connProvider } + | connProvider == localAddress -> throwSTM (connectedIOError fd_) -- simultaneous open @@ -979,7 +1003,7 @@ mkSnocket state tr = Snocket { getLocalAddr <$> readTVar (nsConnections state) case lstFd of -- error cases - (Nothing) -> + Nothing -> return (Left (connectIOError connId "no such listening socket")) (Just FDUninitialised {}) -> return (Left (connectIOError connId "unitialised listening socket")) @@ -1009,13 +1033,16 @@ mkSnocket state tr = Snocket { getLocalAddr Just conn@Connection { connState = SYN_SENT } -> do let fd_' = FDConnected connId conn writeTVar fdVarLocal fd_' - writeTBQueue queue - ChannelWithInfo - { cwiAddress = localAddress connId - , cwiSDUSize = biSDUSize bearerInfo - , cwiChannelLocal = connChannelRemote conn - , cwiChannelRemote = connChannelLocal conn - } + when (localAddress connId /= remoteAddress) $ + -- We only write to the accept `queue` if we're not + -- connecting to ourselves. + writeTBQueue queue + ChannelWithInfo + { cwiAddress = localAddress connId + , cwiSDUSize = biSDUSize bearerInfo + , cwiChannelLocal = connChannelRemote conn + , cwiChannelRemote = connChannelLocal conn + } return (Right (fd_', NormalOpen)) Just Connection { connState = FIN } -> do @@ -1049,15 +1076,15 @@ mkSnocket state tr = Snocket { getLocalAddr (\e -> atomically $ modifyTVar (nsConnections state) (Map.delete (normaliseId connId)) >> throwIO e) - $ unmask (atomically $ runFirstToFinish $ - (FirstToFinish $ do + $ unmask . atomically . runFirstToFinish $ + FirstToFinish (do LazySTM.readTVar timeoutVar >>= check modifyTVar (nsConnections state) (Map.delete (normaliseId connId)) return Nothing ) <> - (FirstToFinish $ do + FirstToFinish (do mbConn <- Map.lookup (normaliseId connId) <$> readTVar (nsConnections state) case mbConn of @@ -1072,12 +1099,16 @@ mkSnocket state tr = Snocket { getLocalAddr ++ show (normaliseId connId) Just Connection { connState } -> Just <$> check (connState == ESTABLISHED)) - ) case r of + -- self connect + Nothing | localAddress connId == remoteAddress + -> traceWith' fd (STConnected fd_' o) + Nothing -> do traceWith' fd (STConnectTimeout WaitingToBeAccepted) throwIO (connectIOError connId "connect timeout: when waiting for being accepted") + Just _ -> traceWith' fd (STConnected fd_' o) FDConnecting {} -> diff --git a/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Experiments.hs b/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Experiments.hs index 0116343488..aa962773b6 100644 --- a/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Experiments.hs +++ b/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Experiments.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} @@ -735,7 +736,9 @@ unidirectionalExperiment stdGen timeouts snocket makeBearer confSock socket clie (numberOfRounds clientAndServerData) (bracket (acquireOutboundConnection connectionManager serverAddr) - (\_ -> releaseOutboundConnection connectionManager serverAddr) + (\case + Connected connId _ _ -> releaseOutboundConnection connectionManager connId + Disconnected {} -> error "unidirectionalExperiment: impossible happened") (\connHandle -> do case connHandle of Connected connId _ (Handle mux muxBundle controlBundle _ @@ -834,10 +837,13 @@ bidirectionalExperiment (acquireOutboundConnection connectionManager0 localAddr1)) - (\_ -> - releaseOutboundConnection - connectionManager0 - localAddr1) + (\case + Connected connId _ _ -> + releaseOutboundConnection + connectionManager0 + connId + Disconnected {} -> + error "bidirectionalExperiment: impossible happened") (\connHandle -> case connHandle of Connected connId _ (Handle mux muxBundle controlBundle _) -> do @@ -856,10 +862,13 @@ bidirectionalExperiment (acquireOutboundConnection connectionManager1 localAddr0)) - (\_ -> - releaseOutboundConnection - connectionManager1 - localAddr0) + (\case + Connected connId _ _ -> + releaseOutboundConnection + connectionManager1 + connId + Disconnected {} -> + error "ibidirectionalExperiment: impossible happened") (\connHandle -> case connHandle of Connected connId _ (Handle mux muxBundle controlBundle _) -> do diff --git a/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Utils.hs b/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Utils.hs index ec6e7fe489..710bc03472 100644 --- a/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Utils.hs +++ b/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Utils.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} module Ouroboros.Network.ConnectionManager.Test.Utils where @@ -193,20 +194,22 @@ validTransitionMap t@Transition { fromState, toState } = -- Assuming all transitions in the transition list are valid, we only need to -- look at the 'toState' of the current transition and the 'fromState' of the -- next transition. -verifyAbstractTransitionOrder :: Bool -- ^ Check last transition: useful for +verifyAbstractTransitionOrder :: forall a. Show a + => (a -> AbstractTransition) + -> Bool -- ^ Check last transition: useful for -- distinguish Diffusion layer tests -- vs non-Diffusion ones. - -> [AbstractTransition] + -> [a] -> All -verifyAbstractTransitionOrder _ [] = mempty -verifyAbstractTransitionOrder checkLast (h:t) = go t h +verifyAbstractTransitionOrder _ _ [] = mempty +verifyAbstractTransitionOrder get checkLast (h:t) = go t h where - go :: [AbstractTransition] -> AbstractTransition -> All + go :: [a] -> a -> All -- All transitions must end in the 'UnknownConnectionSt', and since we -- assume that all transitions are valid we do not have to check the -- 'fromState'. - go [] (Transition _ UnknownConnectionSt) = mempty - go [] tr@(Transition _ _) = + go [] a | (Transition _ UnknownConnectionSt) <- get a = mempty + go [] a | tr@(Transition _ _) <- get a = All $ counterexample ("\nUnexpected last transition: " ++ show tr) @@ -214,14 +217,14 @@ verifyAbstractTransitionOrder checkLast (h:t) = go t h -- All transitions have to be in a correct order, which means that the -- current state we are looking at (current toState) needs to be equal to -- the next 'fromState', in order for the transition chain to be correct. - go (next@(Transition nextFromState _) : ts) - curr@(Transition _ currToState) = + go (a : as) b | (Transition nextFromState _) <- get a + , (Transition _ currToState) <- get b = All (counterexample ("\nUnexpected transition order!\nWent from: " - ++ show curr ++ "\nto: " ++ show next) + ++ show b ++ "\nto: " ++ show a) (property (currToState == nextFromState))) - <> go ts next + <> go as a -- | List of all valid transition's names. diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs index a9b07a1ddf..37796ec5d8 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs @@ -44,7 +44,6 @@ import System.Random (mkStdGen) import Network.DNS.Types qualified as DNS -import Ouroboros.Network.Block (BlockNo (..)) import Ouroboros.Network.BlockFetch (PraosFetchMode (..), TraceFetchClientState (..)) import Ouroboros.Network.ConnectionHandler (ConnectionHandlerTrace) @@ -167,8 +166,6 @@ tests = , nightlyTest $ testProperty "steps" (testWithIOSimPOR prop_churn_steps 10000) ] - , testGroup "unit" - [ nightlyTest $ testProperty "unit cm" unit_cm_valid_transitions ] ] , testGroup "IOSim" [ testProperty "no failure" @@ -278,8 +275,10 @@ testWithIOSim f traceNumber bi ds = sim = diffusionSimulation (toBearerInfo bi) ds iosimTracer + trace = runSimTrace sim in labelDiffusionScript ds - $ f (runSimTrace sim) traceNumber + $ counterexample (Trace.ppTrace show (ppSimEvent 0 0 0) $ Trace.take traceNumber trace) + $ f trace traceNumber testWithIOSimPOR :: (SimTrace Void -> Int -> Property) -> Int @@ -297,154 +296,6 @@ testWithIOSimPOR f traceNumber bi ds = $ exploreSimTrace id sim $ \_ ioSimTrace -> f ioSimTrace traceNumber --- | This test checks a IOSimPOR false positive bug with the connection --- manager state transition traces no longer happens. --- -unit_cm_valid_transitions :: Property -unit_cm_valid_transitions = - let bi = AbsBearerInfo - { abiConnectionDelay = SmallDelay - , abiInboundAttenuation = NoAttenuation FastSpeed - , abiOutboundAttenuation = NoAttenuation FastSpeed - , abiInboundWriteFailure = Nothing - , abiOutboundWriteFailure = Just 0 - , abiAcceptFailure = Nothing - , abiSDUSize = LargeSDU - } - ds = DiffusionScript - (SimArgs 1 10) - (Script ((Map.empty, ShortDelay) :| [(Map.empty, LongDelay)])) - [ ( NodeArgs - (-2) - InitiatorAndResponderDiffusionMode - (Just 269) - (Map.fromList [(RelayAccessAddress "0:71:0:1:0:1:0:1" 65534, - DoAdvertisePeer)]) - GenesisMode - (Script (DontUseBootstrapPeers :| [])) - (TestAddress (IPAddr (read "0:79::1:0:0") 3)) - PeerSharingDisabled - [ (HotValency {getHotValency = 1}, - WarmValency {getWarmValency = 1}, - Map.fromList [(RelayAccessAddress "0:71:0:1:0:1:0:1" 65534, - (DoAdvertisePeer, IsTrustable))]) - ] - (Script (LedgerPools [] :| [])) - (ConsensusModePeerTargets - { deadlineTargets = PeerSelectionTargets - { targetNumberOfRootPeers = 4 - , targetNumberOfKnownPeers = 4 - , targetNumberOfEstablishedPeers = 3 - , targetNumberOfActivePeers = 2 - , targetNumberOfKnownBigLedgerPeers = 4 - , targetNumberOfEstablishedBigLedgerPeers = 1 - , targetNumberOfActiveBigLedgerPeers = 1 - } - , syncTargets = PeerSelectionTargets - { targetNumberOfRootPeers = 0 - , targetNumberOfKnownPeers = 4 - , targetNumberOfEstablishedPeers = 0 - , targetNumberOfActivePeers = 0 - , targetNumberOfKnownBigLedgerPeers = 4 - , targetNumberOfEstablishedBigLedgerPeers = 4 - , targetNumberOfActiveBigLedgerPeers = 3 - } - }) - (Script (DNSTimeout {getDNSTimeout = 0.325} :| [])) - (Script (DNSLookupDelay {getDNSLookupDelay = 0.1} :| - [DNSLookupDelay {getDNSLookupDelay = 0.072}])) - Nothing - False - (Script (FetchModeBulkSync :| [FetchModeBulkSync])) - , [JoinNetwork 0.5] - ) - , ( NodeArgs - 0 - InitiatorAndResponderDiffusionMode - (Just 90) - Map.empty - GenesisMode - (Script (DontUseBootstrapPeers :| [])) - (TestAddress (IPAddr (read "0:71:0:1:0:1:0:1") 65534)) - PeerSharingEnabled - [ (HotValency {getHotValency = 1}, - WarmValency {getWarmValency = 1}, - Map.fromList [(RelayAccessAddress "0:79::1:0:0" 3, - (DoNotAdvertisePeer, IsTrustable))]) - ] - (Script (LedgerPools [] :| [])) - (ConsensusModePeerTargets - { deadlineTargets = PeerSelectionTargets - { targetNumberOfRootPeers = 1 - , targetNumberOfKnownPeers = 1 - , targetNumberOfEstablishedPeers = 1 - , targetNumberOfActivePeers = 1 - , targetNumberOfKnownBigLedgerPeers = 4 - , targetNumberOfEstablishedBigLedgerPeers = 3 - , targetNumberOfActiveBigLedgerPeers = 3 - } - , syncTargets = PeerSelectionTargets - { targetNumberOfRootPeers = 0 - , targetNumberOfKnownPeers = 1 - , targetNumberOfEstablishedPeers = 1 - , targetNumberOfActivePeers = 1 - , targetNumberOfKnownBigLedgerPeers = 4 - , targetNumberOfEstablishedBigLedgerPeers = 2 - , targetNumberOfActiveBigLedgerPeers = 2 - } - }) - (Script (DNSTimeout {getDNSTimeout = 0.18} :| [])) - (Script (DNSLookupDelay {getDNSLookupDelay = 0.125} :| [])) - (Just (BlockNo 2)) - False - (Script (FetchModeDeadline :| [])) - , [JoinNetwork 1.484848484848] - ) - ] - s = ControlAwait - [ ScheduleMod - (RacyThreadId [3,1,3,1,2,3,2,1], 7) - ControlDefault - [ (RacyThreadId [3,1,3,1,2,3,2], 32) - , (RacyThreadId [3,1,3,1,2,3,2], 33) - , (RacyThreadId [2,1,3,1,4], 8) - , (RacyThreadId [2,1,3,1,4], 9) - , (RacyThreadId [2,1,3,1,4], 10) - , (RacyThreadId [2,1,3,1,4], 11) - , (RacyThreadId [2,1,3,1,4], 12) - , (RacyThreadId [2,1,3,1,4,1], 0) - , (RacyThreadId [2,1,3,1,4,1], 1) - , (RacyThreadId [2,1,3,1,4,1], 2) - , (RacyThreadId [2,1,3,1,4,1], 3) - , (RacyThreadId [2,1,3,1,4,1], 4) - , (RacyThreadId [2,1,3,1,4,1], 5) - , (RacyThreadId [2,1,3,1,4,1], 6) - , (RacyThreadId [2,1,3,1,4,1], 7) - , (RacyThreadId [2,1,3,1,4,1], 8) - , (RacyThreadId [2,1,3,1,4,1,1], 0) - , (RacyThreadId [2,1,3,1,4,1,1], 1) - , (RacyThreadId [2,1,3,1,4,1,1], 2) - , (RacyThreadId [2,1,3,1,4,1,1], 3) - , (RacyThreadId [2,1,3,1,4,1], 9) - , (RacyThreadId [2,1,3,1,4,1], 10) - , (RacyThreadId [2,1,3,1,4,1], 11) - , (RacyThreadId [2,1,3,1,4,1], 12) - , (RacyThreadId [2,1,3,1,4,1], 13) - , (RacyThreadId [2,1,3,1,4,1], 14) - , (RacyThreadId [2,1,3,1,4,1], 15) - , (RacyThreadId [2,1,3,1,4], 13) - , (RacyThreadId [2,1,3,1,4], 14) - , (RacyThreadId [2,1,3,1,4], 15) - , (RacyThreadId [2,1,3,1,4], 16) - , (RacyThreadId [3,1,3,1,2,3,2], 34) - ] - ] - sim :: forall s. IOSim s Void - sim = do - exploreRaces - diffusionSimulation (toBearerInfo bi) ds iosimTracer - in exploreSimTrace (\a -> a { explorationReplay = Just s }) sim $ \_ ioSimTrace -> - prop_diffusion_cm_valid_transition_order_iosim_por ioSimTrace 10000 -- | As a basic property we run the governor to explore its state space a bit -- and check it does not throw any exceptions (assertions such as invariant @@ -3192,7 +3043,7 @@ prop_diffusion_cm_valid_transition_order_iosim_por ioSimTrace traceNumber = property . bifoldMap (const mempty) - (verifyAbstractTransitionOrder False) + (verifyAbstractTransitionOrder id False) . fmap (map ttTransition) . groupConns id abstractStateIsFinalTransitionTVarTracing @@ -3225,25 +3076,24 @@ prop_diffusion_cm_valid_transition_order ioSimTrace traceNumber = . last $ evsList in classifySimulatedTime lastTime - $ classifyNumberOfEvents (length evsList) - $ verify_cm_valid_transition_order - $ (\(WithName _ (WithTime _ b)) -> b) - <$> ev + . classifyNumberOfEvents (length evsList) + . verify_cm_valid_transition_order + $ ev ) <$> events where - verify_cm_valid_transition_order :: Trace () DiffusionTestTrace -> Property + verify_cm_valid_transition_order :: Trace () (WithName NtNAddr (WithTime DiffusionTestTrace)) -> Property verify_cm_valid_transition_order events = - let abstractTransitionEvents :: Trace () (AbstractTransitionTrace NtNAddr) + let abstractTransitionEvents :: Trace () (WithName NtNAddr (WithTime (AbstractTransitionTrace NtNAddr))) abstractTransitionEvents = - selectDiffusionConnectionManagerTransitionEvents events + selectDiffusionConnectionManagerTransitionEvents' events in property . bifoldMap (const mempty) - (verifyAbstractTransitionOrder False) - . fmap (map ttTransition) - . groupConns id abstractStateIsFinalTransition + (verifyAbstractTransitionOrder (wtEvent . wnEvent) False) + . fmap (map (fmap (fmap ttTransition))) + . groupConns (wtEvent . wnEvent) abstractStateIsFinalTransition $ abstractTransitionEvents -- | Unit test that checks issue 4258 @@ -4329,6 +4179,18 @@ selectDiffusionConnectionManagerTransitionEvents = _ -> Nothing) . Trace.toList +selectDiffusionConnectionManagerTransitionEvents' + :: Trace () (WithName NtNAddr (WithTime DiffusionTestTrace)) + -> Trace () (WithName NtNAddr (WithTime (AbstractTransitionTrace NtNAddr))) +selectDiffusionConnectionManagerTransitionEvents' = + Trace.fromList () + . mapMaybe + (\case + (WithName addr (WithTime time (DiffusionConnectionManagerTransitionTrace e))) + -> Just (WithName addr (WithTime time e)) + _ -> Nothing) + . Trace.toList + selectDiffusionConnectionManagerTransitionEventsTime :: Trace () (Time, DiffusionTestTrace) -> Trace () (Time, AbstractTransitionTrace NtNAddr) diff --git a/ouroboros-network/src/Ouroboros/Network/PeerSelection/PeerStateActions.hs b/ouroboros-network/src/Ouroboros/Network/PeerSelection/PeerStateActions.hs index 94f0a016e2..fec2e79be8 100644 --- a/ouroboros-network/src/Ouroboros/Network/PeerSelection/PeerStateActions.hs +++ b/ouroboros-network/src/Ouroboros/Network/PeerSelection/PeerStateActions.hs @@ -637,7 +637,7 @@ withPeerStateActions PeerStateActionsArguments { PeerCold -> return Nothing PeerCooling -> do - waitForOutboundDemotion spsConnectionManager (remoteAddress pchConnectionId) + waitForOutboundDemotion spsConnectionManager pchConnectionId writeTVar pchPeerStatus PeerCold return Nothing _ -> @@ -712,14 +712,12 @@ withPeerStateActions PeerStateActionsArguments { -- wrong). Just (WithSomeProtocolTemperature (WithWarm MiniProtocolSuccess {})) -> do isCooling <- closePeerConnection pch - if isCooling - then peerMonitoringLoop pch - else return () + when isCooling + $ peerMonitoringLoop pch Just (WithSomeProtocolTemperature (WithEstablished MiniProtocolSuccess {})) -> do isCooling <- closePeerConnection pch - if isCooling - then peerMonitoringLoop pch - else return () + when isCooling + $ peerMonitoringLoop pch Nothing -> traceWith spsTracer (PeerStatusChanged (CoolingToCold pchConnectionId)) @@ -744,7 +742,7 @@ withPeerStateActions PeerStateActionsArguments { >> throwIO e ShutdownPeer -> throwIO e - Right (Connected connectionId@ConnectionId { localAddress, remoteAddress } + Right (Connected connId@ConnectionId { localAddress, remoteAddress } _dataFlow (Handle mux muxBundle controlMessageBundle versionData)) -> do @@ -757,7 +755,7 @@ withPeerStateActions PeerStateActionsArguments { let connHandle = PeerConnectionHandle { - pchConnectionId = connectionId, + pchConnectionId = connId, pchPeerStatus = peerStateVar, pchMux = mux, pchAppHandles = mkApplicationHandleBundle @@ -782,9 +780,9 @@ withPeerStateActions PeerStateActionsArguments { Nothing -> Just e) (\e -> do atomically $ do - waitForOutboundDemotion spsConnectionManager remoteAddress + waitForOutboundDemotion spsConnectionManager connId writeTVar peerStateVar PeerCold - traceWith spsTracer (PeerMonitoringError connectionId e) + traceWith spsTracer (PeerMonitoringError connId e) throwIO e) (peerMonitoringLoop connHandle $> Nothing)) (return . Just) @@ -1055,7 +1053,7 @@ withPeerStateActions PeerStateActionsArguments { -- 'unregisterOutboundConnection' could only fail to demote the peer if -- connection manager would simultaneously promote it, but this is not -- possible. - _ <- releaseOutboundConnection spsConnectionManager (remoteAddress pchConnectionId) + _ <- releaseOutboundConnection spsConnectionManager pchConnectionId wasWarm <- atomically (updateUnlessCoolingOrCold pchPeerStatus PeerCooling) when wasWarm $ traceWith spsTracer (PeerStatusChanged (WarmToCooling pchConnectionId))