From 0cce3f66b3737aa11d913cbff1de8783b4ad137f Mon Sep 17 00:00:00 2001 From: Abtin Keshavarzian Date: Fri, 8 Sep 2023 17:18:04 -0700 Subject: [PATCH] [mdns-msdnssd] simplify service registration code --- src/mdns/mdns_mdnssd.cpp | 235 +++++++++++++++++++-------------------- src/mdns/mdns_mdnssd.hpp | 61 +++------- 2 files changed, 130 insertions(+), 166 deletions(-) diff --git a/src/mdns/mdns_mdnssd.cpp b/src/mdns/mdns_mdnssd.cpp index 227e8caccbf..0c1dca577ca 100644 --- a/src/mdns/mdns_mdnssd.cpp +++ b/src/mdns/mdns_mdnssd.cpp @@ -227,9 +227,16 @@ PublisherMDnsSd::~PublisherMDnsSd(void) otbrError PublisherMDnsSd::Start(void) { + DNSServiceErrorType dnsError; + + SuccessOrExit(dnsError = DNSServiceCreateConnection(&mHostsRef)); + otbrLogDebug("Created new shared DNSServiceRef: %p", mHostsRef); + mState = State::kReady; mStateCallback(State::kReady); - return OTBR_ERROR_NONE; + +exit: + return DnsErrorToOtbrError(dnsError); } bool PublisherMDnsSd::IsStarted(void) const @@ -270,14 +277,15 @@ void PublisherMDnsSd::Update(MainloopContext &aMainloop) { auto &serviceReg = static_cast(*kv.second); - assert(serviceReg.GetServiceRef() != nullptr); - - int fd = DNSServiceRefSockFD(serviceReg.GetServiceRef()); - - if (fd != -1) + if (serviceReg.mServiceRef != nullptr) { - FD_SET(fd, &aMainloop.mReadFdSet); - aMainloop.mMaxFd = std::max(aMainloop.mMaxFd, fd); + int fd = DNSServiceRefSockFD(serviceReg.mServiceRef); + + if (fd != -1) + { + FD_SET(fd, &aMainloop.mReadFdSet); + aMainloop.mMaxFd = std::max(aMainloop.mMaxFd, fd); + } } } @@ -310,11 +318,15 @@ void PublisherMDnsSd::Process(const MainloopContext &aMainloop) for (auto &kv : mServiceRegistrations) { auto &serviceReg = static_cast(*kv.second); - int fd = DNSServiceRefSockFD(serviceReg.GetServiceRef()); - if (FD_ISSET(fd, &aMainloop.mReadFdSet)) + if (serviceReg.mServiceRef != nullptr) { - readyServices.push_back(serviceReg.GetServiceRef()); + int fd = DNSServiceRefSockFD(serviceReg.mServiceRef); + + if (FD_ISSET(fd, &aMainloop.mReadFdSet)) + { + readyServices.push_back(serviceReg.mServiceRef); + } } } @@ -360,11 +372,76 @@ void PublisherMDnsSd::Process(const MainloopContext &aMainloop) return; } -PublisherMDnsSd::DnssdServiceRegistration::~DnssdServiceRegistration(void) +otbrError PublisherMDnsSd::DnssdServiceRegistration::Register(void) +{ + std::string fullHostName; + std::string regType = MakeRegType(mType, mSubTypeList); + const char *hostNameCString = nullptr; + const char *serviceNameCString = nullptr; + DNSServiceErrorType dnsError; + + if (!mHostName.empty()) + { + fullHostName = MakeFullHostName(mHostName); + hostNameCString = fullHostName.c_str(); + } + + if (!mName.empty()) + { + serviceNameCString = mName.c_str(); + } + + otbrLogInfo("Registering service %s.%s", mName.c_str(), regType.c_str()); + + dnsError = DNSServiceRegister(&mServiceRef, kDNSServiceFlagsNoAutoRename, kDNSServiceInterfaceIndexAny, + serviceNameCString, regType.c_str(), + /* domain */ nullptr, hostNameCString, htons(mPort), mTxtData.size(), mTxtData.data(), + HandleRegisterResult, this); + + if (dnsError != kDNSServiceErr_NoError) + { + HandleRegisterResult(/* aFlags */ 0, dnsError); + } + + return GetPublisher().DnsErrorToOtbrError(dnsError); +} + +void PublisherMDnsSd::DnssdServiceRegistration::Unregister(void) { if (mServiceRef != nullptr) { DNSServiceRefDeallocate(mServiceRef); + mServiceRef = nullptr; + } +} + +void PublisherMDnsSd::DnssdServiceRegistration::HandleRegisterResult(DNSServiceRef aServiceRef, + DNSServiceFlags aFlags, + DNSServiceErrorType aError, + const char *aName, + const char *aType, + const char *aDomain, + void *aContext) +{ + OTBR_UNUSED_VARIABLE(aServiceRef); + OTBR_UNUSED_VARIABLE(aName); + OTBR_UNUSED_VARIABLE(aType); + OTBR_UNUSED_VARIABLE(aDomain); + + static_cast(aContext)->HandleRegisterResult(aFlags, aError); +} + +void PublisherMDnsSd::DnssdServiceRegistration::HandleRegisterResult(DNSServiceFlags aFlags, DNSServiceErrorType aError) +{ + if ((aError == kDNSServiceErr_NoError) && (aFlags & kDNSServiceFlagsAdd)) + { + otbrLogInfo("Successfully registered service %s.%s", mName.c_str(), mType.c_str()); + Complete(OTBR_ERROR_NONE); + } + else + { + otbrLogErr("Failed to register service %s.%s: %s", mName.c_str(), mType.c_str(), DNSErrorToString(aError)); + GetPublisher().RemoveServiceRegistration(mName, mType, DNSErrorToOtbrError(aError)); } } @@ -402,87 +479,25 @@ PublisherMDnsSd::DnssdHostRegistration::~DnssdHostRegistration(void) return; } -Publisher::ServiceRegistration *PublisherMDnsSd::FindServiceRegistration(const DNSServiceRef &aServiceRef) +PublisherMDnsSd::DnssdHostRegistration *PublisherMDnsSd::FindHostRegistration(const DNSServiceRef &aServiceRef, + const DNSRecordRef &aRecordRef) { - ServiceRegistration *result = nullptr; - - for (auto &kv : mServiceRegistrations) - { - // We are sure that the service registrations must be instances of `DnssdServiceRegistration`. - auto &serviceReg = static_cast(*kv.second); - - if (serviceReg.GetServiceRef() == aServiceRef) - { - result = kv.second.get(); - break; - } - } - - return result; -} - -Publisher::HostRegistration *PublisherMDnsSd::FindHostRegistration(const DNSServiceRef &aServiceRef, - const DNSRecordRef &aRecordRef) -{ - HostRegistration *result = nullptr; + DnssdHostRegistration *hostReg; for (auto &kv : mHostRegistrations) { - // We are sure that the host registrations must be instances of `DnssdHostRegistration`. - auto &hostReg = static_cast(*kv.second); + hostReg = static_cast(kv.second.get()); - if (hostReg.GetServiceRef() == aServiceRef && hostReg.GetRecordRefMap().count(aRecordRef)) + if ((hostReg->mServiceRef == aServiceRef) && hostReg->mRecordRefMap.count(aRecordRef)) { - result = kv.second.get(); - break; + ExitNow(); } } - return result; -} - -void PublisherMDnsSd::HandleServiceRegisterResult(DNSServiceRef aService, - const DNSServiceFlags aFlags, - DNSServiceErrorType aError, - const char *aName, - const char *aType, - const char *aDomain, - void *aContext) -{ - static_cast(aContext)->HandleServiceRegisterResult(aService, aFlags, aError, aName, aType, - aDomain); -} - -void PublisherMDnsSd::HandleServiceRegisterResult(DNSServiceRef aServiceRef, - const DNSServiceFlags aFlags, - DNSServiceErrorType aError, - const char *aName, - const char *aType, - const char *aDomain) -{ - OTBR_UNUSED_VARIABLE(aDomain); - - otbrError error = DNSErrorToOtbrError(aError); - ServiceRegistration *serviceReg = FindServiceRegistration(aServiceRef); - serviceReg->mName = aName; - - otbrLogInfo("Received reply for service %s.%s, serviceRef = %p", aName, aType, aServiceRef); - - VerifyOrExit(serviceReg != nullptr); - - if (aError == kDNSServiceErr_NoError && (aFlags & kDNSServiceFlagsAdd)) - { - otbrLogInfo("Successfully registered service %s.%s", aName, aType); - serviceReg->Complete(OTBR_ERROR_NONE); - } - else - { - otbrLogErr("Failed to register service %s.%s: %s", aName, aType, DNSErrorToString(aError)); - RemoveServiceRegistration(serviceReg->mName, serviceReg->mType, error); - } + hostReg = nullptr; exit: - return; + return hostReg; } otbrError PublisherMDnsSd::PublishServiceImpl(const std::string &aHostName, @@ -493,56 +508,30 @@ otbrError PublisherMDnsSd::PublishServiceImpl(const std::string &aHostName, const TxtData &aTxtData, ResultCallback &&aCallback) { - otbrError ret = OTBR_ERROR_NONE; - int error = 0; - SubTypeList sortedSubTypeList = SortSubTypeList(aSubTypeList); - std::string regType = MakeRegType(aType, sortedSubTypeList); - DNSServiceRef serviceRef = nullptr; - std::string fullHostName; - const char *hostNameCString = nullptr; - const char *serviceNameCString = nullptr; - - VerifyOrExit(mState == State::kReady, ret = OTBR_ERROR_INVALID_STATE); - - if (!aHostName.empty()) - { - fullHostName = MakeFullHostName(aHostName); - hostNameCString = fullHostName.c_str(); - } - if (!aName.empty()) + otbrError error = OTBR_ERROR_NONE; + SubTypeList sortedSubTypeList = SortSubTypeList(aSubTypeList); + std::string regType = MakeRegType(aType, sortedSubTypeList); + DnssdServiceRegistration *serviceReg; + + if (mState != State::kReady) { - serviceNameCString = aName.c_str(); + error = OTBR_ERROR_INVALID_STATE; + std::move(aCallback)(error); + ExitNow(); } aCallback = HandleDuplicateServiceRegistration(aHostName, aName, aType, sortedSubTypeList, aPort, aTxtData, std::move(aCallback)); VerifyOrExit(!aCallback.IsNull()); - otbrLogInfo("Registering new service %s.%s.local, serviceRef = %p", aName.c_str(), regType.c_str(), serviceRef); - SuccessOrExit(error = DNSServiceRegister(&serviceRef, kDNSServiceFlagsNoAutoRename, kDNSServiceInterfaceIndexAny, - serviceNameCString, regType.c_str(), - /* domain */ nullptr, hostNameCString, htons(aPort), aTxtData.size(), - aTxtData.data(), HandleServiceRegisterResult, this)); - AddServiceRegistration(std::unique_ptr(new DnssdServiceRegistration( - aHostName, aName, aType, sortedSubTypeList, aPort, aTxtData, std::move(aCallback), serviceRef, this))); + serviceReg = new DnssdServiceRegistration(aHostName, aName, aType, sortedSubTypeList, aPort, aTxtData, + std::move(aCallback), this); + AddServiceRegistration(std::unique_ptr(serviceReg)); -exit: - if (error != kDNSServiceErr_NoError || ret != OTBR_ERROR_NONE) - { - if (error != kDNSServiceErr_NoError) - { - ret = DNSErrorToOtbrError(error); - otbrLogErr("Failed to publish service %s.%s for mdnssd error: %s!", aName.c_str(), aType.c_str(), - DNSErrorToString(error)); - } + error = serviceReg->Register(); - if (serviceRef != nullptr) - { - DNSServiceRefDeallocate(serviceRef); - } - std::move(aCallback)(ret); - } - return ret; +exit: + return error; } void PublisherMDnsSd::UnpublishService(const std::string &aName, const std::string &aType, ResultCallback &&aCallback) diff --git a/src/mdns/mdns_mdnssd.hpp b/src/mdns/mdns_mdnssd.hpp index 37af069ff1d..e82a952c1ce 100644 --- a/src/mdns/mdns_mdnssd.hpp +++ b/src/mdns/mdns_mdnssd.hpp @@ -106,32 +106,24 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher class DnssdServiceRegistration : public ServiceRegistration { public: - DnssdServiceRegistration(const std::string &aHostName, - const std::string &aName, - const std::string &aType, - const SubTypeList &aSubTypeList, - uint16_t aPort, - const TxtData &aTxtData, - ResultCallback &&aCallback, - DNSServiceRef aServiceRef, - PublisherMDnsSd *aPublisher) - : ServiceRegistration(aHostName, - aName, - aType, - aSubTypeList, - aPort, - aTxtData, - std::move(aCallback), - aPublisher) - , mServiceRef(aServiceRef) - { - } + using ServiceRegistration::ServiceRegistration; // Inherit base constructor - ~DnssdServiceRegistration(void) override; - const DNSServiceRef &GetServiceRef() const { return mServiceRef; } + ~DnssdServiceRegistration(void) override { Unregister(); } - private: - DNSServiceRef mServiceRef; + otbrError Register(void); + void Unregister(void); + void HandleRegisterResult(DNSServiceFlags aFlags, DNSServiceErrorType aError); + PublisherMDnsSd &GetPublisher(void) { return *static_cast(mPublisher); } + + static void HandleRegisterResult(DNSServiceRef aServiceRef, + DNSServiceFlags aFlags, + DNSServiceErrorType aError, + const char *aName, + const char *aType, + const char *aDomain, + void *aContext); + + DNSServiceRef mServiceRef = nullptr; }; class DnssdHostRegistration : public HostRegistration @@ -154,10 +146,7 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher const std::map &GetRecordRefMap() const { return mRecordRefMap; } std::map &GetRecordRefMap() { return mRecordRefMap; } - private: - DNSServiceRef mServiceRef; - - public: + DNSServiceRef mServiceRef; std::map mRecordRefMap; uint32_t mCallbackCount; }; @@ -320,19 +309,6 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher using ServiceSubscriptionList = std::vector>; using HostSubscriptionList = std::vector>; - static void HandleServiceRegisterResult(DNSServiceRef aService, - const DNSServiceFlags aFlags, - DNSServiceErrorType aError, - const char *aName, - const char *aType, - const char *aDomain, - void *aContext); - void HandleServiceRegisterResult(DNSServiceRef aService, - const DNSServiceFlags aFlags, - DNSServiceErrorType aError, - const char *aName, - const char *aType, - const char *aDomain); static void HandleRegisterHostResult(DNSServiceRef aHostsConnection, DNSRecordRef aHostRecord, DNSServiceFlags aFlags, @@ -345,8 +321,7 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher static std::string MakeRegType(const std::string &aType, SubTypeList aSubTypeList); - ServiceRegistration *FindServiceRegistration(const DNSServiceRef &aServiceRef); - HostRegistration *FindHostRegistration(const DNSServiceRef &aServiceRef, const DNSRecordRef &aRecordRef); + DnssdHostRegistration *FindHostRegistration(const DNSServiceRef &aServiceRef, const DNSRecordRef &aRecordRef); DNSServiceRef mHostsRef; State mState;