From d3608df7d451fc4d7defdc18715fd0895caef937 Mon Sep 17 00:00:00 2001 From: Abtin Keshavarzian Date: Wed, 6 Sep 2023 11:22:49 -0700 Subject: [PATCH] [srp-server] process completed update from proxy from taskelt (#9398) This commit enhances `Srp::Server` to process and commit the completed `UpdateMetadata` entries (signaled by the "proxy service handler" calling `HandleServiceUpdateResult()`) from a `Tasklet`. This change is helpful in the case where the `HandleServiceUpdateResult ()` callback is invoked directly from the "update service handler" itself. While `Srp::Server` can handle this situation, the change makes it easier for platform implementations of advertising proxy. In particular, it addresses an issue with the `otbr` advertising proxy implementation. This implementation can potentially access an already freed `Host` object. This can happen because the implementation may hold on to the `Host` object while iterating over its `Service` entries as advertising an earlier `Service` of the same `Host` may fail immediately and invoke the callback directly. This would then cause the `Host` to be freed by `Srp::Server`. --- src/core/common/linked_list.hpp | 20 ++++++++++ src/core/net/srp_server.cpp | 67 +++++++++++++++++++++++---------- src/core/net/srp_server.hpp | 14 +++++-- tests/unit/test_linked_list.cpp | 10 +++++ 4 files changed, 88 insertions(+), 23 deletions(-) diff --git a/src/core/common/linked_list.hpp b/src/core/common/linked_list.hpp index 3f04aef1dad..0d9ecf9a4fc 100644 --- a/src/core/common/linked_list.hpp +++ b/src/core/common/linked_list.hpp @@ -180,6 +180,26 @@ template class LinkedList aPrevEntry.SetNext(&aEntry); } + /** + * Pushes an entry after the tail in the linked list. + * + * @param[in] aEntry A reference to an entry to push into the list. + * + */ + void PushAfterTail(Type &aEntry) + { + Type *tail = GetTail(); + + if (tail == nullptr) + { + Push(aEntry); + } + else + { + PushAfter(aEntry, *tail); + } + } + /** * Pops an entry from head of the linked list. * diff --git a/src/core/net/srp_server.cpp b/src/core/net/srp_server.cpp index 6548506dac1..2a57c9217eb 100644 --- a/src/core/net/srp_server.cpp +++ b/src/core/net/srp_server.cpp @@ -89,6 +89,7 @@ Server::Server(Instance &aInstance) , mSocket(aInstance) , mLeaseTimer(aInstance) , mOutstandingUpdatesTimer(aInstance) + , mCompletedUpdateTask(aInstance) , mServiceUpdateId(Random::NonCrypto::GetUint32()) , mPort(kUdpPortMin) , mState(kStateDisabled) @@ -381,26 +382,26 @@ bool Server::HasNameConflictsWith(Host &aHost) const void Server::HandleServiceUpdateResult(ServiceUpdateId aId, Error aError) { - UpdateMetadata *update = mOutstandingUpdates.FindMatching(aId); + UpdateMetadata *update = mOutstandingUpdates.RemoveMatching(aId); - if (update != nullptr) - { - HandleServiceUpdateResult(update, aError); - } - else + if (update == nullptr) { LogInfo("Delayed SRP host update result, the SRP update has been committed (updateId = %lu)", ToUlong(aId)); + ExitNow(); } -} -void Server::HandleServiceUpdateResult(UpdateMetadata *aUpdate, Error aError) -{ - LogInfo("Handler result of SRP update (id = %lu) is received: %s", ToUlong(aUpdate->GetId()), - ErrorToString(aError)); + update->SetError(aError); + + LogInfo("Handler result of SRP update (id = %lu) is received: %s", ToUlong(update->GetId()), ErrorToString(aError)); - IgnoreError(mOutstandingUpdates.Remove(*aUpdate)); - CommitSrpUpdate(aError, *aUpdate); - aUpdate->Free(); + // We add new `update` at the tail of the `mCompletedUpdates` list + // so that updates are processed in the order we receive the + // `HandleServiceUpdateResult()` callbacks for them. The + // completed updates are processed from `mCompletedUpdateTask` + // and `ProcessCompletedUpdates()`. + + mCompletedUpdates.PushAfterTail(*update); + mCompletedUpdateTask.Post(); if (mOutstandingUpdates.IsEmpty()) { @@ -410,6 +411,19 @@ void Server::HandleServiceUpdateResult(UpdateMetadata *aUpdate, Error aError) { mOutstandingUpdatesTimer.FireAt(mOutstandingUpdates.GetTail()->GetExpireTime()); } + +exit: + return; +} + +void Server::ProcessCompletedUpdates(void) +{ + UpdateMetadata *update; + + while ((update = mCompletedUpdates.Pop()) != nullptr) + { + CommitSrpUpdate(*update); + } } void Server::CommitSrpUpdate(Error aError, Host &aHost, const MessageMetadata &aMessageMetadata) @@ -418,11 +432,13 @@ void Server::CommitSrpUpdate(Error aError, Host &aHost, const MessageMetadata &a aMessageMetadata.mTtlConfig, aMessageMetadata.mLeaseConfig); } -void Server::CommitSrpUpdate(Error aError, UpdateMetadata &aUpdateMetadata) +void Server::CommitSrpUpdate(UpdateMetadata &aUpdateMetadata) { - CommitSrpUpdate(aError, aUpdateMetadata.GetHost(), aUpdateMetadata.GetDnsHeader(), + CommitSrpUpdate(aUpdateMetadata.GetError(), aUpdateMetadata.GetHost(), aUpdateMetadata.GetDnsHeader(), aUpdateMetadata.IsDirectRxFromClient() ? &aUpdateMetadata.GetMessageInfo() : nullptr, aUpdateMetadata.GetTtlConfig(), aUpdateMetadata.GetLeaseConfig()); + + aUpdateMetadata.Free(); } void Server::CommitSrpUpdate(Error aError, @@ -1663,10 +1679,22 @@ void Server::HandleLeaseTimer(void) void Server::HandleOutstandingUpdatesTimer(void) { - while (!mOutstandingUpdates.IsEmpty() && mOutstandingUpdates.GetTail()->GetExpireTime() <= TimerMilli::GetNow()) + TimeMilli now = TimerMilli::GetNow(); + UpdateMetadata *update; + + while ((update = mOutstandingUpdates.GetTail()) != nullptr) { - LogInfo("Outstanding service update timeout (updateId = %lu)", ToUlong(mOutstandingUpdates.GetTail()->GetId())); - HandleServiceUpdateResult(mOutstandingUpdates.GetTail(), kErrorResponseTimeout); + if (update->GetExpireTime() > now) + { + mOutstandingUpdatesTimer.FireAtIfEarlier(update->GetExpireTime()); + break; + } + + LogInfo("Outstanding service update timeout (updateId = %lu)", ToUlong(update->GetId())); + + IgnoreError(mOutstandingUpdates.Remove(*update)); + update->SetError(kErrorResponseTimeout); + CommitSrpUpdate(*update); } } @@ -2097,6 +2125,7 @@ Server::UpdateMetadata::UpdateMetadata(Instance &aInstance, Host &aHost, const M , mTtlConfig(aMessageMetadata.mTtlConfig) , mLeaseConfig(aMessageMetadata.mLeaseConfig) , mHost(aHost) + , mError(kErrorNone) , mIsDirectRxFromClient(aMessageMetadata.IsDirectRxFromClient()) { if (aMessageMetadata.mMessageInfo != nullptr) diff --git a/src/core/net/srp_server.hpp b/src/core/net/srp_server.hpp index e4d7cdf560a..bb73c154a97 100644 --- a/src/core/net/srp_server.hpp +++ b/src/core/net/srp_server.hpp @@ -912,6 +912,8 @@ class Server : public InstanceLocator, private NonCopyable const LeaseConfig &GetLeaseConfig(void) const { return mLeaseConfig; } Host &GetHost(void) { return mHost; } const Ip6::MessageInfo &GetMessageInfo(void) const { return mMessageInfo; } + Error GetError(void) const { return mError; } + void SetError(Error aError) { mError = aError; } bool IsDirectRxFromClient(void) const { return mIsDirectRxFromClient; } bool Matches(ServiceUpdateId aId) const { return mId == aId; } @@ -926,6 +928,7 @@ class Server : public InstanceLocator, private NonCopyable LeaseConfig mLeaseConfig; // Lease config to use when processing the message. Host &mHost; // The `UpdateMetadata` has no ownership of this host. Ip6::MessageInfo mMessageInfo; // Valid when `mIsDirectRxFromClient` is true. + Error mError; bool mIsDirectRxFromClient; }; @@ -948,7 +951,7 @@ class Server : public InstanceLocator, private NonCopyable void InformUpdateHandlerOrCommit(Error aError, Host &aHost, const MessageMetadata &aMetadata); void CommitSrpUpdate(Error aError, Host &aHost, const MessageMetadata &aMessageMetadata); - void CommitSrpUpdate(Error aError, UpdateMetadata &aUpdateMetadata); + void CommitSrpUpdate(UpdateMetadata &aUpdateMetadata); void CommitSrpUpdate(Error aError, Host &aHost, const Dns::UpdateHeader &aDnsHeader, @@ -998,15 +1001,16 @@ class Server : public InstanceLocator, private NonCopyable void HandleLeaseTimer(void); static void HandleOutstandingUpdatesTimer(Timer &aTimer); void HandleOutstandingUpdatesTimer(void); + void ProcessCompletedUpdates(void); - void HandleServiceUpdateResult(UpdateMetadata *aUpdate, Error aError); const UpdateMetadata *FindOutstandingUpdate(const MessageMetadata &aMessageMetadata) const; static const char *AddressModeToString(AddressMode aMode); void UpdateResponseCounters(Dns::Header::Response aResponseCode); - using LeaseTimer = TimerMilliIn; - using UpdateTimer = TimerMilliIn; + using LeaseTimer = TimerMilliIn; + using UpdateTimer = TimerMilliIn; + using CompletedUpdatesTask = TaskletIn; Ip6::Udp::Socket mSocket; @@ -1022,6 +1026,8 @@ class Server : public InstanceLocator, private NonCopyable UpdateTimer mOutstandingUpdatesTimer; LinkedList mOutstandingUpdates; + LinkedList mCompletedUpdates; + CompletedUpdatesTask mCompletedUpdateTask; ServiceUpdateId mServiceUpdateId; uint16_t mPort; diff --git a/tests/unit/test_linked_list.cpp b/tests/unit/test_linked_list.cpp index 40a4f7f0517..0032ee2c1a5 100644 --- a/tests/unit/test_linked_list.cpp +++ b/tests/unit/test_linked_list.cpp @@ -294,6 +294,16 @@ void TestLinkedList(void) list.RemoveAllMatching(kBetaType, removedList); VerifyLinkedListContent(&list, &a, &b, &e, nullptr); VerifyLinkedListContent(&removedList, &f, &d, &c, nullptr); + + list.Clear(); + list.PushAfterTail(a); + VerifyLinkedListContent(&list, &a, nullptr); + list.PushAfterTail(b); + VerifyLinkedListContent(&list, &a, &b, nullptr); + list.PushAfterTail(c); + VerifyLinkedListContent(&list, &a, &b, &c, nullptr); + list.PushAfterTail(d); + VerifyLinkedListContent(&list, &a, &b, &c, &d, nullptr); } void TestOwningList(void)