Skip to content

Commit

Permalink
refactor in anticipation of testing requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
crocodile-dentist committed Dec 9, 2024
1 parent 6d5d221 commit f4e401f
Showing 1 changed file with 131 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ module Ouroboros.Network.PeerSelection.RootPeersDNS.DNSActions
, Resource (..)
, retryResource
, constantResource
-- ** Exposed for testing purposes
, lookupSRVWithTTL
, dispatchLookupWithTTL
-- ** Error type
, DNSorIOError (..)
, DNSLookupResult (..)
Expand Down Expand Up @@ -49,7 +52,7 @@ import System.Directory (getModificationTime)
#endif

import Data.Map qualified as Map
import Network.DNS (DNSError)
import Network.DNS (DNSError, DNSMessage)
import Network.DNS qualified as DNS

import Ouroboros.Network.PeerSelection.RelayAccessPoint
Expand Down Expand Up @@ -221,9 +224,14 @@ ioDNSActions =
\reqs -> DNSActions {
dnsResolverResource = resolverResource,
dnsAsyncResolverResource = asyncResolverResource,
dnsLookupWithTTL = dispatchLookupWithTTL reqs
dnsLookupWithTTL = dispatchLookupWithTTL reqs mkIOAction4
}
where
mkIOAction4 resolver resolvConf domain ofType =
timeout (microsecondsAsIntToDiffTime
$ DNS.resolvTimeout resolvConf)
(DNS.lookupRaw resolver domain ofType)

-- |
--
-- TODO: it could be useful for `publicRootPeersProvider`.
Expand Down Expand Up @@ -334,86 +342,36 @@ ioDNSActions =
return (Right resolver', go filePath resourceVar)


-- | Like 'DNS.lookupA' but also return the TTL for the results.
--
-- DNS library timeouts do not work reliably on Windows (#1873), hence the
-- additional timeout.
--
lookupAWithTTL :: DNS.ResolvConf
-> DNS.Resolver
-> DNS.Domain
-> IO (Either DNS.DNSError [(IP, DNS.TTL)])
lookupAWithTTL resolvConf resolver domain = do
reply <- timeout (microsecondsAsIntToDiffTime
$ DNS.resolvTimeout resolvConf)
(DNS.lookupRaw resolver domain DNS.A)
case reply of
Nothing -> return (Left DNS.TimeoutExpired)
Just (Left err) -> return (Left err)
Just (Right ans) -> return (DNS.fromDNSMessage ans selectA)
--TODO: we can get the SOA TTL on NXDOMAIN here if we want to
where
selectA DNS.DNSMessage { DNS.answer } =
[ (IPv4 addr, fixupTTL ttl)
| DNS.ResourceRecord {
DNS.rdata = DNS.RD_A addr,
DNS.rrttl = ttl
} <- answer
]


lookupAAAAWithTTL :: DNS.ResolvConf
-> DNS.Resolver
-> DNS.Domain
-> IO (Either DNS.DNSError [(IP, DNS.TTL)])
lookupAAAAWithTTL resolvConf resolver domain = do
reply <- timeout (microsecondsAsIntToDiffTime
$ DNS.resolvTimeout resolvConf)
(DNS.lookupRaw resolver domain DNS.AAAA)
case reply of
Nothing -> return (Left DNS.TimeoutExpired)
Just (Left err) -> return (Left err)
Just (Right ans) -> return (DNS.fromDNSMessage ans selectAAAA)
--TODO: we can get the SOA TTL on NXDOMAIN here if we want to
where
selectAAAA DNS.DNSMessage { DNS.answer } =
[ (IPv6 addr, fixupTTL ttl)
| DNS.ResourceRecord {
DNS.rdata = DNS.RD_AAAA addr,
DNS.rrttl = ttl
} <- answer
]

lookupSRVWithTTL :: DNSLookupType
-> DNS.Domain
-> DNS.ResolvConf
-> DNS.Resolver
-> StdGen
-> IO (DNSLookupResult IP)
lookupSRVWithTTL ofType domain0 resolvConf resolver rng = do
reply <- timeout (microsecondsAsIntToDiffTime
$ DNS.resolvTimeout resolvConf)
(DNS.lookupRaw resolver domain0 DNS.SRV)
case reply of
Nothing -> return $ DNSLookupSRV (domain0, [DNS.TimeoutExpired], Nothing)
Just (Left err) -> return $ DNSLookupSRV (domain0, [err], Nothing)
Just (Right msg) ->
case DNS.fromDNSMessage msg selectSRV of
Left err -> return $ DNSLookupSRV (domain0, [err], Nothing)
Right services -> do
let srvByPriority = sortOn priority services
grouped = NE.groupWith priority srvByPriority

case listToMaybe grouped of
Just topPriority -> do
case topPriority of
(single, _, _, port) NE.:| [] ->
DNSLookupSRV
. inclDomainAndPort single port <$>
lookupWithTTL ofType single resolvConf resolver
many ->
DNSLookupSRV <$> runWeightedLookup (NE.toList many)
Nothing -> return $ DNSLookupSRV (domain0, [], Nothing)
lookupSRVWithTTL :: (MonadAsync m)
=> DNSLookupType
-> DNS.Domain
-> ( DNS.Domain
-> DNS.TYPE
-> m (Maybe (Either DNSError DNSMessage)))
-> StdGen
-> m (DNSLookupResult IP)
lookupSRVWithTTL ofType domain0 mkAction2 rng = do
reply <- mkAction2 domain0 DNS.SRV
case reply of
Nothing -> return $ DNSLookupSRV (domain0, [DNS.TimeoutExpired], Nothing)
Just (Left err) -> return $ DNSLookupSRV (domain0, [err], Nothing)
Just (Right msg) ->
case DNS.fromDNSMessage msg selectSRV of
Left err -> return $ DNSLookupSRV (domain0, [err], Nothing)
Right services -> do
let srvByPriority = sortOn priority services
grouped = NE.groupWith priority srvByPriority

case listToMaybe grouped of
Just topPriority -> do
case topPriority of
(single, _, _, port) NE.:| [] ->
DNSLookupSRV
. inclDomainAndPort single port <$>
lookupWithTTL ofType (mkAction2 single)
many ->
DNSLookupSRV <$> runWeightedLookup (NE.toList many)
Nothing -> return $ DNSLookupSRV (domain0, [], Nothing)

where
inclDomainAndPort domain !port (e, ipsttls) = (domain, e, Just (fromIntegral port, ipsttls))
Expand All @@ -429,7 +387,7 @@ ioDNSActions =
(pick, _) = randomR (0, upperBound) rng
(winner, _, _, port) = snd . fromJust $ Map.lookupGE pick mapCdf
in inclDomainAndPort winner port <$>
lookupWithTTL ofType winner resolvConf resolver
lookupWithTTL ofType (mkAction2 winner)

selectSRV DNS.DNSMessage { DNS.answer } =
[ (domain', priority', weight', port)
Expand All @@ -441,45 +399,96 @@ ioDNSActions =
weight (_, _, w, _) = w
priority (_, p, _, _) = p

dispatchLookupWithTTL :: DNSLookupType
-> DomainAccessPoint
-> DNS.ResolvConf
-> DNS.Resolver
-> StdGen
-> IO (DNSLookupResult IP)
dispatchLookupWithTTL lookupType domain conf resolver rng =
case domain of
DomainAccessPoint d -> wrap <$> lookupWithTTL lookupType (dapDomain d) conf resolver
where
wrap (a, b) = DNSLookup (d, a, b)
DomainSRVAccessPoint d -> lookupSRVWithTTL lookupType (srvDomain d) conf resolver rng

lookupWithTTL :: DNSLookupType
-> DNS.Domain
-> DNS.ResolvConf
-> DNS.Resolver
-> IO ([DNS.DNSError], [(IP, DNS.TTL)])
lookupWithTTL LookupReqAOnly domain resolvConf resolver = do
res <- lookupAWithTTL resolvConf resolver domain
case res of
Left err -> return ([err], [])
Right r -> return ([], r)

lookupWithTTL LookupReqAAAAOnly domain resolvConf resolver = do
res <- lookupAAAAWithTTL resolvConf resolver domain
case res of
Left err -> return ([err], [])
Right r -> return ([], r)

lookupWithTTL LookupReqAAndAAAA domain resolvConf resolver = do
(r_ipv6, r_ipv4) <- concurrently (lookupAAAAWithTTL resolvConf resolver domain)
(lookupAWithTTL resolvConf resolver domain)
case (r_ipv6, r_ipv4) of
(Left e6, Left e4) -> return ([e6, e4], [])
(Right r6, Left e4) -> return ([e4], r6)
(Left e6, Right r4) -> return ([e6], r4)
(Right r6, Right r4) -> return ([], r6 <> r4)

dispatchLookupWithTTL :: (MonadAsync m)
=> DNSLookupType
-> ( resolver
-> resolvConf
-> DNS.Domain
-> DNS.TYPE
-> m (Maybe (Either DNSError DNSMessage)))
-> DomainAccessPoint
-> resolvConf
-> resolver
-> StdGen
-> m (DNSLookupResult IP)
dispatchLookupWithTTL lookupType mkAction4 domain conf resolver rng =
let mkAction2 = mkAction4 resolver conf
in case domain of
DomainAccessPoint d -> push <$> lookupWithTTL lookupType (mkAction2 dapDomain)
where
DomainPlain { dapDomain } = d
push (a, b) = DNSLookup (d, a, b)
DomainSRVAccessPoint d -> lookupSRVWithTTL lookupType (srvDomain d) mkAction2 rng

lookupWithTTL :: (MonadAsync m)
=> DNSLookupType
-> ( DNS.TYPE
-> m (Maybe (Either DNSError DNSMessage)))
-> m ([DNS.DNSError], [(IP, DNS.TTL)])
lookupWithTTL LookupReqAOnly action1 = do
res <- lookupAWithTTL (action1 DNS.A)
case res of
Left err -> return ([err], [])
Right r -> return ([], r)

lookupWithTTL LookupReqAAAAOnly action1 = do
res <- lookupAAAAWithTTL (action1 DNS.AAAA)
case res of
Left err -> return ([err], [])
Right r -> return ([], r)

lookupWithTTL LookupReqAAndAAAA action1 = do
(r_ipv6, r_ipv4) <- concurrently (lookupAAAAWithTTL (action1 DNS.AAAA))
(lookupAWithTTL (action1 DNS.A))
case (r_ipv6, r_ipv4) of
(Left e6, Left e4) -> return ([e6, e4], [])
(Right r6, Left e4) -> return ([e4], r6)
(Left e6, Right r4) -> return ([e6], r4)
(Right r6, Right r4) -> return ([], r6 <> r4)

-- | Like 'DNS.lookupA' but also return the TTL for the results.
--
-- DNS library timeouts do not work reliably on Windows (#1873), hence the
-- additional timeout.
--
lookupAWithTTL :: (Monad m)
=> m (Maybe (Either DNSError DNSMessage))
-> m (Either DNS.DNSError [(IP, DNS.TTL)])
lookupAWithTTL action = do
reply <- action
case reply of
Nothing -> return (Left DNS.TimeoutExpired)
Just (Left err) -> return (Left err)
Just (Right ans) -> return (DNS.fromDNSMessage ans selectA)
--TODO: we can get the SOA TTL on NXDOMAIN here if we want to
where
selectA DNS.DNSMessage { DNS.answer } =
[ (IPv4 addr, fixupTTL ttl)
| DNS.ResourceRecord {
DNS.rdata = DNS.RD_A addr,
DNS.rrttl = ttl
} <- answer
]


lookupAAAAWithTTL :: (Monad m)
=> m (Maybe (Either DNSError DNSMessage))
-> m (Either DNS.DNSError [(IP, DNS.TTL)])
lookupAAAAWithTTL action = do
reply <- action
case reply of
Nothing -> return (Left DNS.TimeoutExpired)
Just (Left err) -> return (Left err)
Just (Right ans) -> return (DNS.fromDNSMessage ans selectAAAA)
--TODO: we can get the SOA TTL on NXDOMAIN here if we want to
where
selectAAAA DNS.DNSMessage { DNS.answer } =
[ (IPv6 addr, fixupTTL ttl)
| DNS.ResourceRecord {
DNS.rdata = DNS.RD_AAAA addr,
DNS.rrttl = ttl
} <- answer
]

--
-- Utils
Expand Down

0 comments on commit f4e401f

Please sign in to comment.