Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved connection manager transition tests #5026

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ouroboros-network-framework/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* Added `RawBearer` API (see https://github.com/IntersectMBO/ouroboros-network/pull/4395)
* Connection manager is using `ConnectionId`s to identify connections, this
affects its API.
* Added `connStateSupply` record field to
`Ouroboros.Network.ConnectionManager.Core.Arguments`.

### Non-breaking changes

Expand Down
9 changes: 7 additions & 2 deletions ouroboros-network-framework/demo/connection-manager.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import Ouroboros.Network.ConnectionHandler
import Ouroboros.Network.ConnectionManager.Core qualified as CM
import Ouroboros.Network.ConnectionManager.InformationChannel
(newInformationChannel)
import Ouroboros.Network.ConnectionManager.State qualified as CM
import Ouroboros.Network.ConnectionManager.Types
import Ouroboros.Network.Context
import Ouroboros.Network.IOManager
Expand Down Expand Up @@ -187,6 +188,7 @@ withBidirectionalConnectionManager
-> Mux.MakeBearer m socket
-> socket
-- ^ listening socket
-> CM.ConnStateIdSupply m
-> DiffTime -- protocol idle timeout
-> DiffTime -- wait time timeout
-> Maybe peerAddr
Expand All @@ -201,6 +203,7 @@ withBidirectionalConnectionManager
-> m a)
-> m a
withBidirectionalConnectionManager snocket makeBearer socket
connStateIdSupply
protocolIdleTimeout
timeWaitTimeout
localAddress
Expand Down Expand Up @@ -244,7 +247,8 @@ withBidirectionalConnectionManager snocket makeBearer socket
acceptedConnectionsSoftLimit = maxBound,
acceptedConnectionsDelay = 0
},
CM.updateVersionData = \a _ -> a
CM.updateVersionData = \a _ -> a,
CM.connStateIdSupply
}
(makeConnectionHandler
muxTracer
Expand Down Expand Up @@ -458,8 +462,9 @@ bidirectionalExperiment
localAddr remoteAddr
clientAndServerData = do
stdGen <- Random.newStdGen
connStateIdSupply <- atomically $ CM.newConnStateIdSupply (Proxy @IO)
withBidirectionalConnectionManager
snocket makeBearer socket0
snocket makeBearer socket0 connStateIdSupply
protocolIdleTimeout timeWaitTimeout
(Just localAddr) stdGen clientAndServerData $
\connectionManager _serverAddr -> forever' $ do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import Data.List (intercalate, sortOn)
import Data.Map (Map)
import Data.Map.Strict qualified as Map
import Data.Monoid (All (..))
import Data.Proxy (Proxy (..))
import Data.Text.Lazy qualified as Text
import Data.Void (Void)
import Quiet
Expand Down Expand Up @@ -731,6 +732,7 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap =
experiment = do
labelThisThread "th-main"
snocket <- mkSnocket scheduleMap
connStateIdSupply <- atomically $ CM.newConnStateIdSupply Proxy
let tracer :: Tracer (IOSim s) TestConnectionManagerTrace
tracer = Tracer (say . show)
{--
Expand Down Expand Up @@ -775,7 +777,8 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap =
},
CM.timeWaitTimeout = testTimeWaitTimeout,
CM.outboundIdleTimeout = testOutboundIdleTimeout,
CM.updateVersionData = \a _ -> a
CM.updateVersionData = \a _ -> a,
CM.connStateIdSupply
}
connectionHandler
(\_ -> HandshakeFailure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import Data.Monoid (Sum (..))
import Data.Monoid.Synchronisation (FirstToFinish (..))
import Data.OrdPSQ (OrdPSQ)
import Data.OrdPSQ qualified as OrdPSQ
import Data.Proxy (Proxy (..))
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Typeable (Typeable)
Expand Down Expand Up @@ -654,64 +655,66 @@ multinodeExperiment inboundTrTracer trTracer inboundTracer debugTracer cmTracer
(MultiNodeScript script _) =
withJobPool $ \jobpool -> do
stdGenVar <- newTVarIO stdGen0
cc <- startServerConnectionHandler stdGenVar MainServer dataFlow0 [accInit] serverAddr jobpool
loop stdGenVar (Map.singleton serverAddr [accInit]) (Map.singleton serverAddr cc) script jobpool
connStateIdSupply <- atomically $ CM.newConnStateIdSupply (Proxy @m)
cc <- startServerConnectionHandler stdGenVar connStateIdSupply MainServer dataFlow0 [accInit] serverAddr jobpool
loop stdGenVar connStateIdSupply (Map.singleton serverAddr [accInit]) (Map.singleton serverAddr cc) script jobpool
where

loop :: StrictTVar m StdGen
-> CM.ConnStateIdSupply m
-> Map.Map peerAddr acc
-> Map.Map peerAddr (StrictTQueue m (ConnectionHandlerMessage peerAddr req))
-> [ConnectionEvent req peerAddr]
-> JobPool () m ()
-> m ()
loop _ _ _ [] _ = threadDelay 3600
loop stdGenVar nodeAccs servers (event : events) jobpool =
loop _ _ _ _ [] _ = threadDelay 3600
loop stdGenVar connStateIdSupply nodeAccs servers (event : events) jobpool =
case event of

StartClient delay localAddr -> do
threadDelay delay
cc <- startClientConnectionHandler stdGenVar (Client localAddr) localAddr jobpool
loop stdGenVar nodeAccs (Map.insert localAddr cc servers) events jobpool
cc <- startClientConnectionHandler stdGenVar connStateIdSupply (Client localAddr) localAddr jobpool
loop stdGenVar connStateIdSupply nodeAccs (Map.insert localAddr cc servers) events jobpool

StartServer delay localAddr nodeAcc -> do
threadDelay delay
cc <- startServerConnectionHandler stdGenVar (Node localAddr) Duplex [nodeAcc] localAddr jobpool
loop stdGenVar (Map.insert localAddr [nodeAcc] nodeAccs) (Map.insert localAddr cc servers) events jobpool
cc <- startServerConnectionHandler stdGenVar connStateIdSupply (Node localAddr) Duplex [nodeAcc] localAddr jobpool
loop stdGenVar connStateIdSupply (Map.insert localAddr [nodeAcc] nodeAccs) (Map.insert localAddr cc servers) events jobpool

InboundConnection delay nodeAddr -> do
threadDelay delay
sendMsg nodeAddr $ NewConnection serverAddr
loop stdGenVar nodeAccs servers events jobpool
loop stdGenVar connStateIdSupply nodeAccs servers events jobpool

OutboundConnection delay nodeAddr -> do
threadDelay delay
sendMsg serverAddr $ NewConnection nodeAddr
loop stdGenVar nodeAccs servers events jobpool
loop stdGenVar connStateIdSupply nodeAccs servers events jobpool

CloseInboundConnection delay remoteAddr -> do
threadDelay delay
sendMsg remoteAddr $ Disconnect serverAddr
loop stdGenVar nodeAccs servers events jobpool
loop stdGenVar connStateIdSupply nodeAccs servers events jobpool

CloseOutboundConnection delay remoteAddr -> do
threadDelay delay
sendMsg serverAddr $ Disconnect remoteAddr
loop stdGenVar nodeAccs servers events jobpool
loop stdGenVar connStateIdSupply nodeAccs servers events jobpool

InboundMiniprotocols delay nodeAddr reqs -> do
threadDelay delay
sendMsg nodeAddr $ RunMiniProtocols serverAddr reqs
loop stdGenVar nodeAccs servers events jobpool
loop stdGenVar connStateIdSupply nodeAccs servers events jobpool

OutboundMiniprotocols delay nodeAddr reqs -> do
threadDelay delay
sendMsg serverAddr $ RunMiniProtocols nodeAddr reqs
loop stdGenVar nodeAccs servers events jobpool
loop stdGenVar connStateIdSupply nodeAccs servers events jobpool

ShutdownClientServer delay nodeAddr -> do
threadDelay delay
sendMsg nodeAddr Shutdown
loop stdGenVar nodeAccs servers events jobpool
loop stdGenVar connStateIdSupply nodeAccs servers events jobpool
where
sendMsg :: peerAddr -> ConnectionHandlerMessage peerAddr req -> m ()
sendMsg addr msg = atomically $
Expand All @@ -731,11 +734,12 @@ multinodeExperiment inboundTrTracer trTracer inboundTracer debugTracer cmTracer
Just qs -> readTQueue (projectBundle tok qs)

startClientConnectionHandler :: StrictTVar m StdGen
-> CM.ConnStateIdSupply m
-> Name peerAddr
-> peerAddr
-> JobPool () m ()
-> m (StrictTQueue m (ConnectionHandlerMessage peerAddr req))
startClientConnectionHandler stdGenVar name localAddr jobpool = do
startClientConnectionHandler stdGenVar connStateIdSupply name localAddr jobpool = do
cc <- atomically newTQueue
labelTQueueIO cc $ "cc/" ++ show name
connVar <- newTVarIO Map.empty
Expand All @@ -746,7 +750,8 @@ multinodeExperiment inboundTrTracer trTracer inboundTracer debugTracer cmTracer
$ Job
( withInitiatorOnlyConnectionManager
name simTimeouts nullTracer nullTracer stdGen
snocket makeBearer (Just localAddr) (mkNextRequests connVar)
snocket makeBearer connStateIdSupply
(Just localAddr) (mkNextRequests connVar)
timeLimitsHandshake acceptedConnLimit
( \ connectionManager ->
connectionLoop SingInitiatorMode localAddr cc connectionManager Map.empty connVar
Expand All @@ -758,13 +763,14 @@ multinodeExperiment inboundTrTracer trTracer inboundTracer debugTracer cmTracer
return cc

startServerConnectionHandler :: StrictTVar m StdGen
-> CM.ConnStateIdSupply m
-> Name peerAddr
-> DataFlow
-> acc
-> peerAddr
-> JobPool () m ()
-> m (StrictTQueue m (ConnectionHandlerMessage peerAddr req))
startServerConnectionHandler stdGenVar name dataFlow serverAcc localAddr jobpool = do
startServerConnectionHandler stdGenVar connStateIdSupply name dataFlow serverAcc localAddr jobpool = do
fd <- Snocket.open snocket addrFamily
Snocket.bind snocket fd localAddr
Snocket.listen snocket fd
Expand All @@ -782,7 +788,8 @@ multinodeExperiment inboundTrTracer trTracer inboundTracer debugTracer cmTracer
inboundTrTracer trTracer cmTracer
inboundTracer debugTracer
stdGen
snocket makeBearer (\_ -> pure ()) fd (Just localAddr) serverAcc
snocket makeBearer connStateIdSupply
(\_ -> pure ()) fd (Just localAddr) serverAcc
(mkNextRequests connVar)
timeLimitsHandshake
acceptedConnLimit
Expand All @@ -799,7 +806,8 @@ multinodeExperiment inboundTrTracer trTracer inboundTracer debugTracer cmTracer
(show name)
Unidirectional ->
Job ( withInitiatorOnlyConnectionManager
name simTimeouts trTracer cmTracer stdGen snocket makeBearer (Just localAddr)
name simTimeouts trTracer cmTracer stdGen snocket makeBearer
connStateIdSupply (Just localAddr)
(mkNextRequests connVar)
timeLimitsHandshake
acceptedConnLimit
Expand Down Expand Up @@ -2182,13 +2190,15 @@ prop_server_accept_error (Fixed rnd) (AbsIOError ioerr) =
Snocket.bind snock socket0 addr
Snocket.listen snock socket0
nextRequests <- oneshotNextRequests pdata
connStateIdSupply <- atomically $ CM.newConnStateIdSupply Proxy
withBidirectionalConnectionManager "node-0" simTimeouts
nullTracer nullTracer
nullTracer nullTracer
nullTracer
(mkStdGen rnd)
snock
makeFDBearer
connStateIdSupply
(\_ -> pure ())
socket0 (Just addr)
[accumulatorInit pdata]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ data Arguments handlerTrace socket peerAddr handle handleError versionNumber ver

connectionsLimits :: AcceptedConnectionsLimit,

updateVersionData :: versionData -> DiffusionMode -> versionData
updateVersionData :: versionData -> DiffusionMode -> versionData,

-- | Supply for `ConnStateId`-s.
--
connStateIdSupply :: ConnStateIdSupply m
}


Expand Down Expand Up @@ -397,17 +401,17 @@ with args@Arguments {
connectionDataFlow,
prunePolicy,
connectionsLimits,
updateVersionData
updateVersionData,
connStateIdSupply
}
ConnectionHandler {
connectionHandler
}
classifyHandleError
inboundGovernorInfoChannel
k = do
((connStateIdSupply, stateVar, stdGenVar)
:: ( ConnStateIdSupply m
, StrictTMVar m (ConnectionManagerState peerAddr handle handleError
((stateVar, stdGenVar)
:: ( StrictTMVar m (ConnectionManagerState peerAddr handle handleError
version m)
, StrictTVar m StdGen
))
Expand All @@ -420,9 +424,8 @@ with args@Arguments {
Just st -> Just <$> traverse (inspectTVar (Proxy :: Proxy m) . toLazyTVar . connVar) st
return (TraceString (show st'))

connStateIdSupply <- State.newConnStateIdSupply (Proxy :: Proxy m)
stdGenVar <- newTVar (stdGen args)
return (connStateIdSupply, v, stdGenVar)
return (v, stdGenVar)

let readState
:: STM m (State.ConnMap peerAddr AbstractState)
Expand Down Expand Up @@ -459,8 +462,7 @@ with args@Arguments {
WithInitiatorMode
OutboundConnectionManager {
ocmAcquireConnection =
acquireOutboundConnectionImpl connStateIdSupply stateVar
stdGenVar outboundHandler,
acquireOutboundConnectionImpl stateVar stdGenVar outboundHandler,
ocmReleaseConnection =
releaseOutboundConnectionImpl stateVar stdGenVar
},
Expand All @@ -474,8 +476,7 @@ with args@Arguments {
WithResponderMode
InboundConnectionManager {
icmIncludeConnection =
includeInboundConnectionImpl connStateIdSupply stateVar
inboundHandler,
includeInboundConnectionImpl stateVar inboundHandler,
icmReleaseConnection =
releaseInboundConnectionImpl stateVar,
icmPromotedToWarmRemote =
Expand All @@ -495,15 +496,13 @@ with args@Arguments {
WithInitiatorResponderMode
OutboundConnectionManager {
ocmAcquireConnection =
acquireOutboundConnectionImpl connStateIdSupply stateVar
stdGenVar outboundHandler,
acquireOutboundConnectionImpl stateVar stdGenVar outboundHandler,
ocmReleaseConnection =
releaseOutboundConnectionImpl stateVar stdGenVar
}
InboundConnectionManager {
icmIncludeConnection =
includeInboundConnectionImpl connStateIdSupply stateVar
inboundHandler,
includeInboundConnectionImpl stateVar inboundHandler,
icmReleaseConnection =
releaseInboundConnectionImpl stateVar,
icmPromotedToWarmRemote =
Expand Down Expand Up @@ -846,8 +845,7 @@ with args@Arguments {

includeInboundConnectionImpl
:: HasCallStack
=> ConnStateIdSupply m
-> StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m)
=> StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m)
-> ConnectionHandlerFn handlerTrace socket peerAddr handle handleError version versionData m
-> Word32
-- ^ inbound connections hard limit
Expand All @@ -861,8 +859,7 @@ with args@Arguments {
-> ConnectionId peerAddr
-- ^ connection id used as an identifier of the resource
-> m (Connected peerAddr handle handleError)
includeInboundConnectionImpl connStateIdSupply
stateVar
includeInboundConnectionImpl stateVar
handler
hardLimit
socket
Expand Down Expand Up @@ -1314,14 +1311,13 @@ with args@Arguments {

acquireOutboundConnectionImpl
:: HasCallStack
=> ConnStateIdSupply m
-> StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m)
=> StrictTMVar m (ConnectionManagerState peerAddr handle handleError version m)
-> StrictTVar m StdGen
-> ConnectionHandlerFn handlerTrace socket peerAddr handle handleError version versionData m
-> DiffusionMode
-> peerAddr
-> m (Connected peerAddr handle handleError)
acquireOutboundConnectionImpl connStateIdSupply stateVar stdGenVar handler diffusionMode peerAddr = do
acquireOutboundConnectionImpl stateVar stdGenVar handler diffusionMode peerAddr = do
let provenance = Outbound
traceWith tracer (TrIncludeConnection provenance peerAddr)
(trace, mutableConnState@MutableConnState { connVar, connStateId }
Expand Down
Loading
Loading