Skip to content

Commit

Permalink
[mdns-msdnssd] simplify service registration code
Browse files Browse the repository at this point in the history
  • Loading branch information
abtink committed Sep 9, 2023
1 parent 8a280e3 commit 0cce3f6
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 166 deletions.
235 changes: 112 additions & 123 deletions src/mdns/mdns_mdnssd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,16 @@ PublisherMDnsSd::~PublisherMDnsSd(void)

otbrError PublisherMDnsSd::Start(void)
{
DNSServiceErrorType dnsError;

SuccessOrExit(dnsError = DNSServiceCreateConnection(&mHostsRef));
otbrLogDebug("Created new shared DNSServiceRef: %p", mHostsRef);

mState = State::kReady;
mStateCallback(State::kReady);
return OTBR_ERROR_NONE;

exit:
return DnsErrorToOtbrError(dnsError);
}

bool PublisherMDnsSd::IsStarted(void) const
Expand Down Expand Up @@ -270,14 +277,15 @@ void PublisherMDnsSd::Update(MainloopContext &aMainloop)
{
auto &serviceReg = static_cast<DnssdServiceRegistration &>(*kv.second);

assert(serviceReg.GetServiceRef() != nullptr);

int fd = DNSServiceRefSockFD(serviceReg.GetServiceRef());

if (fd != -1)
if (serviceReg.mServiceRef != nullptr)
{
FD_SET(fd, &aMainloop.mReadFdSet);
aMainloop.mMaxFd = std::max(aMainloop.mMaxFd, fd);
int fd = DNSServiceRefSockFD(serviceReg.mServiceRef);

if (fd != -1)
{
FD_SET(fd, &aMainloop.mReadFdSet);
aMainloop.mMaxFd = std::max(aMainloop.mMaxFd, fd);
}
}
}

Expand Down Expand Up @@ -310,11 +318,15 @@ void PublisherMDnsSd::Process(const MainloopContext &aMainloop)
for (auto &kv : mServiceRegistrations)
{
auto &serviceReg = static_cast<DnssdServiceRegistration &>(*kv.second);
int fd = DNSServiceRefSockFD(serviceReg.GetServiceRef());

if (FD_ISSET(fd, &aMainloop.mReadFdSet))
if (serviceReg.mServiceRef != nullptr)
{
readyServices.push_back(serviceReg.GetServiceRef());
int fd = DNSServiceRefSockFD(serviceReg.mServiceRef);

if (FD_ISSET(fd, &aMainloop.mReadFdSet))
{
readyServices.push_back(serviceReg.mServiceRef);
}
}
}

Expand Down Expand Up @@ -360,11 +372,76 @@ void PublisherMDnsSd::Process(const MainloopContext &aMainloop)
return;
}

PublisherMDnsSd::DnssdServiceRegistration::~DnssdServiceRegistration(void)
otbrError PublisherMDnsSd::DnssdServiceRegistration::Register(void)
{
std::string fullHostName;
std::string regType = MakeRegType(mType, mSubTypeList);
const char *hostNameCString = nullptr;
const char *serviceNameCString = nullptr;
DNSServiceErrorType dnsError;

if (!mHostName.empty())
{
fullHostName = MakeFullHostName(mHostName);
hostNameCString = fullHostName.c_str();
}

if (!mName.empty())
{
serviceNameCString = mName.c_str();
}

otbrLogInfo("Registering service %s.%s", mName.c_str(), regType.c_str());

dnsError = DNSServiceRegister(&mServiceRef, kDNSServiceFlagsNoAutoRename, kDNSServiceInterfaceIndexAny,
serviceNameCString, regType.c_str(),
/* domain */ nullptr, hostNameCString, htons(mPort), mTxtData.size(), mTxtData.data(),
HandleRegisterResult, this);

if (dnsError != kDNSServiceErr_NoError)
{
HandleRegisterResult(/* aFlags */ 0, dnsError);
}

return GetPublisher().DnsErrorToOtbrError(dnsError);
}

void PublisherMDnsSd::DnssdServiceRegistration::Unregister(void)
{
if (mServiceRef != nullptr)
{
DNSServiceRefDeallocate(mServiceRef);
mServiceRef = nullptr;
}
}

void PublisherMDnsSd::DnssdServiceRegistration::HandleRegisterResult(DNSServiceRef aServiceRef,
DNSServiceFlags aFlags,
DNSServiceErrorType aError,
const char *aName,
const char *aType,
const char *aDomain,
void *aContext)
{
OTBR_UNUSED_VARIABLE(aServiceRef);
OTBR_UNUSED_VARIABLE(aName);
OTBR_UNUSED_VARIABLE(aType);
OTBR_UNUSED_VARIABLE(aDomain);

static_cast<DnssdServiceRegistration *>(aContext)->HandleRegisterResult(aFlags, aError);
}

void PublisherMDnsSd::DnssdServiceRegistration::HandleRegisterResult(DNSServiceFlags aFlags, DNSServiceErrorType aError)
{
if ((aError == kDNSServiceErr_NoError) && (aFlags & kDNSServiceFlagsAdd))
{
otbrLogInfo("Successfully registered service %s.%s", mName.c_str(), mType.c_str());
Complete(OTBR_ERROR_NONE);
}
else
{
otbrLogErr("Failed to register service %s.%s: %s", mName.c_str(), mType.c_str(), DNSErrorToString(aError));
GetPublisher().RemoveServiceRegistration(mName, mType, DNSErrorToOtbrError(aError));
}
}

Expand Down Expand Up @@ -402,87 +479,25 @@ PublisherMDnsSd::DnssdHostRegistration::~DnssdHostRegistration(void)
return;
}

Publisher::ServiceRegistration *PublisherMDnsSd::FindServiceRegistration(const DNSServiceRef &aServiceRef)
PublisherMDnsSd::DnssdHostRegistration *PublisherMDnsSd::FindHostRegistration(const DNSServiceRef &aServiceRef,
const DNSRecordRef &aRecordRef)
{
ServiceRegistration *result = nullptr;

for (auto &kv : mServiceRegistrations)
{
// We are sure that the service registrations must be instances of `DnssdServiceRegistration`.
auto &serviceReg = static_cast<DnssdServiceRegistration &>(*kv.second);

if (serviceReg.GetServiceRef() == aServiceRef)
{
result = kv.second.get();
break;
}
}

return result;
}

Publisher::HostRegistration *PublisherMDnsSd::FindHostRegistration(const DNSServiceRef &aServiceRef,
const DNSRecordRef &aRecordRef)
{
HostRegistration *result = nullptr;
DnssdHostRegistration *hostReg;

for (auto &kv : mHostRegistrations)
{
// We are sure that the host registrations must be instances of `DnssdHostRegistration`.
auto &hostReg = static_cast<DnssdHostRegistration &>(*kv.second);
hostReg = static_cast<DnssdHostRegistration *>(kv.second.get());

if (hostReg.GetServiceRef() == aServiceRef && hostReg.GetRecordRefMap().count(aRecordRef))
if ((hostReg->mServiceRef == aServiceRef) && hostReg->mRecordRefMap.count(aRecordRef))
{
result = kv.second.get();
break;
ExitNow();
}
}

return result;
}

void PublisherMDnsSd::HandleServiceRegisterResult(DNSServiceRef aService,
const DNSServiceFlags aFlags,
DNSServiceErrorType aError,
const char *aName,
const char *aType,
const char *aDomain,
void *aContext)
{
static_cast<PublisherMDnsSd *>(aContext)->HandleServiceRegisterResult(aService, aFlags, aError, aName, aType,
aDomain);
}

void PublisherMDnsSd::HandleServiceRegisterResult(DNSServiceRef aServiceRef,
const DNSServiceFlags aFlags,
DNSServiceErrorType aError,
const char *aName,
const char *aType,
const char *aDomain)
{
OTBR_UNUSED_VARIABLE(aDomain);

otbrError error = DNSErrorToOtbrError(aError);
ServiceRegistration *serviceReg = FindServiceRegistration(aServiceRef);
serviceReg->mName = aName;

otbrLogInfo("Received reply for service %s.%s, serviceRef = %p", aName, aType, aServiceRef);

VerifyOrExit(serviceReg != nullptr);

if (aError == kDNSServiceErr_NoError && (aFlags & kDNSServiceFlagsAdd))
{
otbrLogInfo("Successfully registered service %s.%s", aName, aType);
serviceReg->Complete(OTBR_ERROR_NONE);
}
else
{
otbrLogErr("Failed to register service %s.%s: %s", aName, aType, DNSErrorToString(aError));
RemoveServiceRegistration(serviceReg->mName, serviceReg->mType, error);
}
hostReg = nullptr;

exit:
return;
return hostReg;
}

otbrError PublisherMDnsSd::PublishServiceImpl(const std::string &aHostName,
Expand All @@ -493,56 +508,30 @@ otbrError PublisherMDnsSd::PublishServiceImpl(const std::string &aHostName,
const TxtData &aTxtData,
ResultCallback &&aCallback)
{
otbrError ret = OTBR_ERROR_NONE;
int error = 0;
SubTypeList sortedSubTypeList = SortSubTypeList(aSubTypeList);
std::string regType = MakeRegType(aType, sortedSubTypeList);
DNSServiceRef serviceRef = nullptr;
std::string fullHostName;
const char *hostNameCString = nullptr;
const char *serviceNameCString = nullptr;

VerifyOrExit(mState == State::kReady, ret = OTBR_ERROR_INVALID_STATE);

if (!aHostName.empty())
{
fullHostName = MakeFullHostName(aHostName);
hostNameCString = fullHostName.c_str();
}
if (!aName.empty())
otbrError error = OTBR_ERROR_NONE;
SubTypeList sortedSubTypeList = SortSubTypeList(aSubTypeList);
std::string regType = MakeRegType(aType, sortedSubTypeList);
DnssdServiceRegistration *serviceReg;

if (mState != State::kReady)
{
serviceNameCString = aName.c_str();
error = OTBR_ERROR_INVALID_STATE;
std::move(aCallback)(error);
ExitNow();
}

aCallback = HandleDuplicateServiceRegistration(aHostName, aName, aType, sortedSubTypeList, aPort, aTxtData,
std::move(aCallback));
VerifyOrExit(!aCallback.IsNull());

otbrLogInfo("Registering new service %s.%s.local, serviceRef = %p", aName.c_str(), regType.c_str(), serviceRef);
SuccessOrExit(error = DNSServiceRegister(&serviceRef, kDNSServiceFlagsNoAutoRename, kDNSServiceInterfaceIndexAny,
serviceNameCString, regType.c_str(),
/* domain */ nullptr, hostNameCString, htons(aPort), aTxtData.size(),
aTxtData.data(), HandleServiceRegisterResult, this));
AddServiceRegistration(std::unique_ptr<DnssdServiceRegistration>(new DnssdServiceRegistration(
aHostName, aName, aType, sortedSubTypeList, aPort, aTxtData, std::move(aCallback), serviceRef, this)));
serviceReg = new DnssdServiceRegistration(aHostName, aName, aType, sortedSubTypeList, aPort, aTxtData,
std::move(aCallback), this);
AddServiceRegistration(std::unique_ptr<DnssdServiceRegistration>(serviceReg));

exit:
if (error != kDNSServiceErr_NoError || ret != OTBR_ERROR_NONE)
{
if (error != kDNSServiceErr_NoError)
{
ret = DNSErrorToOtbrError(error);
otbrLogErr("Failed to publish service %s.%s for mdnssd error: %s!", aName.c_str(), aType.c_str(),
DNSErrorToString(error));
}
error = serviceReg->Register();

if (serviceRef != nullptr)
{
DNSServiceRefDeallocate(serviceRef);
}
std::move(aCallback)(ret);
}
return ret;
exit:
return error;
}

void PublisherMDnsSd::UnpublishService(const std::string &aName, const std::string &aType, ResultCallback &&aCallback)
Expand Down
61 changes: 18 additions & 43 deletions src/mdns/mdns_mdnssd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,32 +106,24 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher
class DnssdServiceRegistration : public ServiceRegistration
{
public:
DnssdServiceRegistration(const std::string &aHostName,
const std::string &aName,
const std::string &aType,
const SubTypeList &aSubTypeList,
uint16_t aPort,
const TxtData &aTxtData,
ResultCallback &&aCallback,
DNSServiceRef aServiceRef,
PublisherMDnsSd *aPublisher)
: ServiceRegistration(aHostName,
aName,
aType,
aSubTypeList,
aPort,
aTxtData,
std::move(aCallback),
aPublisher)
, mServiceRef(aServiceRef)
{
}
using ServiceRegistration::ServiceRegistration; // Inherit base constructor

~DnssdServiceRegistration(void) override;
const DNSServiceRef &GetServiceRef() const { return mServiceRef; }
~DnssdServiceRegistration(void) override { Unregister(); }

private:
DNSServiceRef mServiceRef;
otbrError Register(void);
void Unregister(void);
void HandleRegisterResult(DNSServiceFlags aFlags, DNSServiceErrorType aError);
PublisherMDnsSd &GetPublisher(void) { return *static_cast<PublisherMDnsSd *>(mPublisher); }

static void HandleRegisterResult(DNSServiceRef aServiceRef,
DNSServiceFlags aFlags,
DNSServiceErrorType aError,
const char *aName,
const char *aType,
const char *aDomain,
void *aContext);

DNSServiceRef mServiceRef = nullptr;
};

class DnssdHostRegistration : public HostRegistration
Expand All @@ -154,10 +146,7 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher
const std::map<DNSRecordRef, Ip6Address> &GetRecordRefMap() const { return mRecordRefMap; }
std::map<DNSRecordRef, Ip6Address> &GetRecordRefMap() { return mRecordRefMap; }

private:
DNSServiceRef mServiceRef;

public:
DNSServiceRef mServiceRef;
std::map<DNSRecordRef, Ip6Address> mRecordRefMap;
uint32_t mCallbackCount;
};
Expand Down Expand Up @@ -320,19 +309,6 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher
using ServiceSubscriptionList = std::vector<std::unique_ptr<ServiceSubscription>>;
using HostSubscriptionList = std::vector<std::unique_ptr<HostSubscription>>;

static void HandleServiceRegisterResult(DNSServiceRef aService,
const DNSServiceFlags aFlags,
DNSServiceErrorType aError,
const char *aName,
const char *aType,
const char *aDomain,
void *aContext);
void HandleServiceRegisterResult(DNSServiceRef aService,
const DNSServiceFlags aFlags,
DNSServiceErrorType aError,
const char *aName,
const char *aType,
const char *aDomain);
static void HandleRegisterHostResult(DNSServiceRef aHostsConnection,
DNSRecordRef aHostRecord,
DNSServiceFlags aFlags,
Expand All @@ -345,8 +321,7 @@ class PublisherMDnsSd : public MainloopProcessor, public Publisher

static std::string MakeRegType(const std::string &aType, SubTypeList aSubTypeList);

ServiceRegistration *FindServiceRegistration(const DNSServiceRef &aServiceRef);
HostRegistration *FindHostRegistration(const DNSServiceRef &aServiceRef, const DNSRecordRef &aRecordRef);
DnssdHostRegistration *FindHostRegistration(const DNSServiceRef &aServiceRef, const DNSRecordRef &aRecordRef);

DNSServiceRef mHostsRef;
State mState;
Expand Down

0 comments on commit 0cce3f6

Please sign in to comment.