From 041a8db34c77451294e9c899220ceae86c89e0d0 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 18 Nov 2024 07:08:39 +0100 Subject: [PATCH 01/15] connection-manager: moved ConnectionState to its own module --- .../ouroboros-network-framework.cabal | 1 + .../Network/ConnectionManager/Core.hs | 233 +--------------- .../Network/ConnectionManager/State.hs | 260 ++++++++++++++++++ 3 files changed, 262 insertions(+), 232 deletions(-) create mode 100644 ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs diff --git a/ouroboros-network-framework/ouroboros-network-framework.cabal b/ouroboros-network-framework/ouroboros-network-framework.cabal index 18f64c26ce..ee07835882 100644 --- a/ouroboros-network-framework/ouroboros-network-framework.cabal +++ b/ouroboros-network-framework/ouroboros-network-framework.cabal @@ -31,6 +31,7 @@ library Ouroboros.Network.ConnectionId Ouroboros.Network.ConnectionManager.Core Ouroboros.Network.ConnectionManager.InformationChannel + Ouroboros.Network.ConnectionManager.State Ouroboros.Network.ConnectionManager.Types Ouroboros.Network.Context Ouroboros.Network.Driver diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 90083bac36..296f4c5be7 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -8,8 +8,6 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} --- Undecidable instances are need for 'Show' instance of 'ConnectionState'. -{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} @@ -41,9 +39,7 @@ import Control.Monad.Class.MonadTimer.SI import Control.Monad.Fix import Control.Tracer (Tracer, contramap, traceWith) import Data.Foldable (foldMap', traverse_) -import Data.Function (on) import Data.Functor (void, ($>)) -import Data.Maybe (maybeToList) import Data.Proxy (Proxy (..)) import Data.Typeable (Typeable) import GHC.Stack (CallStack, HasCallStack, callStack) @@ -66,11 +62,11 @@ import Ouroboros.Network.ConnectionManager.InformationChannel (InformationChannel) import Ouroboros.Network.ConnectionManager.InformationChannel qualified as InfoChannel import Ouroboros.Network.ConnectionManager.Types +import Ouroboros.Network.ConnectionManager.State import Ouroboros.Network.InboundGovernor.Event (NewConnectionInfo (..)) import Ouroboros.Network.MuxMode import Ouroboros.Network.Server.RateLimiting (AcceptedConnectionsLimit (..)) import Ouroboros.Network.Snocket -import Ouroboros.Network.Testing.Utils (WithName (..)) -- | Arguments for a 'ConnectionManager' which are independent of 'MuxMode'. @@ -151,106 +147,6 @@ data Arguments handlerTrace socket peerAddr handle handleError versionNumber ver } --- | 'MutableConnState', which supplies a unique identifier. --- --- TODO: We can get away without id, by tracking connections in --- `TerminatingState` using a separate priority search queue. --- -data MutableConnState peerAddr handle handleError version m = MutableConnState { - -- | A unique identifier - -- - connStateId :: !Int - - , -- | Mutable state - -- - connVar :: !(StrictTVar m (ConnectionState peerAddr handle handleError - version m)) - } - - -instance Eq (MutableConnState peerAddr handle handleError version m) where - (==) = (==) `on` connStateId - - --- | A supply of fresh id's. --- --- We use a fresh ids for 'MutableConnState'. --- -newtype FreshIdSupply m = FreshIdSupply { getFreshId :: STM m Int } - - --- | Create a 'FreshIdSupply' inside an 'STM' monad. --- -newFreshIdSupply :: forall m. MonadSTM m - => Proxy m -> STM m (FreshIdSupply m) -newFreshIdSupply _ = do - (v :: StrictTVar m Int) <- newTVar 0 - let getFreshId :: STM m Int - getFreshId = do - c <- readTVar v - writeTVar v (succ c) - return c - return $ FreshIdSupply { getFreshId } - - -newMutableConnState :: forall peerAddr handle handleError version m. - ( MonadTraceSTM m - , Typeable peerAddr - ) - => peerAddr - -> FreshIdSupply m - -> ConnectionState peerAddr handle handleError - version m - -> STM m (MutableConnState peerAddr handle handleError - version m) -newMutableConnState peerAddr freshIdSupply connState = do - connStateId <- getFreshId freshIdSupply - connVar <- newTVar connState - -- This tracing is a no op in IO. - -- - -- We need this for IOSimPOR testing of connection manager state - -- transition tests. It can happen that the transitions happen - -- correctly but IOSimPOR reorders the threads that log the transitions. - -- This is a false positive and we don't want that to happen. - -- - -- The simplest way to do so is to leverage the `traceTVar` IOSim - -- capabilities. These trace messages won't be reordered by IOSimPOR - -- since these happen atomically in STM. - -- - traceTVar - (Proxy @m) connVar - (\mbPrev curr -> - let currAbs = abstractState (Known curr) - in case mbPrev of - Just prev | - let prevAbs = abstractState (Known prev) - , prevAbs /= currAbs -> pure - $ TraceDynamic - $ WithName connStateId - $ TransitionTrace peerAddr - $ mkAbsTransition prevAbs - currAbs - Nothing -> pure - $ TraceDynamic - $ WithName connStateId - $ TransitionTrace peerAddr - $ mkAbsTransition TerminatedSt - currAbs - _ -> pure DontTrace - ) - return $ MutableConnState { connStateId, connVar } - - --- | 'ConnectionManager' state: for each peer we keep a 'ConnectionState' in --- a mutable variable, which reduces congestion on the 'TMVar' which keeps --- 'ConnectionManagerState'. --- --- It is important we can lookup by remote @peerAddr@; this way we can find if --- the connection manager is already managing a connection towards that --- @peerAddr@ and reuse the 'ConnectionState'. --- -type ConnectionManagerState peerAddr handle handleError version m - = Map peerAddr (MutableConnState peerAddr handle handleError version m) connectionManagerStateToCounters :: Map peerAddr (ConnectionState peerAddr handle handleError version m) @@ -258,44 +154,6 @@ connectionManagerStateToCounters connectionManagerStateToCounters = foldMap' connectionStateToCounters --- | State of a connection. --- -data ConnectionState peerAddr handle handleError version m = - -- | Each outbound connections starts in this state. - ReservedOutboundState - - -- | Each inbound connection starts in this state, outbound connection - -- reach this state once `connect` call returns. - -- - -- note: the async handle is lazy, because it's passed with 'mfix'. - | UnnegotiatedState !Provenance - !(ConnectionId peerAddr) - (Async m ()) - - -- | @OutboundState Unidirectional@ state. - | OutboundUniState !(ConnectionId peerAddr) !(Async m ()) !handle - - -- | Either @OutboundState Duplex@ or @OutboundState^\tau Duplex@. - | OutboundDupState !(ConnectionId peerAddr) !(Async m ()) !handle !TimeoutExpired - - -- | Before connection is reset it is put in 'OutboundIdleState' for the - -- duration of 'outboundIdleTimeout'. - -- - | OutboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow - | InboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow - | InboundState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow - | DuplexState !(ConnectionId peerAddr) !(Async m ()) !handle - | TerminatingState !(ConnectionId peerAddr) !(Async m ()) !(Maybe handleError) - | TerminatedState !(Maybe handleError) - - --- | Return 'True' for states in which the connection was already closed. --- -connectionTerminated :: ConnectionState peerAddr handle handleError version m - -> Bool -connectionTerminated TerminatingState {} = True -connectionTerminated TerminatedState {} = True -connectionTerminated _ = False -- | Perform counting from an 'AbstractState' @@ -349,76 +207,6 @@ connectionStateToCounters state = outboundConn = ConnectionManagerCounters 0 0 0 0 1 -instance ( Show peerAddr - , Show handleError - , MonadAsync m - ) - => Show (ConnectionState peerAddr handle handleError version m) where - show ReservedOutboundState = "ReservedOutboundState" - show (UnnegotiatedState pr connId connThread) = - concat ["UnnegotiatedState " - , show pr - , " " - , show connId - , " " - , show (asyncThreadId connThread) - ] - show (OutboundUniState connId connThread _handle) = - concat [ "OutboundState Unidirectional " - , show connId - , " " - , show (asyncThreadId connThread) - ] - show (OutboundDupState connId connThread _handle expired) = - concat [ "OutboundState " - , show connId - , " " - , show (asyncThreadId connThread) - , " " - , show expired - ] - show (OutboundIdleState connId connThread _handle df) = - concat [ "OutboundIdleState " - , show connId - , " " - , show (asyncThreadId connThread) - , " " - , show df - ] - show (InboundIdleState connId connThread _handle df) = - concat [ "InboundIdleState " - , show connId - , " " - , show (asyncThreadId connThread) - , " " - , show df - ] - show (InboundState connId connThread _handle df) = - concat [ "InboundState " - , show connId - , " " - , show (asyncThreadId connThread) - , " " - , show df - ] - show (DuplexState connId connThread _handle) = - concat [ "DuplexState " - , show connId - , " " - , show (asyncThreadId connThread) - ] - show (TerminatingState connId connThread handleError) = - concat ([ "TerminatingState " - , show connId - , " " - , show (asyncThreadId connThread) - ] - ++ maybeToList ((' ' :) . show <$> handleError)) - show (TerminatedState handleError) = - concat (["TerminatedState"] - ++ maybeToList ((' ' :) . show <$> handleError)) - - getConnThread :: ConnectionState peerAddr handle handleError version m -> Maybe (Async m ()) getConnThread ReservedOutboundState = Nothing @@ -465,25 +253,6 @@ isInboundConn DuplexState {} = True isInboundConn TerminatingState {} = False isInboundConn TerminatedState {} = False - -abstractState :: MaybeUnknown (ConnectionState muxMode peerAddr m a b) -> AbstractState -abstractState = \case - Unknown -> UnknownConnectionSt - Race s' -> go s' - Known s' -> go s' - where - go :: ConnectionState muxMode peerAddr m a b -> AbstractState - go ReservedOutboundState {} = ReservedOutboundSt - go (UnnegotiatedState pr _ _) = UnnegotiatedSt pr - go (OutboundUniState _ _ _) = OutboundUniSt - go (OutboundDupState _ _ _ te) = OutboundDupSt te - go (OutboundIdleState _ _ _ df) = OutboundIdleSt df - go (InboundIdleState _ _ _ df) = InboundIdleSt df - go (InboundState _ _ _ df) = InboundSt df - go DuplexState {} = DuplexSt - go TerminatingState {} = TerminatingSt - go TerminatedState {} = TerminatedSt - -- | The default value for 'timeWaitTimeout'. -- defaultTimeWaitTimeout :: DiffTime diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs new file mode 100644 index 0000000000..8947db5a7c --- /dev/null +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs @@ -0,0 +1,260 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +-- Undecidable instances are need for 'Show' instance of 'ConnectionState'. +{-# LANGUAGE QuantifiedConstraints #-} + +module Ouroboros.Network.ConnectionManager.State + ( ConnectionManagerState + , MutableConnState (..) + , FreshIdSupply + , newFreshIdSupply + , newMutableConnState + , abstractState + , ConnectionState (..) + , connectionTerminated + ) where + +import Control.Monad.Class.MonadAsync +import Control.Concurrent.Class.MonadSTM.Strict +import Data.Function (on) +import Data.Map (Map) +import Data.Maybe (maybeToList) +import Data.Proxy (Proxy (..)) +import Data.Typeable (Typeable) + +import Ouroboros.Network.ConnectionId +import Ouroboros.Network.ConnectionManager.Types + +import Ouroboros.Network.Testing.Utils (WithName (..)) + +-- | 'ConnectionManager' state: for each peer we keep a 'ConnectionState' in +-- a mutable variable, which reduces congestion on the 'TMVar' which keeps +-- 'ConnectionManagerState'. +-- +-- It is important we can lookup by remote @peerAddr@; this way we can find if +-- the connection manager is already managing a connection towards that +-- @peerAddr@ and reuse the 'ConnectionState'. +-- +type ConnectionManagerState peerAddr handle handleError version m + = Map peerAddr (MutableConnState peerAddr handle handleError version m) + + +-- | 'MutableConnState', which supplies a unique identifier. +-- +-- TODO: We can get away without id, by tracking connections in +-- `TerminatingState` using a separate priority search queue. +-- +data MutableConnState peerAddr handle handleError version m = MutableConnState { + -- | A unique identifier + -- + connStateId :: !Int + + , -- | Mutable state + -- + connVar :: !(StrictTVar m (ConnectionState peerAddr handle handleError + version m)) + } + + +instance Eq (MutableConnState peerAddr handle handleError version m) where + (==) = (==) `on` connStateId + + +-- | A supply of fresh id's. +-- +-- We use a fresh ids for 'MutableConnState'. +-- +newtype FreshIdSupply m = FreshIdSupply { getFreshId :: STM m Int } + + +-- | Create a 'FreshIdSupply' inside an 'STM' monad. +-- +newFreshIdSupply :: forall m. MonadSTM m + => Proxy m -> STM m (FreshIdSupply m) +newFreshIdSupply _ = do + (v :: StrictTVar m Int) <- newTVar 0 + let getFreshId :: STM m Int + getFreshId = do + c <- readTVar v + writeTVar v (succ c) + return c + return $ FreshIdSupply { getFreshId } + + +newMutableConnState :: forall peerAddr handle handleError version m. + ( MonadTraceSTM m + , Typeable peerAddr + ) + => peerAddr + -> FreshIdSupply m + -> ConnectionState peerAddr handle handleError + version m + -> STM m (MutableConnState peerAddr handle handleError + version m) +newMutableConnState peerAddr freshIdSupply connState = do + connStateId <- getFreshId freshIdSupply + connVar <- newTVar connState + -- This tracing is a no op in IO. + -- + -- We need this for IOSimPOR testing of connection manager state + -- transition tests. It can happen that the transitions happen + -- correctly but IOSimPOR reorders the threads that log the transitions. + -- This is a false positive and we don't want that to happen. + -- + -- The simplest way to do so is to leverage the `traceTVar` IOSim + -- capabilities. These trace messages won't be reordered by IOSimPOR + -- since these happen atomically in STM. + -- + traceTVar + (Proxy @m) connVar + (\mbPrev curr -> + let currAbs = abstractState (Known curr) + in case mbPrev of + Just prev | + let prevAbs = abstractState (Known prev) + , prevAbs /= currAbs -> pure + $ TraceDynamic + $ WithName connStateId + $ TransitionTrace peerAddr + $ mkAbsTransition prevAbs + currAbs + Nothing -> pure + $ TraceDynamic + $ WithName connStateId + $ TransitionTrace peerAddr + $ mkAbsTransition TerminatedSt + currAbs + _ -> pure DontTrace + ) + return $ MutableConnState { connStateId, connVar } + + +abstractState :: MaybeUnknown (ConnectionState muxMode peerAddr m a b) -> AbstractState +abstractState = \case + Unknown -> UnknownConnectionSt + Race s' -> go s' + Known s' -> go s' + where + go :: ConnectionState muxMode peerAddr m a b -> AbstractState + go ReservedOutboundState {} = ReservedOutboundSt + go (UnnegotiatedState pr _ _) = UnnegotiatedSt pr + go (OutboundUniState _ _ _) = OutboundUniSt + go (OutboundDupState _ _ _ te) = OutboundDupSt te + go (OutboundIdleState _ _ _ df) = OutboundIdleSt df + go (InboundIdleState _ _ _ df) = InboundIdleSt df + go (InboundState _ _ _ df) = InboundSt df + go DuplexState {} = DuplexSt + go TerminatingState {} = TerminatingSt + go TerminatedState {} = TerminatedSt + + +-- | State of a connection. +-- +data ConnectionState peerAddr handle handleError version m = + -- | Each outbound connections starts in this state. + ReservedOutboundState + + -- | Each inbound connection starts in this state, outbound connection + -- reach this state once `connect` call returns. + -- + -- note: the async handle is lazy, because it's passed with 'mfix'. + | UnnegotiatedState !Provenance + !(ConnectionId peerAddr) + (Async m ()) + + -- | @OutboundState Unidirectional@ state. + | OutboundUniState !(ConnectionId peerAddr) !(Async m ()) !handle + + -- | Either @OutboundState Duplex@ or @OutboundState^\tau Duplex@. + | OutboundDupState !(ConnectionId peerAddr) !(Async m ()) !handle !TimeoutExpired + + -- | Before connection is reset it is put in 'OutboundIdleState' for the + -- duration of 'outboundIdleTimeout'. + -- + | OutboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow + | InboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow + | InboundState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow + | DuplexState !(ConnectionId peerAddr) !(Async m ()) !handle + | TerminatingState !(ConnectionId peerAddr) !(Async m ()) !(Maybe handleError) + | TerminatedState !(Maybe handleError) + + +instance ( Show peerAddr + , Show handleError + , MonadAsync m + ) + => Show (ConnectionState peerAddr handle handleError version m) where + show ReservedOutboundState = "ReservedOutboundState" + show (UnnegotiatedState pr connId connThread) = + concat ["UnnegotiatedState " + , show pr + , " " + , show connId + , " " + , show (asyncThreadId connThread) + ] + show (OutboundUniState connId connThread _handle) = + concat [ "OutboundState Unidirectional " + , show connId + , " " + , show (asyncThreadId connThread) + ] + show (OutboundDupState connId connThread _handle expired) = + concat [ "OutboundState " + , show connId + , " " + , show (asyncThreadId connThread) + , " " + , show expired + ] + show (OutboundIdleState connId connThread _handle df) = + concat [ "OutboundIdleState " + , show connId + , " " + , show (asyncThreadId connThread) + , " " + , show df + ] + show (InboundIdleState connId connThread _handle df) = + concat [ "InboundIdleState " + , show connId + , " " + , show (asyncThreadId connThread) + , " " + , show df + ] + show (InboundState connId connThread _handle df) = + concat [ "InboundState " + , show connId + , " " + , show (asyncThreadId connThread) + , " " + , show df + ] + show (DuplexState connId connThread _handle) = + concat [ "DuplexState " + , show connId + , " " + , show (asyncThreadId connThread) + ] + show (TerminatingState connId connThread handleError) = + concat ([ "TerminatingState " + , show connId + , " " + , show (asyncThreadId connThread) + ] + ++ maybeToList ((' ' :) . show <$> handleError)) + show (TerminatedState handleError) = + concat (["TerminatedState"] + ++ maybeToList ((' ' :) . show <$> handleError)) + + +-- | Return 'True' for states in which the connection was already closed. +-- +connectionTerminated :: ConnectionState peerAddr handle handleError version m + -> Bool +connectionTerminated TerminatingState {} = True +connectionTerminated TerminatedState {} = True +connectionTerminated _ = False From 3b570fd9348a495bb3bb6f1a0aa7a438be9f07cf Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 18 Nov 2024 07:14:21 +0100 Subject: [PATCH 02/15] connection-manager: use strict map --- .../src/Ouroboros/Network/ConnectionManager/Core.hs | 4 ++-- .../src/Ouroboros/Network/ConnectionManager/State.hs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 296f4c5be7..2744093fa3 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -45,8 +45,8 @@ import Data.Typeable (Typeable) import GHC.Stack (CallStack, HasCallStack, callStack) import System.Random (StdGen, split) -import Data.Map (Map) -import Data.Map qualified as Map +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as Map import Data.Set qualified as Set import Data.Monoid.Synchronisation diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs index 8947db5a7c..8b9e91389a 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs @@ -19,7 +19,7 @@ module Ouroboros.Network.ConnectionManager.State import Control.Monad.Class.MonadAsync import Control.Concurrent.Class.MonadSTM.Strict import Data.Function (on) -import Data.Map (Map) +import Data.Map.Strict (Map) import Data.Maybe (maybeToList) import Data.Proxy (Proxy (..)) import Data.Typeable (Typeable) From 12e28b9e2e8ab59ea9d7055093f0ecd537807c00 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 2 Dec 2024 17:07:07 +0100 Subject: [PATCH 03/15] simulated snocket: changed semantics of self connections On Linux when a server listens on a given (addr, port), and we connect a socket which is bound to the same address: * the accept call never returns * the connect returns, and the socket acts as a mirror: whatever is sent over it, can be received from it. This patch implements the same behaviour of simulated snockets. It also provides a test which verifies it. --- .../Network/Mux/Bearer/AttenuatedChannel.hs | 16 +++++ .../Test/Simulation/Network/Snocket.hs | 45 ++++++++++++++ .../src/Simulation/Network/Snocket.hs | 59 ++++++++++++++----- 3 files changed, 106 insertions(+), 14 deletions(-) diff --git a/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs b/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs index e7d51104ab..13620703c3 100644 --- a/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs +++ b/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs @@ -8,6 +8,9 @@ module Network.Mux.Bearer.AttenuatedChannel , Size , SuccessOrFailure (..) , Attenuation (..) + , QueueChannel + , newAttenuatedChannel + , echoQueueChannel , newConnectedAttenuatedChannelPair , attenuationChannelAsBearer -- * Trace @@ -60,6 +63,19 @@ data QueueChannel m = QueueChannel { qcWrite :: StrictTVar m (Maybe (StrictTQueue m Message)) } + +-- A `QueueChannel` which receives what is written to it. +-- +echoQueueChannel :: MonadSTM m => STM m (QueueChannel m) +echoQueueChannel = do + q <- newTQueue + v <- newTVar (Just q) + return QueueChannel { + qcRead = v, + qcWrite = v + } + + -- -- QueueChannel API -- diff --git a/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs b/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs index a02fa86e35..1c5e11399f 100644 --- a/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs @@ -40,6 +40,7 @@ import Data.ByteString.Lazy (ByteString) import Data.Foldable (traverse_) import Data.Functor (void) import Data.Map qualified as Map +import Data.Maybe (isNothing) import Data.Set (Set) import Data.Set qualified as Set import Text.Printf @@ -50,6 +51,7 @@ import Ouroboros.Network.Snocket import Simulation.Network.Snocket import Network.Mux as Mx +import Network.Mux.Types as Mx import Network.TypedProtocol.Codec.CBOR import Network.TypedProtocol.Core import Network.TypedProtocol.Peer @@ -81,6 +83,8 @@ tests = prop_connect_to_not_listening_socket , testProperty "simultaneous_open" prop_simultaneous_open + , testProperty "self connect" + prop_self_connect ] type TestAddr = TestAddress Int @@ -543,6 +547,47 @@ prop_simultaneous_open defaultBearerInfo = snocket getState wait clientAsync + +-- | Check that when we bind both outbound and inbound socket to the same +-- address, and connect the outbound to inbound: +-- +-- * accept loop never returns +-- * the outbound socket acts as a mirror +-- +-- This is how socket API behaves on Linux. +-- +prop_self_connect :: ByteString -> Property +prop_self_connect payload = + runSimOrThrow sim + where + addr :: TestAddress Int + addr = TestAddress 0 + + sim :: forall s. IOSim s Property + sim = + withSnocket nullTracer noAttenuation Map.empty + $ \snocket _getState -> + withAsync (runServer addr snocket + (close snocket) acceptOne return) + $ \serverThread -> do + bracket (openToConnect snocket addr) + (close snocket) + $ \fd -> do + bind snocket fd addr + connect snocket fd addr + bearer <- getBearer makeFDBearer 10 nullTracer fd + let channel = bearerAsChannel bearer (MiniProtocolNum 0) InitiatorDir + send channel payload + payload' <- recv channel + threadDelay 1 + serverResult <- atomically $ + (Just <$> waitSTM serverThread) + `orElse` + pure Nothing + return $ Just payload === payload' + .&&. isNothing serverResult + + -- -- Utils -- diff --git a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs index 8029d08e50..9f522101d2 100644 --- a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs @@ -90,7 +90,7 @@ data Connection m addr = Connection { -- | Attenuated channels of a connection. -- connChannelLocal :: !(AttenuatedChannel m) - , connChannelRemote :: !(AttenuatedChannel m) + , connChannelRemote :: AttenuatedChannel m -- | SDU size of a connection. -- @@ -139,12 +139,36 @@ mkConnection :: ( MonadDelay m , MonadTimer m , MonadThrow m , MonadThrow (STM m) + , Eq addr ) => Tracer m (WithAddr (TestAddress addr) (SnocketTrace m (TestAddress addr))) -> BearerInfo -> ConnectionId (TestAddress addr) -> STM m (Connection m (TestAddress addr)) + +mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } | localAddress == remoteAddress = do + -- we are connecting to onself. On Linux this returns a connection which + -- mirrors all sent data. + qc <- echoQueueChannel + channel <- newAttenuatedChannel + ( ( WithAddr (Just localAddress) (Just remoteAddress) + . STAttenuatedChannelTrace connId + ) + `contramap` tr) + Attenuation + { aReadAttenuation = biOutboundAttenuation bearerInfo + , aWriteAttenuation = biOutboundWriteFailure bearerInfo + } + qc + return Connection { + connChannelLocal = channel, + connChannelRemote = undefined, + connSDUSize = biSDUSize bearerInfo, + connState = SYN_SENT, + connProvider = localAddress + } + mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = (\(connChannelLocal, connChannelRemote) -> Connection { @@ -895,11 +919,11 @@ mkSnocket state tr = Snocket { getLocalAddr connMap <- readTVar (nsConnections state) case Map.lookup normalisedId connMap of - Just Connection { connState = ESTABLISHED } -> + Just Connection { connState = ESTABLISHED } -> throwSTM (connectedIOError fd_) - Just Connection { connState = SYN_SENT, connProvider } - | connProvider == localAddress -> + Just Connection { connState = SYN_SENT, connProvider } + | connProvider == localAddress -> throwSTM (connectedIOError fd_) -- simultaneous open @@ -979,7 +1003,7 @@ mkSnocket state tr = Snocket { getLocalAddr <$> readTVar (nsConnections state) case lstFd of -- error cases - (Nothing) -> + Nothing -> return (Left (connectIOError connId "no such listening socket")) (Just FDUninitialised {}) -> return (Left (connectIOError connId "unitialised listening socket")) @@ -1009,13 +1033,16 @@ mkSnocket state tr = Snocket { getLocalAddr Just conn@Connection { connState = SYN_SENT } -> do let fd_' = FDConnected connId conn writeTVar fdVarLocal fd_' - writeTBQueue queue - ChannelWithInfo - { cwiAddress = localAddress connId - , cwiSDUSize = biSDUSize bearerInfo - , cwiChannelLocal = connChannelRemote conn - , cwiChannelRemote = connChannelLocal conn - } + when (localAddress connId /= remoteAddress) $ + -- We only write to the accept `queue` if we're not + -- connecting to ourselves. + writeTBQueue queue + ChannelWithInfo + { cwiAddress = localAddress connId + , cwiSDUSize = biSDUSize bearerInfo + , cwiChannelLocal = connChannelRemote conn + , cwiChannelRemote = connChannelLocal conn + } return (Right (fd_', NormalOpen)) Just Connection { connState = FIN } -> do @@ -1049,7 +1076,7 @@ mkSnocket state tr = Snocket { getLocalAddr (\e -> atomically $ modifyTVar (nsConnections state) (Map.delete (normaliseId connId)) >> throwIO e) - $ unmask (atomically $ runFirstToFinish $ + $ unmask . atomically . runFirstToFinish $ (FirstToFinish $ do LazySTM.readTVar timeoutVar >>= check modifyTVar (nsConnections state) @@ -1072,12 +1099,16 @@ mkSnocket state tr = Snocket { getLocalAddr ++ show (normaliseId connId) Just Connection { connState } -> Just <$> check (connState == ESTABLISHED)) - ) case r of + -- self connect + Nothing | localAddress connId == remoteAddress + -> traceWith' fd (STConnected fd_' o) + Nothing -> do traceWith' fd (STConnectTimeout WaitingToBeAccepted) throwIO (connectIOError connId "connect timeout: when waiting for being accepted") + Just _ -> traceWith' fd (STConnected fd_' o) FDConnecting {} -> From 2afa4b090e60fa844f28ec4f922f573974f5d8b7 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 2 Dec 2024 17:09:54 +0100 Subject: [PATCH 04/15] simulated snocket: code style --- ouroboros-network-framework/src/Simulation/Network/Snocket.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs index 9f522101d2..3a80257926 100644 --- a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs @@ -1077,14 +1077,14 @@ mkSnocket state tr = Snocket { getLocalAddr (Map.delete (normaliseId connId)) >> throwIO e) $ unmask . atomically . runFirstToFinish $ - (FirstToFinish $ do + FirstToFinish (do LazySTM.readTVar timeoutVar >>= check modifyTVar (nsConnections state) (Map.delete (normaliseId connId)) return Nothing ) <> - (FirstToFinish $ do + FirstToFinish (do mbConn <- Map.lookup (normaliseId connId) <$> readTVar (nsConnections state) case mbConn of From 4fd1cfa82302b8e44eb4c6cb4e305d572c607e25 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 18 Nov 2024 09:04:20 +0100 Subject: [PATCH 05/15] connection-manager: identify connections by local & remote address Connections in side connection manager state must be identified by their `ConnectionId`. Since the state map also reserves places for outbound connections for which we only know the remote address the state is slightly more general than `Map (ConnectionId peerAddr) ...`. The data type used by the state is provided in `Ouroboros.Network.ConnectionManager.ConnMap` module, the `Ouroboros.Network.ConnectionManager.State` module provides type alias for the state type and additional APIs needed by the connection manager. --- .../ouroboros-network-framework.cabal | 1 + .../src/Ouroboros/Network/ConnectionId.hs | 3 +- .../Network/ConnectionManager/ConnMap.hs | 269 +++++++++++ .../Network/ConnectionManager/Core.hs | 455 +++++++++--------- .../Network/ConnectionManager/State.hs | 51 +- .../Network/ConnectionManager/Types.hs | 39 +- .../src/Ouroboros/Network/InboundGovernor.hs | 9 +- .../src/Ouroboros/Network/Server2.hs | 15 +- .../Network/PeerSelection/PeerStateActions.hs | 22 +- 9 files changed, 566 insertions(+), 298 deletions(-) create mode 100644 ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/ConnMap.hs diff --git a/ouroboros-network-framework/ouroboros-network-framework.cabal b/ouroboros-network-framework/ouroboros-network-framework.cabal index ee07835882..580e3acfbb 100644 --- a/ouroboros-network-framework/ouroboros-network-framework.cabal +++ b/ouroboros-network-framework/ouroboros-network-framework.cabal @@ -29,6 +29,7 @@ library Ouroboros.Network.Channel Ouroboros.Network.ConnectionHandler Ouroboros.Network.ConnectionId + Ouroboros.Network.ConnectionManager.ConnMap Ouroboros.Network.ConnectionManager.Core Ouroboros.Network.ConnectionManager.InformationChannel Ouroboros.Network.ConnectionManager.State diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionId.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionId.hs index b607fbff1c..b505db4af0 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionId.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionId.hs @@ -33,7 +33,8 @@ data ConnectionId addr = ConnectionId { -- -- /Note:/ we relay on the fact that `remoteAddress` is an order -- preserving map (which allows us to use `Map.mapKeysMonotonic` in some --- cases). +-- cases. We also relay on this particular order in +-- `Ouroboros.Network.ConnectionManager.State.liveConnections` -- instance Ord addr => Ord (ConnectionId addr) where conn `compare` conn' = diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/ConnMap.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/ConnMap.hs new file mode 100644 index 0000000000..22c4e79ea3 --- /dev/null +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/ConnMap.hs @@ -0,0 +1,269 @@ +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Ouroboros.Network.ConnectionManager.ConnMap + ( ConnMap (..) + , LocalAddr (..) + , toList + , toMap + , empty + , insert + , insertUnknownLocalAddr + , delete + -- , deleteUnknownLocalAddr + , deleteAtRemoteAddr + , lookup + , lookupByRemoteAddr + , updateLocalAddr + , traverseMaybeWithKey + ) where + +import Prelude hiding (lookup) + +import Data.Foldable qualified as Foldable +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as Map +import System.Random (RandomGen, uniformR) + +import Ouroboros.Network.ConnectionId + +data LocalAddr peerAddr = + -- | A reserved slot for an outbound connection which is being created. The + -- outbound connection must be in the `ReservedOutbound` state. + UnknownLocalAddr + -- | All connections which are not in the `ReservedOutbound` state use + -- `LocalAddr`, since for them the local address is known. + | LocalAddr peerAddr + deriving (Show, Eq, Ord) + + +-- | The outer map keys are remote addresses, the internal ones are local +-- addresses. +-- +newtype ConnMap peerAddr a + = ConnMap { + getConnMap :: + Map peerAddr + (Map (LocalAddr peerAddr) a) + } + deriving (Show, Functor, Foldable) + +instance Traversable (ConnMap peerAddr) where + traverse f (ConnMap m) = ConnMap <$> traverse (traverse f) m + + +toList :: ConnMap m a -> [a] +toList = Foldable.toList + + +-- | Create a map of all connections with a known `ConnectionId`. +-- +toMap :: forall peerAddr a. + Ord peerAddr + => ConnMap peerAddr a + -> Map (ConnectionId peerAddr) a +toMap = + -- We can use `fromAscList` because of the `Ord` instance of `ConnectionId`. + -- /NOTE:/ if `fromAscList` is used on input which doesn't satisfy its + -- precondition, then `Map.lookup` might fail when it shouldn't. + Map.fromAscList + . Map.foldrWithKey + (\remoteAddress st conns -> + Map.foldrWithKey + (\localAddr conn conns' -> + case localAddr of + UnknownLocalAddr -> conns' + LocalAddr localAddress -> + (ConnectionId { remoteAddress, localAddress }, conn) : conns' + ) + conns + st + ) + [] + . getConnMap + + +empty :: ConnMap peerAddr a +empty = ConnMap Map.empty + + +insert :: Ord peerAddr + => ConnectionId peerAddr + -> a + -> ConnMap peerAddr a + -> ConnMap peerAddr a +insert ConnectionId { remoteAddress, localAddress } a = + ConnMap + . Map.alter + (\case + Nothing -> Just $! Map.singleton (LocalAddr localAddress) a + Just st -> Just $! Map.insert (LocalAddr localAddress) a st) + remoteAddress + . getConnMap + + +insertUnknownLocalAddr + :: Ord peerAddr + => peerAddr + -> a + -> ConnMap peerAddr a + -> ConnMap peerAddr a +insertUnknownLocalAddr remoteAddress a = + ConnMap + . Map.alter + (\case + Nothing -> Just $! Map.singleton UnknownLocalAddr a + Just st -> Just $! Map.insert UnknownLocalAddr a st + ) + remoteAddress + . getConnMap + + +delete :: Ord peerAddr + => ConnectionId peerAddr + -> ConnMap peerAddr a + -> ConnMap peerAddr a +delete ConnectionId { remoteAddress, localAddress } = + ConnMap + . Map.alter + (\case + Nothing -> Nothing + Just st -> + let st' = Map.delete (LocalAddr localAddress) st + in if Map.null st' + then Nothing + else Just st' + ) + remoteAddress + . getConnMap + + +deleteAtRemoteAddr + :: (Ord peerAddr, Eq a) + => peerAddr + -- ^ remoteAddr + -> a + -- ^ element to remove + -> ConnMap peerAddr a + -> ConnMap peerAddr a +deleteAtRemoteAddr remoteAddress a = + ConnMap + . Map.alter + (\case + Nothing -> Nothing + Just st -> + let st' = Map.filter (/=a) st in + if Map.null st' + then Nothing + else Just st' + ) + remoteAddress + . getConnMap + + + +lookup :: Ord peerAddr + => ConnectionId peerAddr + -> ConnMap peerAddr a + -> Maybe a +lookup ConnectionId { remoteAddress, localAddress } (ConnMap st) = + case remoteAddress `Map.lookup` st of + Nothing -> Nothing + Just st' -> LocalAddr localAddress `Map.lookup` st' + + +-- | Find a random entry for a given remote address. +-- +-- NOTE: the outbound governor will only ask for a connection to a peer if it +-- doesn't have one (and one isn't being created). This property, simplifies +-- `lookupOutbound`: we can pick (randomly) one of the connections to the +-- remote peer. When the outbound governor is asking, likely all of the +-- connections are in an inbound state. The outbound governor will has a grace +-- period after demoting a peer, so a race (when a is being demoted and +-- promoted at the same time) is unlikely. +-- +lookupByRemoteAddr + :: forall rnd peerAddr a. + ( Ord peerAddr + , RandomGen rnd + ) + => rnd + -- ^ a fresh `rnd` (it must come from a `split`) + -> peerAddr + -- ^ remote address + -> ConnMap peerAddr a + -> (Maybe a) +lookupByRemoteAddr rnd remoteAddress (ConnMap st) = + case remoteAddress `Map.lookup` st of + Nothing -> Nothing + Just st' -> + case UnknownLocalAddr `Map.lookup` st' of + Just a -> Just a + Nothing -> + if Map.null st' + then Nothing + else let (indx, _rnd') = uniformR (0, Map.size st' - 1) rnd + in Just $ snd $ Map.elemAt indx st' + + +-- | Promote `UnknownLocalAddr` to `LocalAddr`. +-- +updateLocalAddr + :: Ord peerAddr + => ConnectionId peerAddr + -> ConnMap peerAddr a + -> (Bool, ConnMap peerAddr a) + -- ^ Return `True` iff the entry was updated. +updateLocalAddr ConnectionId { remoteAddress, localAddress } (ConnMap m) = + ConnMap + <$> Map.alterF + (\case + Nothing -> (False, Nothing) + Just m' -> + let -- delete & lookup for entry in `UnknownLocalAddr` + (ma, m'') = + Map.alterF + (\x -> (x,Nothing)) + UnknownLocalAddr + m' + in + case ma of + -- there was no entry, so no need to update the inner map + Nothing -> (False, Just m') + -- we have an entry: put it in the `LocalAddr`, but only if it's + -- not present in the map. + Just {} -> + fmap Just + . Map.alterF + (\case + Nothing -> (True, ma) + a@Just{} -> (False, a) + ) + (LocalAddr localAddress) + $ m'' + ) + remoteAddress + m + + +traverseMaybeWithKey + :: Applicative f + => (Either peerAddr (ConnectionId peerAddr) -> a -> f (Maybe b)) + -> ConnMap peerAddr a + -> f [b] +traverseMaybeWithKey fn = + fmap (concat . Map.elems) + . Map.traverseMaybeWithKey + (\remoteAddress st -> + fmap (Just . Map.elems) + . Map.traverseMaybeWithKey + (\case + UnknownLocalAddr -> fn (Left remoteAddress) + LocalAddr localAddress -> fn (Right ConnectionId { remoteAddress, + localAddress }) + ) + $ st + ) + . getConnMap diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 2744093fa3..837816e2ff 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -1,15 +1,14 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE UndecidableInstances #-} -- | The implementation of connection manager. -- @@ -23,14 +22,14 @@ module Ouroboros.Network.ConnectionManager.Core , defaultProtocolIdleTimeout , defaultResetTimeout , ConnectionState (..) - , abstractState + , State.abstractState ) where import Control.Applicative (Alternative) import Control.Concurrent.Class.MonadSTM qualified as LazySTM import Control.Concurrent.Class.MonadSTM.Strict import Control.Exception (assert) -import Control.Monad (forM_, guard, when, (>=>)) +import Control.Monad (forM_, guard, unless, when, (>=>)) import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadFork (throwTo) import Control.Monad.Class.MonadThrow hiding (handle) @@ -51,6 +50,7 @@ import Data.Set qualified as Set import Data.Monoid.Synchronisation import Data.Set (Set) +import Data.Tuple (swap) import Data.Wedge import Data.Word (Word32) @@ -61,8 +61,10 @@ import Ouroboros.Network.ConnectionId import Ouroboros.Network.ConnectionManager.InformationChannel (InformationChannel) import Ouroboros.Network.ConnectionManager.InformationChannel qualified as InfoChannel +import Ouroboros.Network.ConnectionManager.State (ConnectionManagerState, + ConnectionState (..), FreshIdSupply, MutableConnState (..)) +import Ouroboros.Network.ConnectionManager.State qualified as State import Ouroboros.Network.ConnectionManager.Types -import Ouroboros.Network.ConnectionManager.State import Ouroboros.Network.InboundGovernor.Event (NewConnectionInfo (..)) import Ouroboros.Network.MuxMode import Ouroboros.Network.Server.RateLimiting (AcceptedConnectionsLimit (..)) @@ -149,10 +151,9 @@ data Arguments handlerTrace socket peerAddr handle handleError versionNumber ver connectionManagerStateToCounters - :: Map peerAddr (ConnectionState peerAddr handle handleError version m) + :: State.ConnMap peerAddr (ConnectionState peerAddr handle handleError version m) -> ConnectionManagerCounters -connectionManagerStateToCounters = - foldMap' connectionStateToCounters +connectionManagerStateToCounters = foldMap' connectionStateToCounters @@ -168,10 +169,10 @@ connectionStateToCounters state = UnnegotiatedState Outbound _ _ -> outboundConn - OutboundUniState _ _ _ -> unidirectionalConn + OutboundUniState {} -> unidirectionalConn <> outboundConn - OutboundDupState _ _ _ _ -> duplexConn + OutboundDupState {} -> duplexConn <> outboundConn OutboundIdleState _ _ _ Unidirectional -> unidirectionalConn @@ -192,13 +193,13 @@ connectionStateToCounters state = InboundState _ _ _ Duplex -> duplexConn <> inboundConn - DuplexState _ _ _ -> fullDuplexConn + DuplexState {} -> fullDuplexConn <> duplexConn - <> inboundConn + <> inboundConn <> outboundConn - TerminatingState _ _ _ -> mempty - TerminatedState _ -> mempty + TerminatingState {} -> mempty + TerminatedState {} -> mempty where fullDuplexConn = ConnectionManagerCounters 1 0 0 0 0 duplexConn = ConnectionManagerCounters 0 1 0 0 0 @@ -403,7 +404,7 @@ with args@Arguments { , StrictTVar m StdGen )) <- atomically $ do - v <- newTMVar Map.empty + v <- newTMVar State.empty labelTMVar v "cm-state" traceTMVar (Proxy :: Proxy m) v $ \old new -> @@ -415,26 +416,20 @@ with args@Arguments { (Just Nothing, Just _) -> pure (TraceString "cm-state: released") (_, _) -> pure DontTrace - freshIdSupply <- newFreshIdSupply (Proxy :: Proxy m) + freshIdSupply <- State.newFreshIdSupply (Proxy :: Proxy m) stdGenVar <- newTVar (stdGen args) return (freshIdSupply, v, stdGenVar) let readState - :: STM m (Map peerAddr AbstractState) - readState = do - state <- readTMVar stateVar - traverse ( fmap (abstractState . Known) - . readTVar - . connVar - ) - state + :: STM m (State.ConnMap peerAddr AbstractState) + readState = readTMVar stateVar >>= State.readAbstractStateMap waitForOutboundDemotion - :: peerAddr + :: ConnectionId peerAddr -> STM m () - waitForOutboundDemotion addr = do + waitForOutboundDemotion connId = do state <- readState - case Map.lookup addr state of + case State.lookup connId state of Nothing -> return () Just UnknownConnectionSt -> return () Just InboundIdleSt {} -> return () @@ -461,7 +456,7 @@ with args@Arguments { OutboundConnectionManager { ocmAcquireConnection = acquireOutboundConnectionImpl freshIdSupply stateVar - outboundHandler, + stdGenVar outboundHandler, ocmReleaseConnection = releaseOutboundConnectionImpl stateVar stdGenVar }, @@ -497,7 +492,7 @@ with args@Arguments { OutboundConnectionManager { ocmAcquireConnection = acquireOutboundConnectionImpl freshIdSupply stateVar - outboundHandler, + stdGenVar outboundHandler, ocmReleaseConnection = releaseOutboundConnectionImpl stateVar stdGenVar } @@ -529,9 +524,9 @@ with args@Arguments { -- Spawning one thread for each connection cleanup avoids spending time -- waiting for locks and cleanup logic that could delay closing the -- connections and making us not respecting certain timeouts. - asyncs <- Map.elems - <$> Map.traverseMaybeWithKey - (\peerAddr MutableConnState { connVar } -> do + asyncs <- State.traverseMaybeWithKey + (\peerAddrOrConnId MutableConnState { connVar } -> do + let remoteAddr = either id remoteAddress peerAddrOrConnId -- cleanup handler for that thread will close socket associated -- with the thread. We put each connection in 'TerminatedState' to -- try that none of the connection threads will enter @@ -546,12 +541,12 @@ with args@Arguments { connState <- readTVar connVar let connState' = TerminatedState Nothing trT = - TransitionTrace peerAddr (mkTransition connState connState') - absConnState = abstractState (Known connState) + TransitionTrace remoteAddr (mkTransition connState connState') + absConnState = State.abstractState (Known connState) shouldTraceTerminated = absConnState /= TerminatedSt shouldTraceUnknown = absConnState == ReservedOutboundSt trU = TransitionTrace - peerAddr + remoteAddr (Transition { fromState = Known connState' , toState = Unknown }) @@ -587,8 +582,8 @@ with args@Arguments { where traceCounters :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) -> m () traceCounters stateVar = do - mState <- atomically $ readTMVar stateVar >>= traverse (readTVar . connVar) - traceWith tracer (TrConnectionManagerCounters (connectionManagerStateToCounters mState)) + state <- atomically $ readTMVar stateVar >>= State.readConnectionStates + traceWith tracer (TrConnectionManagerCounters (connectionManagerStateToCounters state)) countIncomingConnections :: ConnectionManagerState peerAddr handle handleError version m @@ -596,7 +591,7 @@ with args@Arguments { countIncomingConnections st = inboundConns . connectionManagerStateToCounters - <$> traverse (readTVar . connVar) st + <$> State.readConnectionStates st -- Fork connection thread. @@ -701,11 +696,11 @@ with args@Arguments { , toState = Unknown } mbTransition <- modifyTMVar stateVar $ \state -> - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> pure (state, Nothing) Just v -> if mutableConnState == v - then pure (Map.delete peerAddr state , Just transition) + then pure (State.delete connId state , Just transition) else pure (state , Nothing) traverse_ (traceWith trTracer) mbTransition @@ -736,7 +731,7 @@ with args@Arguments { trs <- atomically $ do connState <- readTVar connVar let transition' = transition { fromState = Known connState } - shouldTrace = abstractState (Known connState) + shouldTrace = State.abstractState (Known connState) /= TerminatedSt writeTVar connVar (TerminatedState Nothing) -- We have to be careful when deleting it from @@ -745,12 +740,12 @@ with args@Arguments { modifyTMVarPure stateVar ( \state -> - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> (state, False) Just v -> if mutableConnState == v - then (Map.delete peerAddr state , True) - else (state , False) + then (State.delete connId state, True) + else (state , False) ) if updated @@ -785,7 +780,7 @@ with args@Arguments { -- their state to 'TerminatedState'; -- * an io action which logs and cancels all the connection handler -- threads. - mkPruneAction :: peerAddr + mkPruneAction :: ConnectionId peerAddr -> Int -- ^ number of connections to prune -> ConnectionManagerState peerAddr handle handleError version m @@ -797,27 +792,28 @@ with args@Arguments { -> STM m (Bool, PruneAction m) -- ^ return if the connection was choose to be pruned and the -- 'PruneAction' - mkPruneAction peerAddr numberToPrune state connState' connVar stdGenVar connThread = do - (choiceMap' :: Map peerAddr ( ConnectionType - , Async m () - , StrictTVar m - (ConnectionState - peerAddr - handle handleError - version m) - )) - <- flip Map.traverseMaybeWithKey state $ \_peerAddr MutableConnState { connVar = connVar' } -> - (\cs -> do - -- this expression returns @Maybe (connType, connThread)@; - -- 'traverseMaybeWithKey' collects all 'Just' cases. - guard (isInboundConn cs) - (,,connVar') <$> getConnType cs - <*> getConnThread cs) - <$> readTVar connVar' - let choiceMap = + mkPruneAction connId numberToPrune state connState' connVar stdGenVar connThread = do + choiceMap' + <- Map.traverseMaybeWithKey + (\_ MutableConnState { connVar = connVar' } -> + (\cs -> do + -- this expression returns @Maybe (connType, connThread)@; + -- 'traverseMaybeWithKey' collects all 'Just' cases. + guard (isInboundConn cs) + (,,connVar') <$> getConnType cs + <*> getConnThread cs) + <$> readTVar connVar' + ) + (State.toMap state) + let choiceMap :: Map (ConnectionId peerAddr) + ( ConnectionType + , Async m () + , StrictTVar m (ConnectionState peerAddr handle handleError version m) + ) + choiceMap = case getConnType connState' of Nothing -> assert False choiceMap' - Just a -> Map.insert peerAddr (a, connThread, connVar) + Just a -> Map.insert connId (a, connThread, connVar) choiceMap' stdGen <- stateTVar stdGenVar split @@ -830,7 +826,7 @@ with args@Arguments { forM_ pruneMap $ \(_, _, connVar') -> writeTVar connVar' (TerminatedState Nothing) - return ( peerAddr `Set.member` pruneSet + return ( connId `Set.member` pruneSet , PruneAction $ do traceWith tracer (TrPruneConnections (Map.keysSet pruneMap) numberToPrune @@ -858,29 +854,26 @@ with args@Arguments { -- whether we can include the connection or not. -> socket -- ^ resource to include in the state - -> peerAddr - -- ^ remote address used as an identifier of the resource + -> ConnectionId peerAddr + -- ^ connection id used as an identifier of the resource -> m (Connected peerAddr handle handleError) includeInboundConnectionImpl freshIdSupply stateVar handler hardLimit socket - peerAddr = do - (r, connId) <- modifyTMVar stateVar $ \state -> do - localAddress <- getLocalAddr snocket socket + connId = do + r <- modifyTMVar stateVar $ \state -> do numberOfCons <- atomically $ countIncomingConnections state - let connId = ConnectionId { localAddress, remoteAddress = peerAddr } - - -- Check if after accepting this connection we get above the + let -- Check if after accepting this connection we get above the -- hard limit canAccept = numberOfCons + 1 <= fromIntegral hardLimit if canAccept then do let provenance = Inbound - traceWith tracer (TrIncludeConnection provenance peerAddr) + traceWith tracer (TrIncludeConnection provenance (remoteAddress connId)) (reader, writer) <- newEmptyPromiseIO (connThread, connVar, connState0, connState) <- mfix $ \ ~(connThread, _mutableConnVar, _connState0, _connState) -> do @@ -906,11 +899,11 @@ with args@Arguments { let connState' = UnnegotiatedState provenance connId connThread (mutableConnVar', connState0') <- atomically $ do - let v0 = Map.lookup peerAddr state + let v0 = State.lookup connId state case v0 of Nothing -> do -- 'Accepted' - v <- newMutableConnState peerAddr freshIdSupply connState' + v <- State.newMutableConnState (remoteAddress connId) freshIdSupply connState' labelTVar (connVar v) ("conn-state-" ++ show connId) return (v, Nothing) Just v -> do @@ -938,8 +931,8 @@ with args@Arguments { InboundState {} -> writeTVar (connVar v) connState' $> assert False v - TerminatingState {} -> newMutableConnState peerAddr freshIdSupply connState' - TerminatedState {} -> newMutableConnState peerAddr freshIdSupply connState' + TerminatingState {} -> State.newMutableConnState (remoteAddress connId) freshIdSupply connState' + TerminatedState {} -> State.newMutableConnState (remoteAddress connId) freshIdSupply connState' labelTVar (connVar v') ("conn-state-" ++ show connId) return (v', Just connState0') @@ -949,16 +942,16 @@ with args@Arguments { stateVar mutableConnVar' socket connId writer handler return (connThread', mutableConnVar', connState0', connState') - traceWith trTracer (TransitionTrace peerAddr + traceWith trTracer (TransitionTrace (remoteAddress connId) Transition { fromState = maybe Unknown Known connState0 , toState = Known connState }) - return ( Map.insert peerAddr connVar state - , (Just (connVar, connThread, reader), connId) + return ( State.insert connId connVar state + , Just (connVar, connThread, reader) ) else return ( state - , (Nothing, connId) + , Nothing ) case r of @@ -972,10 +965,10 @@ with args@Arguments { res <- atomically $ readPromise reader case res of Left handleError -> do - terminateInboundWithErrorOrQuery connId connVar connThread peerAddr stateVar mutableConnState $ Just handleError + terminateInboundWithErrorOrQuery connId connVar connThread stateVar mutableConnState $ Just handleError Right HandshakeConnectionQuery -> do - terminateInboundWithErrorOrQuery connId connVar connThread peerAddr stateVar mutableConnState Nothing + terminateInboundWithErrorOrQuery connId connVar connThread stateVar mutableConnState Nothing Right (HandshakeConnectionResult handle (_version, versionData)) -> do let dataFlow = connectionDataFlow versionData @@ -985,7 +978,7 @@ with args@Arguments { -- Inbound connections cannot be found in this state at this -- stage. ReservedOutboundState -> - throwSTM (withCallStack (ImpossibleState peerAddr)) + throwSTM (withCallStack (ImpossibleState (remoteAddress connId))) -- -- The common case. @@ -1019,23 +1012,23 @@ with args@Arguments { ) InboundIdleState {} -> - throwSTM (withCallStack (ImpossibleState peerAddr)) + throwSTM (withCallStack (ImpossibleState (remoteAddress connId))) -- At this stage the inbound connection cannot be in -- 'InboundState', it would mean that there was another thread -- that included that connection, but this would violate @TCP@ -- constraints. InboundState {} -> - throwSTM (withCallStack (ImpossibleState peerAddr)) + throwSTM (withCallStack (ImpossibleState (remoteAddress connId))) DuplexState {} -> - throwSTM (withCallStack (ImpossibleState peerAddr)) + throwSTM (withCallStack (ImpossibleState (remoteAddress connId))) TerminatingState {} -> return (False, Nothing, Inbound) TerminatedState {} -> return (False, Nothing, Inbound) - traverse_ (traceWith trTracer . TransitionTrace peerAddr) mbTransition + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) mbTransition traceCounters stateVar -- Note that we don't set a timeout thread here which would @@ -1065,7 +1058,16 @@ with args@Arguments { else return $ Disconnected connId Nothing - terminateInboundWithErrorOrQuery connId connVar connThread peerAddr stateVar mutableConnState handleErrorM = do + terminateInboundWithErrorOrQuery + :: ConnectionId peerAddr + -> StrictTVar m (ConnectionState peerAddr handle handleError version m) + -> Async m () + -> StrictTMVar + m (ConnectionManagerState peerAddr handle handleError version m) + -> MutableConnState peerAddr handle handleError version m + -> Maybe handleError + -> m (Connected peerAddr handle1 handleError) + terminateInboundWithErrorOrQuery connId connVar connThread stateVar mutableConnState handleErrorM = do transitions <- atomically $ do connState <- readTVar connVar @@ -1081,14 +1083,14 @@ with args@Arguments { TerminatingState connId connThread handleErrorM transition = mkTransition connState connState' - absConnState = abstractState (Known connState) + absConnState = State.abstractState (Known connState) shouldTrace = absConnState /= TerminatedSt updated <- modifyTMVarSTM stateVar ( \state -> - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> return (state, False) Just mutableConnState' -> if mutableConnState' == mutableConnState @@ -1107,8 +1109,9 @@ with args@Arguments { -- tracing accordingly. writeTVar connVar connState' - return (Map.delete peerAddr state , True) - else return (state , False) + return (State.delete connId state, True) + else + return (state , False) ) if updated @@ -1139,7 +1142,7 @@ with args@Arguments { -- overwriting. else return [ ] - traverse_ (traceWith trTracer . TransitionTrace peerAddr) transitions + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) transitions traceCounters stateVar return (Disconnected connId handleErrorM) @@ -1149,13 +1152,13 @@ with args@Arguments { -- action. releaseInboundConnectionImpl :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult DemotedToColdRemoteTr) - releaseInboundConnectionImpl stateVar peerAddr = mask_ $ do - traceWith tracer (TrReleaseConnection Inbound peerAddr) + releaseInboundConnectionImpl stateVar connId = mask_ $ do + traceWith tracer (TrReleaseConnection Inbound connId) (mbThread, mbTransition, result, mbAssertion) <- atomically $ do state <- readTMVar stateVar - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> do -- Note: this can happen if the inbound connection manager is -- notified late about the connection which has already terminated @@ -1167,7 +1170,7 @@ with args@Arguments { ) Just MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of -- In any of the following two states releasing is not -- supported. 'includeInboundConnection' is a synchronous @@ -1190,7 +1193,7 @@ with args@Arguments { -- TimeoutExpired : OutboundState^\tau Duplex -- → OutboundState Duplex -- @ - OutboundDupState connId connThread handle Ticking -> do + OutboundDupState _connId connThread handle Ticking -> do let connState' = OutboundDupState connId connThread handle Expired writeTVar connVar connState' return ( Nothing @@ -1198,7 +1201,7 @@ with args@Arguments { , OperationSuccess KeepTr , Nothing ) - OutboundDupState connId _connThread _handle Expired -> + OutboundDupState _connId _connThread _handle Expired -> assert False $ return ( Nothing , Nothing @@ -1218,7 +1221,7 @@ with args@Arguments { -- unexpected state, this state is reachable only from outbound -- states - OutboundIdleState connId _connThread _handle _dataFlow -> + OutboundIdleState _connId _connThread _handle _dataFlow -> return ( Nothing , Nothing , OperationSuccess CommitTr @@ -1234,7 +1237,7 @@ with args@Arguments { -- @ -- -- Note: the 'TrDemotedToColdRemote' is logged by the server. - InboundIdleState connId connThread _handle _dataFlow -> do + InboundIdleState _connId connThread _handle _dataFlow -> do let connState' = TerminatingState connId connThread Nothing writeTVar connVar connState' return ( Just connThread @@ -1245,7 +1248,7 @@ with args@Arguments { -- the inbound protocol governor was supposed to call -- 'demotedToColdRemote' first. - InboundState connId connThread _handle _dataFlow -> do + InboundState _connId connThread _handle _dataFlow -> do let connState' = TerminatingState connId connThread Nothing writeTVar connVar connState' return ( Just connThread @@ -1259,7 +1262,7 @@ with args@Arguments { -- the inbound connection governor ought to call -- 'demotedToColdRemote' first. - DuplexState connId connThread handle -> do + DuplexState _connId connThread handle -> do let connState' = OutboundDupState connId connThread handle Ticking writeTVar connVar connState' return ( Nothing @@ -1290,7 +1293,7 @@ with args@Arguments { , Nothing ) - traverse_ (traceWith trTracer . TransitionTrace peerAddr) mbTransition + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) mbTransition traceCounters stateVar -- 'throwTo' avoids blocking until 'timeWaitTimeout' expires. @@ -1308,19 +1311,21 @@ with args@Arguments { :: HasCallStack => FreshIdSupply m -> StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) + -> StrictTVar m StdGen -> ConnectionHandlerFn handlerTrace socket peerAddr handle handleError (version, versionData) m -> peerAddr -> m (Connected peerAddr handle handleError) - acquireOutboundConnectionImpl freshIdSupply stateVar handler peerAddr = do + acquireOutboundConnectionImpl freshIdSupply stateVar stdGenVar handler peerAddr = do let provenance = Outbound traceWith tracer (TrIncludeConnection provenance peerAddr) (trace, mutableConnState@MutableConnState { connVar } , eHandleWedge) <- atomically $ do state <- readTMVar stateVar - case Map.lookup peerAddr state of + stdGen <- stateTVar stdGenVar split + case State.lookupByRemoteAddr stdGen peerAddr state of Just mutableConnState@MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of ReservedOutboundState -> return ( Just (Right (TrConnectionExists provenance peerAddr st)) @@ -1360,7 +1365,7 @@ with args@Arguments { ) OutboundIdleState _connId _connThread _handle _dataFlow -> - let tr = abstractState (Known connState) in + let tr = State.abstractState (Known connState) in return ( Just (Right (TrForbiddenOperation peerAddr tr)) , mutableConnState , Left (withCallStack (ForbiddenOperation peerAddr tr)) @@ -1432,23 +1437,16 @@ with args@Arguments { let connState' = ReservedOutboundState (mutableConnState :: MutableConnState peerAddr handle handleError version m) - <- newMutableConnState peerAddr freshIdSupply connState' + <- State.newMutableConnState peerAddr freshIdSupply connState' -- TODO: label `connVar` using 'ConnectionId' labelTVar (connVar mutableConnState) ("conn-state-" ++ show peerAddr) - -- record the @connVar@ in 'ConnectionManagerState' we can use - -- 'swapTMVar' as we did not use 'takeTMVar' at the beginning of - -- this transaction. Since we already 'readTMVar', it will not - -- block. - (mbConnState - :: Maybe (ConnectionState peerAddr handle handleError version m)) - <- swapTMVar stateVar - (Map.insert peerAddr mutableConnState state) - >>= traverse (readTVar . connVar) . Map.lookup peerAddr + writeTMVar stateVar + (State.insertUnknownLocalAddr peerAddr mutableConnState state) return ( Just (Left (TransitionTrace peerAddr Transition { - fromState = maybe Unknown Known mbConnState, + fromState = Unknown, toState = Known connState' })) , mutableConnState @@ -1492,51 +1490,17 @@ with args@Arguments { (\socket -> uninterruptibleMask_ $ do close snocket socket trs <- atomically $ modifyTMVarSTM stateVar $ \state -> do - case Map.lookup peerAddr state of - -- Lookup failed, which means connection was already - -- removed. So we just update the connVar and trace - -- accordingly. - Nothing -> do - connState <- readTVar connVar - let connState' = TerminatedState Nothing - writeTVar connVar connState' - return - ( state - , [ mkTransition connState connState' - , Transition (Known connState') Unknown - ] - ) - - -- Current connVar. - Just mutableConnState' -> do - connState <- readTVar connVar - case connState of - -- Update the state only if the connection was in - -- 'ReservedOutboundState'. This covers the case - -- when we connect to ourselves, in which case: we - -- first set the connection state to - -- `ReservedOutboundState`, then race connect - -- & accept calls. If the connection was - -- accepted it, it will use the same - -- 'MutableConnState', and if the inbound side is - -- using the connection the state will be - -- different than `ReservedOutboundState`. - ReservedOutboundState | mutableConnState' == mutableConnState -> do - let state' = Map.delete peerAddr state - connState' = TerminatedState Nothing - writeTVar connVar connState' - return - ( state' - , [ mkTransition connState connState' - , Transition (Known connState') - Unknown - ] - ) - - -- self connection: the connection might have - -- been accepted, in such case do not modify its - -- state. - _ -> return (state, []) + connState <- readTVar connVar + let state' = State.deleteAtRemoteAddr peerAddr mutableConnState state + connState' = TerminatedState Nothing + writeTVar connVar connState' + return + ( state' + , [ mkTransition connState connState' + , Transition (Known connState') + Unknown + ] + ) traverse_ (traceWith trTracer . TransitionTrace peerAddr) trs traceCounters stateVar @@ -1568,6 +1532,21 @@ with args@Arguments { let connId = ConnectionId { localAddress , remoteAddress = peerAddr } + updated <- atomically $ modifyTMVarPure stateVar (swap . State.updateLocalAddr connId) + unless updated $ + -- there exists a connection with exact same + -- `ConnectionId` + -- + -- NOTE: + -- When we are connecting from our own `(ip, port)` to + -- itself. In this case on linux, the `connect` + -- returns, while `accept` doesn't. The outbound + -- socket is connected to itself (simultaneuos TCP + -- open?). Since the `accept` call never returns, the + -- `connId` slot must have been available, and thus + -- `State.updateLocalAddr` must have returned `True`. + throwIO (withCallStack $ ConnectionExists provenance peerAddr) + return (socket, connId) -- @@ -1616,7 +1595,7 @@ with args@Arguments { , Just (TrUnexpectedlyFalseAssertion (AcquireOutboundConnection (Just connId) - (abstractState (Known connState)) + (State.abstractState (Known connState)) ) ) ) @@ -1701,7 +1680,7 @@ with args@Arguments { TerminatedState _ -> return Nothing _ -> - let st = abstractState (Known connState) in + let st = State.abstractState (Known connState) in throwSTM (withCallStack (ForbiddenOperation peerAddr st)) traverse_ (traceWith trTracer . TransitionTrace peerAddr) mbTransition @@ -1735,7 +1714,7 @@ with args@Arguments { throwSTM (withCallStack (ConnectionExists provenance connId)) OutboundIdleState _connId _connThread _handle _dataFlow -> - let tr = abstractState (Known connState) in + let tr = State.abstractState (Known connState) in throwSTM (withCallStack (ForbiddenOperation peerAddr tr)) InboundIdleState _connId connThread handle dataFlow@Duplex -> do @@ -1830,7 +1809,7 @@ with args@Arguments { Nothing -> TerminatedState handleErrorM transition = mkTransition connState connState' - absConnState = abstractState (Known connState) + absConnState = State.abstractState (Known connState) shouldTrace = absConnState /= TerminatedSt -- 'handleError' might be either a handshake negotiation @@ -1843,12 +1822,12 @@ with args@Arguments { modifyTMVarPure stateVar ( \state -> - case Map.lookup peerAddr state of + case State.lookup connId state of Nothing -> (state, False) Just mutableConnState' -> if mutableConnState' == mutableConnState - then (Map.delete peerAddr state , True) - else (state , False) + then (State.delete connId state, True) + else (state , False) ) if updated @@ -1890,14 +1869,14 @@ with args@Arguments { :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) -> StrictTVar m StdGen - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult AbstractState) - releaseOutboundConnectionImpl stateVar stdGenVar peerAddr = do - traceWith tracer (TrReleaseConnection Outbound peerAddr) + releaseOutboundConnectionImpl stateVar stdGenVar connId = do + traceWith tracer (TrReleaseConnection Outbound connId) (transition, mbAssertion) <- atomically $ do state <- readTMVar stateVar - case Map.lookup peerAddr state of + case State.lookup connId state of -- if the connection errored, it will remove itself from the state. -- Calling 'releaseOutboundConnection' is a no-op in this case. Nothing -> pure ( DemoteToColdLocalNoop Nothing UnknownConnectionSt @@ -1905,7 +1884,7 @@ with args@Arguments { Just MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of -- In any of the following three states releaseing is not -- supported. 'acquireOutboundConnection' is a synchronous @@ -1914,20 +1893,20 @@ with args@Arguments { ReservedOutboundState -> return ( DemoteToColdLocalError - (TrForbiddenOperation peerAddr st) + (TrForbiddenOperation (remoteAddress connId) st) st , Nothing ) - UnnegotiatedState _ _ _ -> + UnnegotiatedState {} -> return ( DemoteToColdLocalError - (TrForbiddenOperation peerAddr st) + (TrForbiddenOperation (remoteAddress connId) st) st , Nothing ) - OutboundUniState connId connThread handle -> do + OutboundUniState _connId connThread handle -> do -- @ -- DemotedToCold^{Unidirectional}_{Local} -- : OutboundState Unidirectional @@ -1941,7 +1920,7 @@ with args@Arguments { , Nothing ) - OutboundDupState connId connThread handle Expired -> do + OutboundDupState _connId connThread handle Expired -> do -- @ -- DemotedToCold^{Duplex}_{Local} -- : OutboundState Duplex @@ -1955,7 +1934,7 @@ with args@Arguments { , Nothing ) - OutboundDupState connId connThread handle Ticking -> do + OutboundDupState _connId connThread handle Ticking -> do let connState' = InboundIdleState connId connThread handle Duplex tr = mkTransition connState connState' @@ -1973,7 +1952,7 @@ with args@Arguments { if numberToPrune > 0 then do (_, prune) - <- mkPruneAction peerAddr numberToPrune state connState' connVar stdGenVar connThread + <- mkPruneAction connId numberToPrune state connState' connVar stdGenVar connThread return ( PruneConnections prune (Left connState) , Nothing @@ -2004,7 +1983,7 @@ with args@Arguments { return ( DemoteToColdLocalNoop Nothing st , Nothing ) - InboundState connId _connThread _handle dataFlow -> do + InboundState _connId _connThread _handle dataFlow -> do let mbAssertion = if dataFlow == Duplex then Nothing @@ -2015,12 +1994,12 @@ with args@Arguments { ) return ( DemoteToColdLocalError - (TrForbiddenOperation peerAddr st) + (TrForbiddenOperation (remoteAddress connId) st) st , mbAssertion ) - DuplexState connId connThread handle -> do + DuplexState _connId connThread handle -> do -- @ -- DemotedToCold^{Duplex}_{Local} : DuplexState -- → InboundState Duplex @@ -2055,8 +2034,8 @@ with args@Arguments { pure () case transition of - DemotedToColdLocal connId connThread connVar tr -> do - traceWith trTracer (TransitionTrace peerAddr tr) + DemotedToColdLocal _connId connThread connVar tr -> do + traceWith trTracer (TransitionTrace (remoteAddress connId) tr) traceCounters stateVar timeoutVar <- registerDelay outboundIdleTimeout r <- atomically $ runFirstToFinish $ @@ -2075,7 +2054,7 @@ with args@Arguments { Right connState -> do let connState' = TerminatingState connId connThread Nothing atomically $ writeTVar connVar connState' - traceWith trTracer (TransitionTrace peerAddr + traceWith trTracer (TransitionTrace (remoteAddress connId) (mkTransition connState connState')) traceCounters stateVar -- We rely on the `finally` handler of connection thread to: @@ -2085,26 +2064,26 @@ with args@Arguments { -- - 'throwTo' avoids blocking until 'timeWaitTimeout' expires. throwTo (asyncThreadId connThread) AsyncCancelled - return (OperationSuccess (abstractState $ Known connState')) + return (OperationSuccess (State.abstractState $ Known connState')) - Left connState | connectionTerminated connState + Left connState | State.connectionTerminated connState -> - return (OperationSuccess (abstractState $ Known connState)) + return (OperationSuccess (State.abstractState $ Known connState)) Left connState -> - return (UnsupportedState (abstractState $ Known connState)) + return (UnsupportedState (State.abstractState $ Known connState)) PruneConnections prune eTr -> do - traverse_ (traceWith trTracer . TransitionTrace peerAddr) eTr + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) eTr runPruneAction prune traceCounters stateVar - return (OperationSuccess (abstractState (either Known fromState eTr))) + return (OperationSuccess (State.abstractState (either Known fromState eTr))) DemoteToColdLocalError trace st -> do traceWith tracer trace return (UnsupportedState st) DemoteToColdLocalNoop tr a -> do - traverse_ (traceWith trTracer) (TransitionTrace peerAddr <$> tr) + traverse_ (traceWith trTracer . TransitionTrace (remoteAddress connId)) tr traceCounters stateVar return (OperationSuccess a) @@ -2115,12 +2094,12 @@ with args@Arguments { promotedToWarmRemoteImpl :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) -> StrictTVar m StdGen - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult AbstractState) - promotedToWarmRemoteImpl stateVar stdGenVar peerAddr = mask_ $ do + promotedToWarmRemoteImpl stateVar stdGenVar connId = mask_ $ do (result, pruneTr, mbAssertion) <- atomically $ do state <- readTMVar stateVar - let mbConnVar = Map.lookup peerAddr state + let mbConnVar = State.lookup connId state case mbConnVar of Nothing -> return ( UnsupportedState UnknownConnectionSt , Nothing @@ -2128,7 +2107,7 @@ with args@Arguments { ) Just MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of ReservedOutboundState {} -> do return ( UnsupportedState st @@ -2139,7 +2118,7 @@ with args@Arguments { st) ) ) - UnnegotiatedState _ connId _ -> + UnnegotiatedState _ _connId _ -> return ( UnsupportedState st , Nothing , Just (TrUnexpectedlyFalseAssertion @@ -2148,7 +2127,7 @@ with args@Arguments { st) ) ) - OutboundUniState connId _connThread _handle -> + OutboundUniState _connId _connThread _handle -> return ( UnsupportedState st , Nothing , Just (TrUnexpectedlyFalseAssertion @@ -2157,7 +2136,7 @@ with args@Arguments { st) ) ) - OutboundDupState connId connThread handle _expired -> do + OutboundDupState _connId connThread handle _expired -> do -- @ -- PromotedToWarm^{Duplex}_{Remote} : OutboundState Duplex -- → DuplexState @@ -2187,7 +2166,7 @@ with args@Arguments { if numberToPrune > 0 then do (pruneSelf, prune) - <- mkPruneAction peerAddr numberToPrune state connState' connVar stdGenVar connThread + <- mkPruneAction connId numberToPrune state connState' connVar stdGenVar connThread when (not pruneSelf) $ writeTVar connVar connState' @@ -2204,7 +2183,7 @@ with args@Arguments { , Nothing , Nothing ) - OutboundIdleState connId connThread handle dataFlow@Duplex -> do + OutboundIdleState _connId connThread handle dataFlow@Duplex -> do -- @ -- Awake^{Duplex}_{Remote} : OutboundIdleState^\tau Duplex -- → InboundState Duplex @@ -2226,7 +2205,7 @@ with args@Arguments { if numberToPrune > 0 then do (pruneSelf, prune) - <- mkPruneAction peerAddr numberToPrune state connState' connVar stdGenVar connThread + <- mkPruneAction connId numberToPrune state connState' connVar stdGenVar connThread when (not pruneSelf) $ writeTVar connVar connState' @@ -2247,7 +2226,7 @@ with args@Arguments { , Nothing , Nothing ) - InboundIdleState connId connThread handle dataFlow -> do + InboundIdleState _connId connThread handle dataFlow -> do -- @ -- Awake^{dataFlow}_{Remote} : InboundIdleState Duplex -- → InboundState Duplex @@ -2258,7 +2237,7 @@ with args@Arguments { , Nothing , Nothing ) - InboundState connId _ _ _ -> + InboundState _connId _ _ _ -> return ( OperationSuccess (mkTransition connState connState) , Nothing -- already in 'InboundState'? @@ -2292,32 +2271,32 @@ with args@Arguments { -- trace transition case (result, pruneTr) of (OperationSuccess tr, Nothing) -> do - traceWith trTracer (TransitionTrace peerAddr tr) + traceWith trTracer (TransitionTrace (remoteAddress connId) tr) traceCounters stateVar (OperationSuccess tr, Just prune) -> do - traceWith trTracer (TransitionTrace peerAddr tr) + traceWith trTracer (TransitionTrace (remoteAddress connId) tr) runPruneAction prune traceCounters stateVar _ -> return () - return (abstractState . fromState <$> result) + return (State.abstractState . fromState <$> result) demotedToColdRemoteImpl :: StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m) - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult AbstractState) - demotedToColdRemoteImpl stateVar peerAddr = do + demotedToColdRemoteImpl stateVar connId = do (result, mbAssertion) <- atomically $ do - mbConnVar <- Map.lookup peerAddr <$> readTMVar stateVar + mbConnVar <- State.lookup connId <$> readTMVar stateVar case mbConnVar of Nothing -> return ( UnsupportedState UnknownConnectionSt , Nothing ) Just MutableConnState { connVar } -> do connState <- readTVar connVar - let st = abstractState (Known connState) + let st = State.abstractState (Known connState) case connState of ReservedOutboundState {} -> do return ( UnsupportedState st @@ -2327,7 +2306,7 @@ with args@Arguments { st) ) ) - UnnegotiatedState _ connId _ -> + UnnegotiatedState _ _connId _ -> return ( UnsupportedState st , Just (TrUnexpectedlyFalseAssertion (DemotedToColdRemote @@ -2335,7 +2314,7 @@ with args@Arguments { st) ) ) - OutboundUniState connId _connThread _handle -> + OutboundUniState _connId _connThread _handle -> return ( UnsupportedState st , Just (TrUnexpectedlyFalseAssertion (DemotedToColdRemote @@ -2363,7 +2342,7 @@ with args@Arguments { -- : InboundState dataFlow -- → InboundIdleState^\tau dataFlow -- @ - InboundState connId connThread handle dataFlow -> do + InboundState _connId connThread handle dataFlow -> do let connState' = InboundIdleState connId connThread handle dataFlow writeTVar connVar connState' return ( OperationSuccess (mkTransition connState connState') @@ -2375,7 +2354,7 @@ with args@Arguments { -- : DuplexState -- → OutboundState^\tau Duplex -- @ - DuplexState connId connThread handle -> do + DuplexState _connId connThread handle -> do let connState' = OutboundDupState connId connThread handle Ticking writeTVar connVar connState' return ( OperationSuccess (mkTransition connState connState') @@ -2399,11 +2378,11 @@ with args@Arguments { -- trace transition case result of OperationSuccess tr -> - traceWith trTracer (TransitionTrace peerAddr tr) + traceWith trTracer (TransitionTrace (remoteAddress connId) tr) _ -> return () traceCounters stateVar - return (abstractState . fromState <$> result) + return (State.abstractState . fromState <$> result) -- @@ -2466,7 +2445,7 @@ withCallStack k = k callStack -- data Trace peerAddr handlerTrace = TrIncludeConnection Provenance peerAddr - | TrReleaseConnection Provenance peerAddr + | TrReleaseConnection Provenance (ConnectionId peerAddr) | TrConnect (Maybe peerAddr) -- ^ local address peerAddr -- ^ remote address | TrConnectError (Maybe peerAddr) -- ^ local address @@ -2481,14 +2460,14 @@ data Trace peerAddr handlerTrace | TrConnectionFailure (ConnectionId peerAddr) | TrConnectionNotFound Provenance peerAddr | TrForbiddenOperation peerAddr AbstractState - | TrPruneConnections (Set peerAddr) -- ^ pruning set + | TrPruneConnections (Set (ConnectionId peerAddr)) -- ^ pruning set Int -- ^ number connections that must be pruned - (Set peerAddr) -- ^ choice set + (Set (ConnectionId peerAddr)) -- ^ choice set | TrConnectionCleanup (ConnectionId peerAddr) | TrConnectionTimeWait (ConnectionId peerAddr) | TrConnectionTimeWaitDone (ConnectionId peerAddr) | TrConnectionManagerCounters ConnectionManagerCounters - | TrState (Map peerAddr AbstractState) + | TrState (State.ConnMap peerAddr AbstractState) -- ^ traced on SIGUSR1 signal, installed in 'runDataDiffusion' | TrUnexpectedlyFalseAssertion (AssertionLocation peerAddr) -- ^ This case is unexpected at call site. diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs index 8b9e91389a..427c095597 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs @@ -6,25 +6,32 @@ {-# LANGUAGE QuantifiedConstraints #-} module Ouroboros.Network.ConnectionManager.State - ( ConnectionManagerState + ( -- * ConnectionManagerState API + ConnectionManagerState + , module ConnMap + -- ** Monadic API + , readConnectionStates + , readAbstractStateMap + -- * MutableConnState , MutableConnState (..) , FreshIdSupply , newFreshIdSupply , newMutableConnState - , abstractState + -- * ConnectionState , ConnectionState (..) + , abstractState , connectionTerminated ) where -import Control.Monad.Class.MonadAsync import Control.Concurrent.Class.MonadSTM.Strict +import Control.Monad.Class.MonadAsync import Data.Function (on) -import Data.Map.Strict (Map) -import Data.Maybe (maybeToList) import Data.Proxy (Proxy (..)) import Data.Typeable (Typeable) +import Prelude hiding (lookup) import Ouroboros.Network.ConnectionId +import Ouroboros.Network.ConnectionManager.ConnMap as ConnMap import Ouroboros.Network.ConnectionManager.Types import Ouroboros.Network.Testing.Utils (WithName (..)) @@ -33,13 +40,22 @@ import Ouroboros.Network.Testing.Utils (WithName (..)) -- a mutable variable, which reduces congestion on the 'TMVar' which keeps -- 'ConnectionManagerState'. -- --- It is important we can lookup by remote @peerAddr@; this way we can find if --- the connection manager is already managing a connection towards that --- @peerAddr@ and reuse the 'ConnectionState'. --- -type ConnectionManagerState peerAddr handle handleError version m - = Map peerAddr (MutableConnState peerAddr handle handleError version m) +type ConnectionManagerState peerAddr handle handleError version m = + ConnMap peerAddr (MutableConnState peerAddr handle handleError version m) + + +readConnectionStates + :: MonadSTM m + => ConnectionManagerState peerAddr handle handleError version m + -> STM m (ConnMap peerAddr (ConnectionState peerAddr handle handleError version m)) +readConnectionStates = traverse (readTVar . connVar) + +readAbstractStateMap + :: MonadSTM m + => ConnectionManagerState peerAddr handle handleError version m + -> STM m (ConnMap peerAddr AbstractState) +readAbstractStateMap = traverse (fmap (abstractState . Known) . readTVar . connVar) -- | 'MutableConnState', which supplies a unique identifier. -- @@ -131,7 +147,8 @@ newMutableConnState peerAddr freshIdSupply connState = do return $ MutableConnState { connStateId, connVar } -abstractState :: MaybeUnknown (ConnectionState muxMode peerAddr m a b) -> AbstractState +abstractState :: MaybeUnknown (ConnectionState muxMode peerAddr m a b) + -> AbstractState abstractState = \case Unknown -> UnknownConnectionSt Race s' -> go s' @@ -182,7 +199,7 @@ data ConnectionState peerAddr handle handleError version m = instance ( Show peerAddr - , Show handleError + -- , Show handleError , MonadAsync m ) => Show (ConnectionState peerAddr handle handleError version m) where @@ -239,16 +256,16 @@ instance ( Show peerAddr , " " , show (asyncThreadId connThread) ] - show (TerminatingState connId connThread handleError) = + show (TerminatingState connId connThread _handleError) = concat ([ "TerminatingState " , show connId , " " , show (asyncThreadId connThread) ] - ++ maybeToList ((' ' :) . show <$> handleError)) - show (TerminatedState handleError) = + )-- ++ maybeToList ((' ' :) . show <$> handleError)) + show (TerminatedState _handleError) = concat (["TerminatedState"] - ++ maybeToList ((' ' :) . show <$> handleError)) + )-- ++ maybeToList ((' ' :) . show <$> handleError)) -- | Return 'True' for states in which the connection was already closed. diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs index cb981524c9..0e51c54129 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs @@ -177,7 +177,8 @@ import System.Random (StdGen) import Network.Mux.Types (HasInitiator, HasResponder, MiniProtocolDir) import Network.Mux.Types qualified as Mux -import Ouroboros.Network.ConnectionId (ConnectionId) +import Ouroboros.Network.ConnectionId (ConnectionId (..)) +import Ouroboros.Network.ConnectionManager.ConnMap (ConnMap) import Ouroboros.Network.MuxMode @@ -423,9 +424,9 @@ data HandleErrorType = -- connections. -- type PrunePolicy peerAddr = StdGen - -> Map peerAddr ConnectionType + -> Map (ConnectionId peerAddr) ConnectionType -> Int - -> Set peerAddr + -> Set (ConnectionId peerAddr) -- | The simplest 'PrunePolicy', it should only be used for tests. @@ -500,7 +501,7 @@ type IncludeInboundConnection socket peerAddr handle handleError m -- ^ inbound connections hard limit. -- NOTE: Check TODO over at includeInboundConnectionImpl -- definition. - -> socket -> peerAddr -> m (Connected peerAddr handle handleError) + -> socket -> ConnectionId peerAddr -> m (Connected peerAddr handle handleError) -- | Outbound connection manager API. @@ -509,7 +510,7 @@ data OutboundConnectionManager (muxMode :: Mux.Mode) socket peerAddr handle hand OutboundConnectionManager :: HasInitiator muxMode ~ True => { ocmAcquireConnection :: AcquireOutboundConnection peerAddr handle handleError m - , ocmReleaseConnection :: peerAddr -> m (OperationResult AbstractState) + , ocmReleaseConnection :: ConnectionId peerAddr -> m (OperationResult AbstractState) } -> OutboundConnectionManager muxMode socket peerAddr handle handleError m @@ -522,10 +523,10 @@ data InboundConnectionManager (muxMode :: Mux.Mode) socket peerAddr handle handl InboundConnectionManager :: HasResponder muxMode ~ True => { icmIncludeConnection :: IncludeInboundConnection socket peerAddr handle handleError m - , icmReleaseConnection :: peerAddr -> m (OperationResult DemotedToColdRemoteTr) - , icmPromotedToWarmRemote :: peerAddr -> m (OperationResult AbstractState) + , icmReleaseConnection :: ConnectionId peerAddr -> m (OperationResult DemotedToColdRemoteTr) + , icmPromotedToWarmRemote :: ConnectionId peerAddr -> m (OperationResult AbstractState) , icmDemotedToColdRemote - :: peerAddr -> m (OperationResult AbstractState) + :: ConnectionId peerAddr -> m (OperationResult AbstractState) , icmNumberOfConnections :: STM m Int } -> InboundConnectionManager muxMode socket peerAddr handle handleError m @@ -548,13 +549,13 @@ data ConnectionManager (muxMode :: Mux.Mode) socket peerAddr handle handleError (InboundConnectionManager muxMode socket peerAddr handle handleError m), readState - :: STM m (Map peerAddr AbstractState), + :: STM m (ConnMap peerAddr AbstractState), -- | This STM action will block until the given connection is fully -- closed/terminated. If the connection manager doesn't have any connection to -- that peer it won't block. waitForOutboundDemotion - :: peerAddr + :: ConnectionId peerAddr -> STM m () } @@ -584,7 +585,7 @@ acquireOutboundConnection = releaseOutboundConnection :: HasInitiator muxMode ~ True => ConnectionManager muxMode socket peerAddr handle handleError m - -> peerAddr + -> ConnectionId peerAddr -> m (OperationResult AbstractState) -- ^ reports the from-state. releaseOutboundConnection = @@ -603,7 +604,7 @@ releaseOutboundConnection = promotedToWarmRemote :: HasResponder muxMode ~ True => ConnectionManager muxMode socket peerAddr handle handleError m - -> peerAddr -> m (OperationResult AbstractState) + -> ConnectionId peerAddr -> m (OperationResult AbstractState) promotedToWarmRemote = icmPromotedToWarmRemote . withResponderMode . getConnectionManager @@ -619,7 +620,7 @@ promotedToWarmRemote = demotedToColdRemote :: HasResponder muxMode ~ True => ConnectionManager muxMode socket peerAddr handle handleError m - -> peerAddr -> m (OperationResult AbstractState) + -> ConnectionId peerAddr -> m (OperationResult AbstractState) demotedToColdRemote = icmDemotedToColdRemote . withResponderMode . getConnectionManager @@ -636,7 +637,7 @@ includeInboundConnection includeInboundConnection = icmIncludeConnection . withResponderMode . getConnectionManager --- | Release outbound connection. Returns if the operation was successful. +-- | Release inbound connection. Returns if the operation was successful. -- -- This executes: -- @@ -645,7 +646,7 @@ includeInboundConnection = releaseInboundConnection :: HasResponder muxMode ~ True => ConnectionManager muxMode socket peerAddr handle handleError m - -> peerAddr -> m (OperationResult DemotedToColdRemoteTr) + -> ConnectionId peerAddr -> m (OperationResult DemotedToColdRemoteTr) releaseInboundConnection = icmReleaseConnection . withResponderMode . getConnectionManager @@ -838,10 +839,10 @@ connectionManagerErrorFromException x = do -- data AssertionLocation peerAddr = ReleaseInboundConnection !(Maybe (ConnectionId peerAddr)) !AbstractState - | AcquireOutboundConnection !(Maybe (ConnectionId peerAddr)) !AbstractState + | AcquireOutboundConnection !(Maybe (ConnectionId peerAddr)) !AbstractState | ReleaseOutboundConnection !(Maybe (ConnectionId peerAddr)) !AbstractState - | PromotedToWarmRemote !(Maybe (ConnectionId peerAddr)) !AbstractState - | DemotedToColdRemote !(Maybe (ConnectionId peerAddr)) !AbstractState + | PromotedToWarmRemote !(Maybe (ConnectionId peerAddr)) !AbstractState + | DemotedToColdRemote !(Maybe (ConnectionId peerAddr)) !AbstractState deriving Show @@ -888,7 +889,7 @@ mkAbsTransition from to = Transition { fromState = from } data TransitionTrace' peerAddr state = TransitionTrace - { ttPeerAddr :: peerAddr + { ttPeerAddr :: peerAddr -- TODO: use ConnectionId , ttTransition :: Transition' state } deriving Functor diff --git a/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs b/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs index b4827b2948..3fb30be73d 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs @@ -403,8 +403,7 @@ with -- @ -- NOTE: `demotedToColdRemote` doesn't throw, hence exception handling -- is not needed. - res <- demotedToColdRemote connectionManager - (remoteAddress connId) + res <- demotedToColdRemote connectionManager connId traceWith tracer (TrWaitIdleRemote connId res) case res of TerminatedConnection {} -> do @@ -445,8 +444,7 @@ with -- -- NOTE: `promotedToWarmRemote` doesn't throw, hence exception handling -- is not needed. - res <- promotedToWarmRemote connectionManager - (remoteAddress connId) + res <- promotedToWarmRemote connectionManager connId traceWith tracer (TrPromotedToWarmRemote connId res) when (resultInState res == UnknownConnectionSt) $ do @@ -479,8 +477,7 @@ with CommitRemote connId -> do -- NOTE: `releaseInboundConnection` doesn't throw, hence exception -- handling is not needed. - res <- releaseInboundConnection connectionManager - (remoteAddress connId) + res <- releaseInboundConnection connectionManager connId traceWith tracer $ TrDemotedToColdRemote connId res case res of UnsupportedState {} -> do diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Server2.hs b/ouroboros-network-framework/src/Ouroboros/Network/Server2.hs index 4ceda99fb2..f8d65d4084 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Server2.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Server2.hs @@ -49,6 +49,7 @@ import Foreign.C.Error import Network.Mux qualified as Mx import Ouroboros.Network.ConnectionHandler +import Ouroboros.Network.ConnectionId (ConnectionId (..)) import Ouroboros.Network.ConnectionManager.InformationChannel (InboundGovernorInfoChannel) import Ouroboros.Network.ConnectionManager.Types @@ -236,15 +237,19 @@ with Arguments { -- no need to use a rethrow policy _ -> throwIO err - (Accepted socket peerAddr, acceptNext) -> + (Accepted socket remoteAddress, acceptNext) -> (do - traceWith tracer (TrAcceptConnection peerAddr) + localAddress' <- getLocalAddr snocket socket + let connId = ConnectionId { localAddress = localAddress', + remoteAddress } + traceWith tracer (TrAcceptConnection connId) async $ - do a <- + do + a <- unmask (includeInboundConnection connectionManager - hardLimit socket peerAddr) + hardLimit socket connId) case a of Connected {} -> pure () Disconnected {} -> close snocket socket @@ -281,7 +286,7 @@ isECONNABORTED _ = False -- data Trace peerAddr - = TrAcceptConnection peerAddr + = TrAcceptConnection (ConnectionId peerAddr) | TrAcceptError SomeException | TrAcceptPolicyTrace AcceptConnectionsPolicyTrace | TrServerStarted [peerAddr] diff --git a/ouroboros-network/src/Ouroboros/Network/PeerSelection/PeerStateActions.hs b/ouroboros-network/src/Ouroboros/Network/PeerSelection/PeerStateActions.hs index 94f0a016e2..fec2e79be8 100644 --- a/ouroboros-network/src/Ouroboros/Network/PeerSelection/PeerStateActions.hs +++ b/ouroboros-network/src/Ouroboros/Network/PeerSelection/PeerStateActions.hs @@ -637,7 +637,7 @@ withPeerStateActions PeerStateActionsArguments { PeerCold -> return Nothing PeerCooling -> do - waitForOutboundDemotion spsConnectionManager (remoteAddress pchConnectionId) + waitForOutboundDemotion spsConnectionManager pchConnectionId writeTVar pchPeerStatus PeerCold return Nothing _ -> @@ -712,14 +712,12 @@ withPeerStateActions PeerStateActionsArguments { -- wrong). Just (WithSomeProtocolTemperature (WithWarm MiniProtocolSuccess {})) -> do isCooling <- closePeerConnection pch - if isCooling - then peerMonitoringLoop pch - else return () + when isCooling + $ peerMonitoringLoop pch Just (WithSomeProtocolTemperature (WithEstablished MiniProtocolSuccess {})) -> do isCooling <- closePeerConnection pch - if isCooling - then peerMonitoringLoop pch - else return () + when isCooling + $ peerMonitoringLoop pch Nothing -> traceWith spsTracer (PeerStatusChanged (CoolingToCold pchConnectionId)) @@ -744,7 +742,7 @@ withPeerStateActions PeerStateActionsArguments { >> throwIO e ShutdownPeer -> throwIO e - Right (Connected connectionId@ConnectionId { localAddress, remoteAddress } + Right (Connected connId@ConnectionId { localAddress, remoteAddress } _dataFlow (Handle mux muxBundle controlMessageBundle versionData)) -> do @@ -757,7 +755,7 @@ withPeerStateActions PeerStateActionsArguments { let connHandle = PeerConnectionHandle { - pchConnectionId = connectionId, + pchConnectionId = connId, pchPeerStatus = peerStateVar, pchMux = mux, pchAppHandles = mkApplicationHandleBundle @@ -782,9 +780,9 @@ withPeerStateActions PeerStateActionsArguments { Nothing -> Just e) (\e -> do atomically $ do - waitForOutboundDemotion spsConnectionManager remoteAddress + waitForOutboundDemotion spsConnectionManager connId writeTVar peerStateVar PeerCold - traceWith spsTracer (PeerMonitoringError connectionId e) + traceWith spsTracer (PeerMonitoringError connId e) throwIO e) (peerMonitoringLoop connHandle $> Nothing)) (return . Just) @@ -1055,7 +1053,7 @@ withPeerStateActions PeerStateActionsArguments { -- 'unregisterOutboundConnection' could only fail to demote the peer if -- connection manager would simultaneously promote it, but this is not -- possible. - _ <- releaseOutboundConnection spsConnectionManager (remoteAddress pchConnectionId) + _ <- releaseOutboundConnection spsConnectionManager pchConnectionId wasWarm <- atomically (updateUnlessCoolingOrCold pchPeerStatus PeerCooling) when wasWarm $ traceWith spsTracer (PeerStatusChanged (WarmToCooling pchConnectionId)) From 26da29ef17f5eaecaa986f7be6a3dd27d80bebcb Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 25 Nov 2024 08:34:49 +0100 Subject: [PATCH 06/15] connection-manager: small code refactoring --- .../Ouroboros/Network/ConnectionManager.hs | 6 ++-- .../Ouroboros/Network/ConnectionHandler.hs | 8 ++--- .../Network/ConnectionManager/Core.hs | 30 +++++++++---------- .../Network/ConnectionManager/State.hs | 2 +- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs index d8e04f422c..fcbe6a2745 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs @@ -422,7 +422,7 @@ mkSnocket scheduleMap = do ) . getSchedule - <$> (getScheduleMap scheduleMap) + <$> getScheduleMap scheduleMap v <- newTVarIO inboundSchedule return $ Snocket { getLocalAddr, @@ -449,10 +449,10 @@ mkSnocket scheduleMap = do $> x getLocalAddr (FD v) = - fdLocalAddress <$> atomically (readTVar v) + fdLocalAddress <$> readTVarIO v getRemoteAddr (FD v) = do - mbRemote <- fdRemoteAddress <$> atomically (readTVar v) + mbRemote <- fdRemoteAddress <$> readTVarIO v case mbRemote of Nothing -> throwIO InvalidArgumentError Just addr -> pure addr diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs index 6971e051b3..e2969ab874 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs @@ -411,10 +411,10 @@ makeConnectionHandler muxTracer singMuxMode -- --- | 'ConnectionHandlerTrace' is embedded into 'ConnectionManagerTrace' with --- 'Ouroboros.Network.ConnectionManager.Types.ConnectionHandlerTrace' --- constructor. It already includes 'ConnectionId' so we don't need to take --- care of it here. +-- | 'ConnectionHandlerTrace' is embedded into +-- 'Ouroboros.Network.ConnectionManager.Core.Trace' with +-- 'Ouroboros.Network.ConnectionManager.Types.TrConnectionHandler' constructor. +-- It already includes 'ConnectionId' so we don't need to take care of it here. -- -- TODO: when 'Handshake' will get its own tracer, independent of 'Mux', it -- should be embedded into 'ConnectionHandlerTrace'. diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 837816e2ff..3247d9da59 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -376,9 +376,9 @@ with -- will be closed. -> m a with args@Arguments { - tracer = tracer, - trTracer = trTracer, - muxTracer = muxTracer, + tracer, + trTracer, + muxTracer, ipv4Address, ipv6Address, addressType, @@ -636,7 +636,7 @@ with args@Arguments { -- hits there we will update `connVar`. uninterruptibleMask $ \unmask -> do traceWith tracer (TrConnectionCleanup connId) - eTransition <- modifyTMVar stateVar $ \state -> do + mbTransition <- modifyTMVar stateVar $ \state -> do eTransition <- atomically $ do connState <- readTVar connVar let connState' = TerminatedState Nothing @@ -677,17 +677,16 @@ with args@Arguments { traverse_ (traceWith trTracer) mbTransition close snocket socket return ( state - , Left () + , Nothing ) Right transition -> do close snocket socket return ( state - , Right transition + , Just transition ) - case eTransition of - Left () -> do - + case mbTransition of + Nothing -> do let transition = TransitionTrace peerAddr @@ -695,7 +694,7 @@ with args@Arguments { { fromState = Known (TerminatedState Nothing) , toState = Unknown } - mbTransition <- modifyTMVar stateVar $ \state -> + mbTransition' <- modifyTMVar stateVar $ \state -> case State.lookup connId state of Nothing -> pure (state, Nothing) Just v -> @@ -703,9 +702,10 @@ with args@Arguments { then pure (State.delete connId state , Just transition) else pure (state , Nothing) - traverse_ (traceWith trTracer) mbTransition + traverse_ (traceWith trTracer) mbTransition' traceCounters stateVar - Right transition -> + + Just transition -> do traceWith tracer (TrConnectionTimeWait connId) when (timeWaitTimeout > 0) $ let -- make sure we wait at least 'timeWaitTimeout', we @@ -1075,13 +1075,13 @@ with args@Arguments { case classifyHandleError <$> handleErrorM of Just HandshakeFailure -> TerminatingState connId connThread - handleErrorM + handleErrorM Just HandshakeProtocolViolation -> TerminatedState handleErrorM -- On inbound query, connection is terminating. Nothing -> TerminatingState connId connThread - handleErrorM + handleErrorM transition = mkTransition connState connState' absConnState = State.abstractState (Known connState) shouldTrace = absConnState /= TerminatedSt @@ -1802,7 +1802,7 @@ with args@Arguments { case classifyHandleError <$> handleErrorM of Just HandshakeFailure -> TerminatingState connId connThread - handleErrorM + handleErrorM Just HandshakeProtocolViolation -> TerminatedState handleErrorM -- On outbound query, connection is terminated. diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs index 427c095597..10031f5875 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/State.hs @@ -157,7 +157,7 @@ abstractState = \case go :: ConnectionState muxMode peerAddr m a b -> AbstractState go ReservedOutboundState {} = ReservedOutboundSt go (UnnegotiatedState pr _ _) = UnnegotiatedSt pr - go (OutboundUniState _ _ _) = OutboundUniSt + go OutboundUniState {} = OutboundUniSt go (OutboundDupState _ _ _ te) = OutboundDupSt te go (OutboundIdleState _ _ _ df) = OutboundIdleSt df go (InboundIdleState _ _ _ df) = InboundIdleSt df From e36f9e2b3104a759b22b2b557a12429e0a191fc9 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 25 Nov 2024 16:10:33 +0100 Subject: [PATCH 07/15] connection-manager: updated ouroboros-network-framework tests & demos --- .../demo/connection-manager.hs | 2 +- .../Ouroboros/Network/ConnectionManager.hs | 19 +++++++------ .../Test/Ouroboros/Network/Server2/Sim.hs | 3 ++- .../ConnectionManager/Test/Experiments.hs | 27 ++++++++++++------- 4 files changed, 32 insertions(+), 19 deletions(-) diff --git a/ouroboros-network-framework/demo/connection-manager.hs b/ouroboros-network-framework/demo/connection-manager.hs index 577dc37d7e..f05892a373 100644 --- a/ouroboros-network-framework/demo/connection-manager.hs +++ b/ouroboros-network-framework/demo/connection-manager.hs @@ -485,7 +485,7 @@ bidirectionalExperiment muxBundle res <- releaseOutboundConnection - connectionManager remoteAddr + connectionManager connId case res of UnsupportedState inState -> do traceWith debugTracer ( "initiator-loop" diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs index fcbe6a2745..0e6c9e49ed 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs @@ -830,7 +830,7 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = Right (Just (Disconnected {})) -> pure () - Right (Just (Connected _ _ _)) -> do + Right (Just (Connected connId _ _)) -> do threadDelay (either id id (seActiveDelay conn)) -- if this outbound connection is not -- executed within inbound connection, @@ -850,11 +850,11 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = -- successful. void $ releaseInboundConnection - connectionManager addr + connectionManager connId res <- releaseOutboundConnection - connectionManager addr + connectionManager connId case res of UnsupportedState st -> throwIO (UnsupportedStateError @@ -879,17 +879,20 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = (Accepted fd' addr', acceptNext) -> do thread <- async $ do + localAddress <- getLocalAddr snocket fd' + let connId = ConnectionId { localAddress, + remoteAddress = addr' } labelThisThread ("th-inbound-" ++ show (getTestAddress addr)) Just conn' <- fdScheduleEntry - <$> atomically (readTVar (fdState fd')) + <$> readTVarIO (fdState fd') when (addr /= addr' && seIdx conn /= seIdx conn') $ throwIO (MismatchedScheduleEntry (addr, seIdx conn) (addr', seIdx conn')) _ <- includeInboundConnection - connectionManager maxBound fd' addr + connectionManager maxBound fd' connId t <- getMonotonicTime let activeDelay = either id id (seActiveDelay conn) @@ -902,11 +905,11 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = threadDelay x _ <- promotedToWarmRemote - connectionManager addr + connectionManager connId threadDelay y _ <- demotedToColdRemote - connectionManager addr + connectionManager connId return () ) (threadDelay activeDelay) @@ -930,7 +933,7 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = -- TODO: should we run 'unregisterInboundConnection' depending on 'seActiveDelay' void $ releaseInboundConnection - connectionManager addr + connectionManager connId go (thread : threads) acceptNext conns' (AcceptFailure err, _acceptNext) -> throwIO err diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs index 26980ae994..f38d37391d 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs @@ -884,7 +884,7 @@ multinodeExperiment inboundTrTracer trTracer inboundTracer debugTracer cmTracer m <- readTVar connVar check (Map.member (connId remoteAddr) m) writeTVar connVar (Map.delete (connId remoteAddr) m) - void (releaseOutboundConnection cm remoteAddr) + void (releaseOutboundConnection cm (connId remoteAddr)) go (Map.delete remoteAddr connMap) RunMiniProtocols remoteAddr reqs -> do atomically $ do @@ -1159,6 +1159,7 @@ prop_connection_manager_valid_transition_order (Fixed rnd) serverAcc (ArbDataFlo in tabulate "ConnectionEvents" (map showConnectionEvents events) . counterexample (ppScript mns) . counterexample (Trace.ppTrace show show abstractTransitionEvents) + . counterexample (ppTrace trace) . bifoldMap ( \ case MainReturn {} -> mempty diff --git a/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Experiments.hs b/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Experiments.hs index 0116343488..aa962773b6 100644 --- a/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Experiments.hs +++ b/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Experiments.hs @@ -3,6 +3,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} @@ -735,7 +736,9 @@ unidirectionalExperiment stdGen timeouts snocket makeBearer confSock socket clie (numberOfRounds clientAndServerData) (bracket (acquireOutboundConnection connectionManager serverAddr) - (\_ -> releaseOutboundConnection connectionManager serverAddr) + (\case + Connected connId _ _ -> releaseOutboundConnection connectionManager connId + Disconnected {} -> error "unidirectionalExperiment: impossible happened") (\connHandle -> do case connHandle of Connected connId _ (Handle mux muxBundle controlBundle _ @@ -834,10 +837,13 @@ bidirectionalExperiment (acquireOutboundConnection connectionManager0 localAddr1)) - (\_ -> - releaseOutboundConnection - connectionManager0 - localAddr1) + (\case + Connected connId _ _ -> + releaseOutboundConnection + connectionManager0 + connId + Disconnected {} -> + error "bidirectionalExperiment: impossible happened") (\connHandle -> case connHandle of Connected connId _ (Handle mux muxBundle controlBundle _) -> do @@ -856,10 +862,13 @@ bidirectionalExperiment (acquireOutboundConnection connectionManager1 localAddr0)) - (\_ -> - releaseOutboundConnection - connectionManager1 - localAddr0) + (\case + Connected connId _ _ -> + releaseOutboundConnection + connectionManager1 + connId + Disconnected {} -> + error "ibidirectionalExperiment: impossible happened") (\connHandle -> case connHandle of Connected connId _ (Handle mux muxBundle controlBundle _) -> do From 72a72b0d4efa96c5af425cb5c4549a93cd112786 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 25 Nov 2024 16:11:19 +0100 Subject: [PATCH 08/15] connection-manager: fixed tests --- .../src/Ouroboros/Network/ConnectionManager/Core.hs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 3247d9da59..fa13251d9d 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -1810,13 +1810,14 @@ with args@Arguments { TerminatedState handleErrorM transition = mkTransition connState connState' absConnState = State.abstractState (Known connState) - shouldTrace = absConnState /= TerminatedSt + shouldTransition = absConnState /= TerminatedSt -- 'handleError' might be either a handshake negotiation -- a protocol failure (an IO exception, a timeout or -- codec failure). In the first case we should not reset -- the connection as this is not a protocol error. - writeTVar connVar connState' + when shouldTransition $ do + writeTVar connVar connState' updated <- modifyTMVarPure @@ -1835,7 +1836,7 @@ with args@Arguments { -- Key was present in the dictionary (stateVar) and -- removed so we trace the removal. return $ - if shouldTrace + if shouldTransition then [ transition , Transition { fromState = Known (TerminatedState Nothing) From 0e53e1934170ad99be307e1438aeb86dff45a52b Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 25 Nov 2024 16:12:39 +0100 Subject: [PATCH 09/15] connection-manager: added TODO Can we remove the transition tracer, since we use `traceTVar` in `newMutableConnState`? --- .../src/Ouroboros/Network/ConnectionManager/Core.hs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index fa13251d9d..27c24ad924 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -81,6 +81,8 @@ data Arguments handlerTrace socket peerAddr handle handleError versionNumber ver -- | Trace state transitions. -- + -- TODO: do we need this tracer? In some tests we relay on `traceTVar` in + -- `newNetworkMutableState` instead. trTracer :: Tracer m (TransitionTrace peerAddr (ConnectionState peerAddr handle handleError versionNumber m)), From f11758176591c9868beb06a42e7f22c9ac36d9b0 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 25 Nov 2024 21:51:29 +0100 Subject: [PATCH 10/15] inbound-governor: handle unknown connection IOSimPOR discovered that a connection can be removed from the connection manager (by the connection clean-up function), while it is promoted to warm remote. --- .../Test/Ouroboros/Network/Server2/Sim.hs | 1 + .../src/Ouroboros/Network/InboundGovernor.hs | 28 ++++++++----------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs index f38d37391d..5b89b8949a 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs @@ -1039,6 +1039,7 @@ prop_connection_manager_valid_transitions_racy (Fixed rnd) serverAcc (ArbDataFlow dataFlow) defaultBearerInfo mns@(MultiNodeScript events attenuationMap) = exploreSimTrace id sim $ \_ trace -> + counterexample (ppTrace trace) $ validate_transitions mns trace where sim :: IOSim s () diff --git a/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs b/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs index 3fb30be73d..506d362da2 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs @@ -39,8 +39,8 @@ module Ouroboros.Network.InboundGovernor import Control.Applicative (Alternative) import Control.Concurrent.Class.MonadSTM qualified as LazySTM import Control.Concurrent.Class.MonadSTM.Strict -import Control.Exception (SomeAsyncException (..), assert) -import Control.Monad (foldM, when) +import Control.Exception (SomeAsyncException (..)) +import Control.Monad (foldM) import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTime.SI @@ -447,20 +447,16 @@ with res <- promotedToWarmRemote connectionManager connId traceWith tracer (TrPromotedToWarmRemote connId res) - when (resultInState res == UnknownConnectionSt) $ do - traceWith tracer (TrUnexpectedlyFalseAssertion - (InboundGovernorLoop - (Just connId) - UnknownConnectionSt) - ) - evaluate (assert False ()) - - let state' = updateRemoteState - connId - RemoteWarm - state - - return (Just connId, state') + case resultInState res of + UnknownConnectionSt -> do + let state' = unregisterConnection connId state + return (Just connId, state') + _ -> do + let state' = updateRemoteState + connId + RemoteWarm + state + return (Just connId, state') RemotePromotedToHot connId -> do traceWith tracer (TrPromotedToHotRemote connId) From 699bd90ce457b5f4c34fdf2ae328f39d03a8f0e5 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 25 Nov 2024 22:10:00 +0100 Subject: [PATCH 11/15] inbound-governor: changed order of events It wasn't enough to fix IOSimPOR issue fixed in the previous commit. --- .../src/Ouroboros/Network/InboundGovernor.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs b/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs index 506d362da2..93a95045cf 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/InboundGovernor.hs @@ -211,12 +211,12 @@ with ) <> Map.foldMapWithKey ( firstMuxToFinish + <> firstPeerDemotedToCold + <> firstPeerCommitRemote <> firstMiniProtocolToFinish connectionDataFlow <> firstPeerPromotedToWarm <> firstPeerPromotedToHot <> firstPeerDemotedToWarm - <> firstPeerDemotedToCold - <> firstPeerCommitRemote :: EventSignal muxMode initiatorCtx peerAddr versionData m a b ) From 7fffa470d1d76e4808e47ac079bc14f52d1ba04d Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Thu, 28 Nov 2024 13:22:55 +0100 Subject: [PATCH 12/15] connection-manager: removed IOSimPOR unit test The fixed schedule doesn't satisfy IOSimPOR invariants anymore. --- .../Test/Ouroboros/Network/Testnet.hs | 151 ------------------ 1 file changed, 151 deletions(-) diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs index a9b07a1ddf..338040b335 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs @@ -44,7 +44,6 @@ import System.Random (mkStdGen) import Network.DNS.Types qualified as DNS -import Ouroboros.Network.Block (BlockNo (..)) import Ouroboros.Network.BlockFetch (PraosFetchMode (..), TraceFetchClientState (..)) import Ouroboros.Network.ConnectionHandler (ConnectionHandlerTrace) @@ -167,8 +166,6 @@ tests = , nightlyTest $ testProperty "steps" (testWithIOSimPOR prop_churn_steps 10000) ] - , testGroup "unit" - [ nightlyTest $ testProperty "unit cm" unit_cm_valid_transitions ] ] , testGroup "IOSim" [ testProperty "no failure" @@ -297,154 +294,6 @@ testWithIOSimPOR f traceNumber bi ds = $ exploreSimTrace id sim $ \_ ioSimTrace -> f ioSimTrace traceNumber --- | This test checks a IOSimPOR false positive bug with the connection --- manager state transition traces no longer happens. --- -unit_cm_valid_transitions :: Property -unit_cm_valid_transitions = - let bi = AbsBearerInfo - { abiConnectionDelay = SmallDelay - , abiInboundAttenuation = NoAttenuation FastSpeed - , abiOutboundAttenuation = NoAttenuation FastSpeed - , abiInboundWriteFailure = Nothing - , abiOutboundWriteFailure = Just 0 - , abiAcceptFailure = Nothing - , abiSDUSize = LargeSDU - } - ds = DiffusionScript - (SimArgs 1 10) - (Script ((Map.empty, ShortDelay) :| [(Map.empty, LongDelay)])) - [ ( NodeArgs - (-2) - InitiatorAndResponderDiffusionMode - (Just 269) - (Map.fromList [(RelayAccessAddress "0:71:0:1:0:1:0:1" 65534, - DoAdvertisePeer)]) - GenesisMode - (Script (DontUseBootstrapPeers :| [])) - (TestAddress (IPAddr (read "0:79::1:0:0") 3)) - PeerSharingDisabled - [ (HotValency {getHotValency = 1}, - WarmValency {getWarmValency = 1}, - Map.fromList [(RelayAccessAddress "0:71:0:1:0:1:0:1" 65534, - (DoAdvertisePeer, IsTrustable))]) - ] - (Script (LedgerPools [] :| [])) - (ConsensusModePeerTargets - { deadlineTargets = PeerSelectionTargets - { targetNumberOfRootPeers = 4 - , targetNumberOfKnownPeers = 4 - , targetNumberOfEstablishedPeers = 3 - , targetNumberOfActivePeers = 2 - , targetNumberOfKnownBigLedgerPeers = 4 - , targetNumberOfEstablishedBigLedgerPeers = 1 - , targetNumberOfActiveBigLedgerPeers = 1 - } - , syncTargets = PeerSelectionTargets - { targetNumberOfRootPeers = 0 - , targetNumberOfKnownPeers = 4 - , targetNumberOfEstablishedPeers = 0 - , targetNumberOfActivePeers = 0 - , targetNumberOfKnownBigLedgerPeers = 4 - , targetNumberOfEstablishedBigLedgerPeers = 4 - , targetNumberOfActiveBigLedgerPeers = 3 - } - }) - (Script (DNSTimeout {getDNSTimeout = 0.325} :| [])) - (Script (DNSLookupDelay {getDNSLookupDelay = 0.1} :| - [DNSLookupDelay {getDNSLookupDelay = 0.072}])) - Nothing - False - (Script (FetchModeBulkSync :| [FetchModeBulkSync])) - , [JoinNetwork 0.5] - ) - , ( NodeArgs - 0 - InitiatorAndResponderDiffusionMode - (Just 90) - Map.empty - GenesisMode - (Script (DontUseBootstrapPeers :| [])) - (TestAddress (IPAddr (read "0:71:0:1:0:1:0:1") 65534)) - PeerSharingEnabled - [ (HotValency {getHotValency = 1}, - WarmValency {getWarmValency = 1}, - Map.fromList [(RelayAccessAddress "0:79::1:0:0" 3, - (DoNotAdvertisePeer, IsTrustable))]) - ] - (Script (LedgerPools [] :| [])) - (ConsensusModePeerTargets - { deadlineTargets = PeerSelectionTargets - { targetNumberOfRootPeers = 1 - , targetNumberOfKnownPeers = 1 - , targetNumberOfEstablishedPeers = 1 - , targetNumberOfActivePeers = 1 - , targetNumberOfKnownBigLedgerPeers = 4 - , targetNumberOfEstablishedBigLedgerPeers = 3 - , targetNumberOfActiveBigLedgerPeers = 3 - } - , syncTargets = PeerSelectionTargets - { targetNumberOfRootPeers = 0 - , targetNumberOfKnownPeers = 1 - , targetNumberOfEstablishedPeers = 1 - , targetNumberOfActivePeers = 1 - , targetNumberOfKnownBigLedgerPeers = 4 - , targetNumberOfEstablishedBigLedgerPeers = 2 - , targetNumberOfActiveBigLedgerPeers = 2 - } - }) - (Script (DNSTimeout {getDNSTimeout = 0.18} :| [])) - (Script (DNSLookupDelay {getDNSLookupDelay = 0.125} :| [])) - (Just (BlockNo 2)) - False - (Script (FetchModeDeadline :| [])) - , [JoinNetwork 1.484848484848] - ) - ] - s = ControlAwait - [ ScheduleMod - (RacyThreadId [3,1,3,1,2,3,2,1], 7) - ControlDefault - [ (RacyThreadId [3,1,3,1,2,3,2], 32) - , (RacyThreadId [3,1,3,1,2,3,2], 33) - , (RacyThreadId [2,1,3,1,4], 8) - , (RacyThreadId [2,1,3,1,4], 9) - , (RacyThreadId [2,1,3,1,4], 10) - , (RacyThreadId [2,1,3,1,4], 11) - , (RacyThreadId [2,1,3,1,4], 12) - , (RacyThreadId [2,1,3,1,4,1], 0) - , (RacyThreadId [2,1,3,1,4,1], 1) - , (RacyThreadId [2,1,3,1,4,1], 2) - , (RacyThreadId [2,1,3,1,4,1], 3) - , (RacyThreadId [2,1,3,1,4,1], 4) - , (RacyThreadId [2,1,3,1,4,1], 5) - , (RacyThreadId [2,1,3,1,4,1], 6) - , (RacyThreadId [2,1,3,1,4,1], 7) - , (RacyThreadId [2,1,3,1,4,1], 8) - , (RacyThreadId [2,1,3,1,4,1,1], 0) - , (RacyThreadId [2,1,3,1,4,1,1], 1) - , (RacyThreadId [2,1,3,1,4,1,1], 2) - , (RacyThreadId [2,1,3,1,4,1,1], 3) - , (RacyThreadId [2,1,3,1,4,1], 9) - , (RacyThreadId [2,1,3,1,4,1], 10) - , (RacyThreadId [2,1,3,1,4,1], 11) - , (RacyThreadId [2,1,3,1,4,1], 12) - , (RacyThreadId [2,1,3,1,4,1], 13) - , (RacyThreadId [2,1,3,1,4,1], 14) - , (RacyThreadId [2,1,3,1,4,1], 15) - , (RacyThreadId [2,1,3,1,4], 13) - , (RacyThreadId [2,1,3,1,4], 14) - , (RacyThreadId [2,1,3,1,4], 15) - , (RacyThreadId [2,1,3,1,4], 16) - , (RacyThreadId [3,1,3,1,2,3,2], 34) - ] - ] - sim :: forall s. IOSim s Void - sim = do - exploreRaces - diffusionSimulation (toBearerInfo bi) ds iosimTracer - in exploreSimTrace (\a -> a { explorationReplay = Just s }) sim $ \_ ioSimTrace -> - prop_diffusion_cm_valid_transition_order_iosim_por ioSimTrace 10000 -- | As a basic property we run the governor to explore its state space a bit -- and check it does not throw any exceptions (assertions such as invariant From 9a8d128a99a776d47d5ce5234555e629cd83b1bc Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Fri, 29 Nov 2024 14:59:49 +0100 Subject: [PATCH 13/15] connection-manager: trace state changes through `traceTMVar` Note that this is only effective in `IOSim`. --- .../Ouroboros/Network/ConnectionManager/Core.hs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 27c24ad924..8b461a6e7f 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -408,15 +408,11 @@ with args@Arguments { <- atomically $ do v <- newTMVar State.empty labelTMVar v "cm-state" - traceTMVar (Proxy :: Proxy m) v - $ \old new -> - case (old, new) of - (Nothing, _) -> pure DontTrace - -- taken - (Just (Just _), Nothing) -> pure (TraceString "cm-state: taken") - -- released - (Just Nothing, Just _) -> pure (TraceString "cm-state: released") - (_, _) -> pure DontTrace + traceTMVar (Proxy :: Proxy m) v $ \_ mbst -> do + st' <- case mbst of + Nothing -> pure Nothing + Just st -> Just <$> traverse (inspectTVar (Proxy :: Proxy m) . toLazyTVar . connVar) st + return (TraceString (show st')) freshIdSupply <- State.newFreshIdSupply (Proxy :: Proxy m) stdGenVar <- newTVar (stdGen args) From bc4df7e2420125120efd5a6b5c0ac47de4c84563 Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Fri, 29 Nov 2024 15:17:16 +0100 Subject: [PATCH 14/15] ouroboros-network-framework: updated CHANGELOG file --- ouroboros-network-framework/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ouroboros-network-framework/CHANGELOG.md b/ouroboros-network-framework/CHANGELOG.md index 49f4022aa2..823d534dc3 100644 --- a/ouroboros-network-framework/CHANGELOG.md +++ b/ouroboros-network-framework/CHANGELOG.md @@ -12,6 +12,8 @@ `unregister{Inbound,Outbound}Connection` to `release{Inbound,Outbound}Connection`. `AssertionLocation` constructors were renamed as well. * Added `RawBearer` API (see https://github.com/IntersectMBO/ouroboros-network/pull/4395) +* Connection manager is using `ConnectionId`s to identify connections, this + affects its API. ### Non-breaking changes From e73179a83f826736dde228e70f1cc707e95041de Mon Sep 17 00:00:00 2001 From: Marcin Szamotulski Date: Mon, 2 Dec 2024 17:11:16 +0100 Subject: [PATCH 15/15] testnet: provide additional context in counterexamples It's useful to provide not only the transitions that didn't match, but also the time and the server name to make it easier to locate the transition in a trace. --- .../Test/Ouroboros/Network/Server2/Sim.hs | 4 +- .../Network/ConnectionManager/Test/Utils.hs | 27 ++++++++------ .../Test/Ouroboros/Network/Testnet.hs | 37 +++++++++++++------ 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs index 5b89b8949a..bbb968f1b0 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Server2/Sim.hs @@ -1166,7 +1166,7 @@ prop_connection_manager_valid_transition_order (Fixed rnd) serverAcc (ArbDataFlo MainReturn {} -> mempty _ -> All False ) - (verifyAbstractTransitionOrder True) + (verifyAbstractTransitionOrder id True) . fmap (map ttTransition) . groupConns id abstractStateIsFinalTransition $ abstractTransitionEvents @@ -1206,7 +1206,7 @@ prop_connection_manager_valid_transition_order_racy (Fixed rnd) serverAcc (ArbDa MainReturn {} -> mempty _ -> All False ) - (verifyAbstractTransitionOrder True) + (verifyAbstractTransitionOrder id True) . fmap (map ttTransition) . groupConns id abstractStateIsFinalTransition $ abstractTransitionEvents diff --git a/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Utils.hs b/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Utils.hs index ec6e7fe489..710bc03472 100644 --- a/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Utils.hs +++ b/ouroboros-network-framework/testlib/Ouroboros/Network/ConnectionManager/Test/Utils.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE ScopedTypeVariables #-} module Ouroboros.Network.ConnectionManager.Test.Utils where @@ -193,20 +194,22 @@ validTransitionMap t@Transition { fromState, toState } = -- Assuming all transitions in the transition list are valid, we only need to -- look at the 'toState' of the current transition and the 'fromState' of the -- next transition. -verifyAbstractTransitionOrder :: Bool -- ^ Check last transition: useful for +verifyAbstractTransitionOrder :: forall a. Show a + => (a -> AbstractTransition) + -> Bool -- ^ Check last transition: useful for -- distinguish Diffusion layer tests -- vs non-Diffusion ones. - -> [AbstractTransition] + -> [a] -> All -verifyAbstractTransitionOrder _ [] = mempty -verifyAbstractTransitionOrder checkLast (h:t) = go t h +verifyAbstractTransitionOrder _ _ [] = mempty +verifyAbstractTransitionOrder get checkLast (h:t) = go t h where - go :: [AbstractTransition] -> AbstractTransition -> All + go :: [a] -> a -> All -- All transitions must end in the 'UnknownConnectionSt', and since we -- assume that all transitions are valid we do not have to check the -- 'fromState'. - go [] (Transition _ UnknownConnectionSt) = mempty - go [] tr@(Transition _ _) = + go [] a | (Transition _ UnknownConnectionSt) <- get a = mempty + go [] a | tr@(Transition _ _) <- get a = All $ counterexample ("\nUnexpected last transition: " ++ show tr) @@ -214,14 +217,14 @@ verifyAbstractTransitionOrder checkLast (h:t) = go t h -- All transitions have to be in a correct order, which means that the -- current state we are looking at (current toState) needs to be equal to -- the next 'fromState', in order for the transition chain to be correct. - go (next@(Transition nextFromState _) : ts) - curr@(Transition _ currToState) = + go (a : as) b | (Transition nextFromState _) <- get a + , (Transition _ currToState) <- get b = All (counterexample ("\nUnexpected transition order!\nWent from: " - ++ show curr ++ "\nto: " ++ show next) + ++ show b ++ "\nto: " ++ show a) (property (currToState == nextFromState))) - <> go ts next + <> go as a -- | List of all valid transition's names. diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs index 338040b335..37796ec5d8 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs @@ -275,8 +275,10 @@ testWithIOSim f traceNumber bi ds = sim = diffusionSimulation (toBearerInfo bi) ds iosimTracer + trace = runSimTrace sim in labelDiffusionScript ds - $ f (runSimTrace sim) traceNumber + $ counterexample (Trace.ppTrace show (ppSimEvent 0 0 0) $ Trace.take traceNumber trace) + $ f trace traceNumber testWithIOSimPOR :: (SimTrace Void -> Int -> Property) -> Int @@ -3041,7 +3043,7 @@ prop_diffusion_cm_valid_transition_order_iosim_por ioSimTrace traceNumber = property . bifoldMap (const mempty) - (verifyAbstractTransitionOrder False) + (verifyAbstractTransitionOrder id False) . fmap (map ttTransition) . groupConns id abstractStateIsFinalTransitionTVarTracing @@ -3074,25 +3076,24 @@ prop_diffusion_cm_valid_transition_order ioSimTrace traceNumber = . last $ evsList in classifySimulatedTime lastTime - $ classifyNumberOfEvents (length evsList) - $ verify_cm_valid_transition_order - $ (\(WithName _ (WithTime _ b)) -> b) - <$> ev + . classifyNumberOfEvents (length evsList) + . verify_cm_valid_transition_order + $ ev ) <$> events where - verify_cm_valid_transition_order :: Trace () DiffusionTestTrace -> Property + verify_cm_valid_transition_order :: Trace () (WithName NtNAddr (WithTime DiffusionTestTrace)) -> Property verify_cm_valid_transition_order events = - let abstractTransitionEvents :: Trace () (AbstractTransitionTrace NtNAddr) + let abstractTransitionEvents :: Trace () (WithName NtNAddr (WithTime (AbstractTransitionTrace NtNAddr))) abstractTransitionEvents = - selectDiffusionConnectionManagerTransitionEvents events + selectDiffusionConnectionManagerTransitionEvents' events in property . bifoldMap (const mempty) - (verifyAbstractTransitionOrder False) - . fmap (map ttTransition) - . groupConns id abstractStateIsFinalTransition + (verifyAbstractTransitionOrder (wtEvent . wnEvent) False) + . fmap (map (fmap (fmap ttTransition))) + . groupConns (wtEvent . wnEvent) abstractStateIsFinalTransition $ abstractTransitionEvents -- | Unit test that checks issue 4258 @@ -4178,6 +4179,18 @@ selectDiffusionConnectionManagerTransitionEvents = _ -> Nothing) . Trace.toList +selectDiffusionConnectionManagerTransitionEvents' + :: Trace () (WithName NtNAddr (WithTime DiffusionTestTrace)) + -> Trace () (WithName NtNAddr (WithTime (AbstractTransitionTrace NtNAddr))) +selectDiffusionConnectionManagerTransitionEvents' = + Trace.fromList () + . mapMaybe + (\case + (WithName addr (WithTime time (DiffusionConnectionManagerTransitionTrace e))) + -> Just (WithName addr (WithTime time e)) + _ -> Nothing) + . Trace.toList + selectDiffusionConnectionManagerTransitionEventsTime :: Trace () (Time, DiffusionTestTrace) -> Trace () (Time, AbstractTransitionTrace NtNAddr)