From 65dd8bff66489f13ac228f4fe2103a5f622b00ba Mon Sep 17 00:00:00 2001 From: Abtin Keshavarzian Date: Tue, 26 Nov 2024 11:18:53 -0800 Subject: [PATCH] [udp] update `Socket::Open()` to include `NetifId` (#10964) This commit updates how the `NetifId` is set on a `Udp::Socket`. It is now specified as a parameter in the `Open()` call instead of `Bind()`. This change makes the socket more flexible by allowing the `NetifId` to be known earlier (which can be used for future optimizations). It also allows the desired `NetifId` to be set even when `Bind()` is not explicitly called, such as when `Connect()` or `SendTo()` are used on a socket that is not yet bound. Previously, `kNetifThread` was assumed by default in these situations. The public `otUdp` APIs are left unchanged to ensure backward compatibility. --- src/core/api/udp_api.cpp | 6 ++++-- src/core/coap/coap.cpp | 4 ++-- src/core/meshcop/joiner_router.cpp | 2 +- src/core/meshcop/secure_transport.cpp | 4 ++-- src/core/net/dhcp6_client.cpp | 2 +- src/core/net/dhcp6_server.cpp | 2 +- src/core/net/dns_client.cpp | 4 ++-- src/core/net/dnssd_server.cpp | 4 ++-- src/core/net/sntp_client.cpp | 4 ++-- src/core/net/srp_client.cpp | 3 ++- src/core/net/srp_server.cpp | 4 ++-- src/core/net/udp6.cpp | 25 +++++++++---------------- src/core/net/udp6.hpp | 27 +++++++++++++++++---------- src/core/thread/mle.cpp | 2 +- tests/unit/test_srp_server.cpp | 4 ++-- 15 files changed, 50 insertions(+), 47 deletions(-) diff --git a/src/core/api/udp_api.cpp b/src/core/api/udp_api.cpp index 4f5b937650e..008a402f13c 100644 --- a/src/core/api/udp_api.cpp +++ b/src/core/api/udp_api.cpp @@ -44,7 +44,7 @@ otMessage *otUdpNewMessage(otInstance *aInstance, const otMessageSettings *aSett otError otUdpOpen(otInstance *aInstance, otUdpSocket *aSocket, otUdpReceive aCallback, void *aContext) { - return AsCoreType(aInstance).Get().Open(AsCoreType(aSocket), aCallback, aContext); + return AsCoreType(aInstance).Get().Open(AsCoreType(aSocket), Ip6::kNetifThread, aCallback, aContext); } bool otUdpIsOpen(otInstance *aInstance, const otUdpSocket *aSocket) @@ -59,7 +59,9 @@ otError otUdpClose(otInstance *aInstance, otUdpSocket *aSocket) otError otUdpBind(otInstance *aInstance, otUdpSocket *aSocket, const otSockAddr *aSockName, otNetifIdentifier aNetif) { - return AsCoreType(aInstance).Get().Bind(AsCoreType(aSocket), AsCoreType(aSockName), MapEnum(aNetif)); + AsCoreType(aSocket).SetNetifId(MapEnum(aNetif)); + + return AsCoreType(aInstance).Get().Bind(AsCoreType(aSocket), AsCoreType(aSockName)); } otError otUdpConnect(otInstance *aInstance, otUdpSocket *aSocket, const otSockAddr *aSockName) diff --git a/src/core/coap/coap.cpp b/src/core/coap/coap.cpp index eb72f7a4260..242ece0f163 100644 --- a/src/core/coap/coap.cpp +++ b/src/core/coap/coap.cpp @@ -1672,10 +1672,10 @@ Error Coap::Start(uint16_t aPort, Ip6::NetifIdentifier aNetifIdentifier) VerifyOrExit(!mSocket.IsBound()); - SuccessOrExit(error = mSocket.Open()); + SuccessOrExit(error = mSocket.Open(aNetifIdentifier)); socketOpened = true; - SuccessOrExit(error = mSocket.Bind(aPort, aNetifIdentifier)); + SuccessOrExit(error = mSocket.Bind(aPort)); exit: if (error != kErrorNone && socketOpened) diff --git a/src/core/meshcop/joiner_router.cpp b/src/core/meshcop/joiner_router.cpp index efb466eb439..7c793d5beca 100644 --- a/src/core/meshcop/joiner_router.cpp +++ b/src/core/meshcop/joiner_router.cpp @@ -69,7 +69,7 @@ void JoinerRouter::Start(void) VerifyOrExit(!mSocket.IsBound()); - IgnoreError(mSocket.Open()); + IgnoreError(mSocket.Open(Ip6::kNetifThread)); IgnoreError(mSocket.Bind(port)); IgnoreError(Get().AddUnsecurePort(port)); LogInfo("Joiner Router: start"); diff --git a/src/core/meshcop/secure_transport.cpp b/src/core/meshcop/secure_transport.cpp index 27fa1e9230d..9134bcd3d4d 100644 --- a/src/core/meshcop/secure_transport.cpp +++ b/src/core/meshcop/secure_transport.cpp @@ -131,7 +131,7 @@ Error SecureTransport::Open(ReceiveHandler aReceiveHandler, ConnectedHandler aCo VerifyOrExit(IsStateClosed(), error = kErrorAlready); - SuccessOrExit(error = mSocket.Open()); + SuccessOrExit(error = mSocket.Open(Ip6::kNetifUnspecified)); mConnectedCallback.Set(aConnectedHandler, aContext); mReceiveCallback.Set(aReceiveHandler, aContext); @@ -225,7 +225,7 @@ Error SecureTransport::Bind(uint16_t aPort) VerifyOrExit(IsStateOpen(), error = kErrorInvalidState); VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready); - SuccessOrExit(error = mSocket.Bind(aPort, Ip6::kNetifUnspecified)); + SuccessOrExit(error = mSocket.Bind(aPort)); exit: return error; diff --git a/src/core/net/dhcp6_client.cpp b/src/core/net/dhcp6_client.cpp index 2b0f3f760b5..d6d71afbfc6 100644 --- a/src/core/net/dhcp6_client.cpp +++ b/src/core/net/dhcp6_client.cpp @@ -162,7 +162,7 @@ void Client::Start(void) { VerifyOrExit(!mSocket.IsBound()); - IgnoreError(mSocket.Open()); + IgnoreError(mSocket.Open(Ip6::kNetifThread)); IgnoreError(mSocket.Bind(kDhcpClientPort)); ProcessNextIdentityAssociation(); diff --git a/src/core/net/dhcp6_server.cpp b/src/core/net/dhcp6_server.cpp index 275b0b9997a..c240e27306b 100644 --- a/src/core/net/dhcp6_server.cpp +++ b/src/core/net/dhcp6_server.cpp @@ -130,7 +130,7 @@ void Server::Start(void) { VerifyOrExit(!mSocket.IsOpen()); - IgnoreError(mSocket.Open()); + IgnoreError(mSocket.Open(Ip6::kNetifThread)); IgnoreError(mSocket.Bind(kDhcpServerPort)); exit: diff --git a/src/core/net/dns_client.cpp b/src/core/net/dns_client.cpp index c1df35128db..070fec15674 100644 --- a/src/core/net/dns_client.cpp +++ b/src/core/net/dns_client.cpp @@ -764,8 +764,8 @@ Error Client::Start(void) { Error error; - SuccessOrExit(error = mSocket.Open()); - SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified)); + SuccessOrExit(error = mSocket.Open(Ip6::kNetifUnspecified)); + SuccessOrExit(error = mSocket.Bind(0)); exit: return error; diff --git a/src/core/net/dnssd_server.cpp b/src/core/net/dnssd_server.cpp index 67c8c41c344..d22976cb11f 100644 --- a/src/core/net/dnssd_server.cpp +++ b/src/core/net/dnssd_server.cpp @@ -71,8 +71,8 @@ Error Server::Start(void) VerifyOrExit(!IsRunning()); - SuccessOrExit(error = mSocket.Open()); - SuccessOrExit(error = mSocket.Bind(kPort, kBindUnspecifiedNetif ? Ip6::kNetifUnspecified : Ip6::kNetifThread)); + SuccessOrExit(error = mSocket.Open(kBindUnspecifiedNetif ? Ip6::kNetifUnspecified : Ip6::kNetifThread)); + SuccessOrExit(error = mSocket.Bind(kPort)); #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE Get().HandleDnssdServerStateChange(); diff --git a/src/core/net/sntp_client.cpp b/src/core/net/sntp_client.cpp index 64e978fd835..e06f5a7718b 100644 --- a/src/core/net/sntp_client.cpp +++ b/src/core/net/sntp_client.cpp @@ -54,8 +54,8 @@ Error Client::Start(void) { Error error; - SuccessOrExit(error = mSocket.Open()); - SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified)); + SuccessOrExit(error = mSocket.Open(Ip6::kNetifUnspecified)); + SuccessOrExit(error = mSocket.Bind(0)); exit: return error; diff --git a/src/core/net/srp_client.cpp b/src/core/net/srp_client.cpp index b7150ff7f2c..105905a3ede 100644 --- a/src/core/net/srp_client.cpp +++ b/src/core/net/srp_client.cpp @@ -400,9 +400,10 @@ Error Client::Start(const Ip6::SockAddr &aServerSockAddr, Requester aRequester) VerifyOrExit(GetState() == kStateStopped, error = (aServerSockAddr == GetServerAddress()) ? kErrorNone : kErrorBusy); - SuccessOrExit(error = mSocket.Open()); + SuccessOrExit(error = mSocket.Open(Ip6::kNetifThread)); error = mSocket.Connect(aServerSockAddr); + if (error != kErrorNone) { LogInfo("Failed to connect to server %s: %s", aServerSockAddr.GetAddress().ToString().AsCString(), diff --git a/src/core/net/srp_server.cpp b/src/core/net/srp_server.cpp index 499f8ff8ac6..5a329834d68 100644 --- a/src/core/net/srp_server.cpp +++ b/src/core/net/srp_server.cpp @@ -658,8 +658,8 @@ Error Server::PrepareSocket(void) #endif VerifyOrExit(!mSocket.IsOpen()); - SuccessOrExit(error = mSocket.Open()); - error = mSocket.Bind(mPort, Ip6::kNetifThread); + SuccessOrExit(error = mSocket.Open(Ip6::kNetifThread)); + error = mSocket.Bind(mPort); exit: if (error != kErrorNone) diff --git a/src/core/net/udp6.cpp b/src/core/net/udp6.cpp index bbd91fe66a8..9d8bd0c56bd 100644 --- a/src/core/net/udp6.cpp +++ b/src/core/net/udp6.cpp @@ -89,19 +89,13 @@ Message *Udp::Socket::NewMessage(uint16_t aReserved, const Message::Settings &aS return Get().NewMessage(aReserved, aSettings); } -Error Udp::Socket::Open(void) { return Get().Open(*this, mHandler, mContext); } +Error Udp::Socket::Open(NetifIdentifier aNetifId) { return Get().Open(*this, aNetifId, mHandler, mContext); } bool Udp::Socket::IsOpen(void) const { return Get().IsOpen(*this); } -Error Udp::Socket::Bind(const SockAddr &aSockAddr, NetifIdentifier aNetifIdentifier) -{ - return Get().Bind(*this, aSockAddr, aNetifIdentifier); -} +Error Udp::Socket::Bind(const SockAddr &aSockAddr) { return Get().Bind(*this, aSockAddr); } -Error Udp::Socket::Bind(uint16_t aPort, NetifIdentifier aNetifIdentifier) -{ - return Bind(SockAddr(aPort), aNetifIdentifier); -} +Error Udp::Socket::Bind(uint16_t aPort) { return Bind(SockAddr(aPort)); } Error Udp::Socket::Connect(const SockAddr &aSockAddr) { return Get().Connect(*this, aSockAddr); } @@ -172,13 +166,14 @@ Error Udp::RemoveReceiver(Receiver &aReceiver) return error; } -Error Udp::Open(SocketHandle &aSocket, ReceiveHandler aHandler, void *aContext) +Error Udp::Open(SocketHandle &aSocket, NetifIdentifier aNetifId, ReceiveHandler aHandler, void *aContext) { Error error = kErrorNone; OT_ASSERT(!IsOpen(aSocket)); aSocket.Clear(); + aSocket.SetNetifId(aNetifId); aSocket.mHandler = aHandler; aSocket.mContext = aContext; @@ -193,16 +188,14 @@ Error Udp::Open(SocketHandle &aSocket, ReceiveHandler aHandler, void *aContext) return error; } -Error Udp::Bind(SocketHandle &aSocket, const SockAddr &aSockAddr, NetifIdentifier aNetifIdentifier) +Error Udp::Bind(SocketHandle &aSocket, const SockAddr &aSockAddr) { Error error = kErrorNone; #if OPENTHREAD_CONFIG_PLATFORM_UDP_ENABLE - SuccessOrExit(error = otPlatUdpBindToNetif(&aSocket, MapEnum(aNetifIdentifier))); + SuccessOrExit(error = otPlatUdpBindToNetif(&aSocket, MapEnum(aSocket.GetNetifId()))); #endif - aSocket.mNetifId = MapEnum(aNetifIdentifier); - VerifyOrExit(aSockAddr.GetAddress().IsUnspecified() || Get().HasUnicastAddress(aSockAddr.GetAddress()), error = kErrorInvalidArgs); @@ -237,7 +230,7 @@ Error Udp::Connect(SocketHandle &aSocket, const SockAddr &aSockAddr) if (!aSocket.IsBound()) { - SuccessOrExit(error = Bind(aSocket, aSocket.GetSockName(), kNetifThread)); + SuccessOrExit(error = Bind(aSocket, aSocket.GetSockName())); } #if OPENTHREAD_CONFIG_PLATFORM_UDP_ENABLE @@ -300,7 +293,7 @@ Error Udp::SendTo(SocketHandle &aSocket, Message &aMessage, const MessageInfo &a if (!aSocket.IsBound()) { - SuccessOrExit(error = Bind(aSocket, aSocket.GetSockName(), kNetifThread)); + SuccessOrExit(error = Bind(aSocket, aSocket.GetSockName())); } messageInfoLocal.SetSockPort(aSocket.GetSockName().mPort); diff --git a/src/core/net/udp6.hpp b/src/core/net/udp6.hpp index f8650699d0e..c9563196a7a 100644 --- a/src/core/net/udp6.hpp +++ b/src/core/net/udp6.hpp @@ -135,6 +135,13 @@ class Udp : public InstanceLocator, private NonCopyable */ NetifIdentifier GetNetifId(void) const { return static_cast(mNetifId); } + /** + * Sets the network interface identifier. + * + * @param[in] aNetifId The network interface identifier. + */ + void SetNetifId(NetifIdentifier aNetifId) { mNetifId = static_cast(aNetifId); } + #if OPENTHREAD_FTD && OPENTHREAD_CONFIG_BACKBONE_ROUTER_ENABLE /** * Indicate whether or not the socket is bound to the backbone network interface. @@ -200,10 +207,12 @@ class Udp : public InstanceLocator, private NonCopyable /** * Opens the UDP socket. * + * @param[in] aNetifId The network interface identifier. + * * @retval kErrorNone Successfully opened the socket. * @retval kErrorFailed Failed to open the socket. */ - Error Open(void); + Error Open(NetifIdentifier aNetifId); /** * Returns if the UDP socket is open. @@ -215,25 +224,23 @@ class Udp : public InstanceLocator, private NonCopyable /** * Binds the UDP socket. * - * @param[in] aSockAddr A reference to the socket address. - * @param[in] aNetifIdentifier The network interface identifier. + * @param[in] aSockAddr A reference to the socket address. * * @retval kErrorNone Successfully bound the socket. * @retval kErrorInvalidArgs Unable to bind to Thread network interface with the given address. * @retval kErrorFailed Failed to bind UDP Socket. */ - Error Bind(const SockAddr &aSockAddr, NetifIdentifier aNetifIdentifier = kNetifThread); + Error Bind(const SockAddr &aSockAddr); /** * Binds the UDP socket. * - * @param[in] aPort A port number. - * @param[in] aNetifIdentifier The network interface identifier. + * @param[in] aPort A port number. * * @retval kErrorNone Successfully bound the socket. * @retval kErrorFailed Failed to bind UDP Socket. */ - Error Bind(uint16_t aPort, NetifIdentifier aNetifIdentifier = kNetifThread); + Error Bind(uint16_t aPort); /** * Binds the UDP socket. @@ -479,13 +486,14 @@ class Udp : public InstanceLocator, private NonCopyable * Opens a UDP socket. * * @param[in] aSocket A reference to the socket. + * @param[in] aNetifId A network interface identifier. * @param[in] aHandler A pointer to a function that is called when receiving UDP messages. * @param[in] aContext A pointer to arbitrary context information. * * @retval kErrorNone Successfully opened the socket. * @retval kErrorFailed Failed to open the socket. */ - Error Open(SocketHandle &aSocket, ReceiveHandler aHandler, void *aContext); + Error Open(SocketHandle &aSocket, NetifIdentifier aNetifId, ReceiveHandler aHandler, void *aContext); /** * Returns if a UDP socket is open. @@ -501,13 +509,12 @@ class Udp : public InstanceLocator, private NonCopyable * * @param[in] aSocket A reference to the socket. * @param[in] aSockAddr A reference to the socket address. - * @param[in] aNetifIdentifier The network interface identifier. * * @retval kErrorNone Successfully bound the socket. * @retval kErrorInvalidArgs Unable to bind to Thread network interface with the given address. * @retval kErrorFailed Failed to bind UDP Socket. */ - Error Bind(SocketHandle &aSocket, const SockAddr &aSockAddr, NetifIdentifier aNetifIdentifier); + Error Bind(SocketHandle &aSocket, const SockAddr &aSockAddr); /** * Connects a UDP socket. diff --git a/src/core/thread/mle.cpp b/src/core/thread/mle.cpp index 99838b0c9da..e3d184f04aa 100644 --- a/src/core/thread/mle.cpp +++ b/src/core/thread/mle.cpp @@ -127,7 +127,7 @@ Error Mle::Enable(void) Error error = kErrorNone; UpdateLinkLocalAddress(); - SuccessOrExit(error = mSocket.Open()); + SuccessOrExit(error = mSocket.Open(Ip6::kNetifThread)); SuccessOrExit(error = mSocket.Bind(kUdpPort)); #if OPENTHREAD_CONFIG_PARENT_SEARCH_ENABLE diff --git a/tests/unit/test_srp_server.cpp b/tests/unit/test_srp_server.cpp index cab62b23d9a..ec7fa9abfaa 100644 --- a/tests/unit/test_srp_server.cpp +++ b/tests/unit/test_srp_server.cpp @@ -1076,8 +1076,8 @@ void TestSrpClientDelayedResponse(void) sServerRxCount = 0; - SuccessOrQuit(udpSocket.Open()); - SuccessOrQuit(udpSocket.Bind(kServerPort, Ip6::kNetifThread)); + SuccessOrQuit(udpSocket.Open(Ip6::kNetifThread)); + SuccessOrQuit(udpSocket.Bind(kServerPort)); //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // Manually start the client with a message ID based on `testIter`