Skip to content

Commit

Permalink
Fix accessing destroyed objects in the callback of async_wait
Browse files Browse the repository at this point in the history
Fixes apache#358
Fixes apache#359

### Motivation

`async_wait` is not used correctly in some places. A callback that
captures the `this` pointer or reference to `this` is passed to
`async_wait`, if this object is destroyed when the callback is called,
an invalid memory access will happen.

### Modifications

Use the following pattern in all `async_wait` calls.

```c++
std::weak_ptr<T> weakSelf{shared_from_this()};
timer_->async_wait([weakSelf](/* ... */) {
    if (auto self = weakSelf.lock()) {
        self->foo();
    }
});
```
  • Loading branch information
BewareMyPower committed Dec 4, 2023
1 parent 0bbc155 commit 0729a9e
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 30 deletions.
12 changes: 7 additions & 5 deletions lib/ConsumerImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr client, const std::string& topic,
consumerName_(config_.getConsumerName()),
consumerStr_("[" + topic + ", " + subscriptionName + ", " + std::to_string(consumerId_) + "] "),
messageListenerRunning_(true),
negativeAcksTracker_(client, *this, conf),
negativeAcksTracker_(std::make_shared<NegativeAcksTracker>(client, *this, conf)),
readCompacted_(conf.isReadCompacted()),
startMessageId_(startMessageId),
maxPendingChunkedMessage_(conf.getMaxPendingChunkedMessage()),
Expand All @@ -105,6 +105,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr client, const std::string& topic,
} else {
unAckedMessageTrackerPtr_.reset(new UnAckedMessageTrackerDisabled());
}
unAckedMessageTrackerPtr_->start();

// Setup stats reporter.
unsigned int statsIntervalInSeconds = client->getClientConfig().getStatsIntervalInSeconds();
Expand Down Expand Up @@ -1228,7 +1229,7 @@ std::pair<MessageId, bool> ConsumerImpl::prepareCumulativeAck(const MessageId& m

void ConsumerImpl::negativeAcknowledge(const MessageId& messageId) {
unAckedMessageTrackerPtr_->remove(messageId);
negativeAcksTracker_.add(messageId);
negativeAcksTracker_->add(messageId);
}

void ConsumerImpl::disconnectConsumer() {
Expand Down Expand Up @@ -1266,7 +1267,7 @@ void ConsumerImpl::closeAsync(ResultCallback originalCallback) {
if (ackGroupingTrackerPtr_) {
ackGroupingTrackerPtr_->close();
}
negativeAcksTracker_.close();
negativeAcksTracker_->close();

ClientConnectionPtr cnx = getCnx().lock();
if (!cnx) {
Expand Down Expand Up @@ -1304,7 +1305,7 @@ void ConsumerImpl::shutdown() {
if (client) {
client->cleanupConsumer(this);
}
negativeAcksTracker_.close();
negativeAcksTracker_->close();
cancelTimers();
consumerCreatedPromise_.setFailed(ResultAlreadyClosed);
failPendingReceiveCallback();
Expand Down Expand Up @@ -1609,7 +1610,7 @@ void ConsumerImpl::internalGetLastMessageIdAsync(const BackoffPtr& backoff, Time
}

void ConsumerImpl::setNegativeAcknowledgeEnabledForTesting(bool enabled) {
negativeAcksTracker_.setEnabledForTesting(enabled);
negativeAcksTracker_->setEnabledForTesting(enabled);
}

void ConsumerImpl::trackMessage(const MessageId& messageId) {
Expand Down Expand Up @@ -1696,6 +1697,7 @@ void ConsumerImpl::cancelTimers() noexcept {
boost::system::error_code ec;
batchReceiveTimer_->cancel(ec);
checkExpiredChunkedTimer_->cancel(ec);
unAckedMessageTrackerPtr_->stop();
}

void ConsumerImpl::processPossibleToDLQ(const MessageId& messageId, ProcessDLQCallBack cb) {
Expand Down
2 changes: 1 addition & 1 deletion lib/ConsumerImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class ConsumerImpl : public ConsumerImplBase {
CompressionCodecProvider compressionCodecProvider_;
UnAckedMessageTrackerPtr unAckedMessageTrackerPtr_;
BrokerConsumerStatsImpl brokerConsumerStats_;
NegativeAcksTracker negativeAcksTracker_;
std::shared_ptr<NegativeAcksTracker> negativeAcksTracker_;
AckGroupingTrackerPtr ackGroupingTrackerPtr_;

MessageCryptoPtr msgCrypto_;
Expand Down
7 changes: 6 additions & 1 deletion lib/NegativeAcksTracker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ void NegativeAcksTracker::scheduleTimer() {
if (closed_) {
return;
}
std::weak_ptr<NegativeAcksTracker> weakSelf{shared_from_this()};
timer_->expires_from_now(timerInterval_);
timer_->async_wait(std::bind(&NegativeAcksTracker::handleTimer, this, std::placeholders::_1));
timer_->async_wait([weakSelf](const boost::system::error_code &ec) {
if (auto self = weakSelf.lock()) {
self->handleTimer(ec);
}
});
}

void NegativeAcksTracker::handleTimer(const boost::system::error_code &ec) {
Expand Down
2 changes: 1 addition & 1 deletion lib/NegativeAcksTracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ using DeadlineTimerPtr = std::shared_ptr<boost::asio::deadline_timer>;
class ExecutorService;
using ExecutorServicePtr = std::shared_ptr<ExecutorService>;

class NegativeAcksTracker {
class NegativeAcksTracker : public std::enable_shared_from_this<NegativeAcksTracker> {
public:
NegativeAcksTracker(ClientImplPtr client, ConsumerImpl &consumer, const ConsumerConfiguration &conf);

Expand Down
17 changes: 13 additions & 4 deletions lib/PatternMultiTopicsConsumerImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ const PULSAR_REGEX_NAMESPACE::regex PatternMultiTopicsConsumerImpl::getPattern()
void PatternMultiTopicsConsumerImpl::resetAutoDiscoveryTimer() {
autoDiscoveryRunning_ = false;
autoDiscoveryTimer_->expires_from_now(seconds(conf_.getPatternAutoDiscoveryPeriod()));
autoDiscoveryTimer_->async_wait(
std::bind(&PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask, this, std::placeholders::_1));

auto weakSelf = weak_from_this();
autoDiscoveryTimer_->async_wait([weakSelf](const boost::system::error_code& err) {
if (auto self = weakSelf.lock()) {
self->autoDiscoveryTimerTask(err);
}
});
}

void PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask(const boost::system::error_code& err) {
Expand Down Expand Up @@ -222,8 +227,12 @@ void PatternMultiTopicsConsumerImpl::start() {

if (conf_.getPatternAutoDiscoveryPeriod() > 0) {
autoDiscoveryTimer_->expires_from_now(seconds(conf_.getPatternAutoDiscoveryPeriod()));
autoDiscoveryTimer_->async_wait(
std::bind(&PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask, this, std::placeholders::_1));
auto weakSelf = weak_from_this();
autoDiscoveryTimer_->async_wait([weakSelf](const boost::system::error_code& err) {
if (auto self = weakSelf.lock()) {
self->autoDiscoveryTimerTask(err);
}
});
}
}

Expand Down
4 changes: 4 additions & 0 deletions lib/PatternMultiTopicsConsumerImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class PatternMultiTopicsConsumerImpl : public MultiTopicsConsumerImpl {
void onTopicsRemoved(NamespaceTopicsPtr removedTopics, ResultCallback callback);
void handleOneTopicAdded(const Result result, const std::string& topic,
std::shared_ptr<std::atomic<int>> topicsNeedCreate, ResultCallback callback);

std::weak_ptr<PatternMultiTopicsConsumerImpl> weak_from_this() noexcept {
return std::static_pointer_cast<PatternMultiTopicsConsumerImpl>(shared_from_this());
}
};

} // namespace pulsar
Expand Down
19 changes: 10 additions & 9 deletions lib/UnAckedMessageTrackerEnabled.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ void UnAckedMessageTrackerEnabled::timeoutHandler() {
ExecutorServicePtr executorService = client_->getIOExecutorProvider()->get();
timer_ = executorService->createDeadlineTimer();
timer_->expires_from_now(boost::posix_time::milliseconds(tickDurationInMs_));
timer_->async_wait([&](const boost::system::error_code& ec) {
if (ec) {
LOG_DEBUG("Ignoring timer cancelled event, code[" << ec << "]");
} else {
timeoutHandler();
std::weak_ptr<UnAckedMessageTrackerEnabled> weakSelf{shared_from_this()};
timer_->async_wait([weakSelf](const boost::system::error_code& ec) {
auto self = weakSelf.lock();
if (self && !ec) {
self->timeoutHandler();
}
});
}
Expand Down Expand Up @@ -91,10 +91,10 @@ UnAckedMessageTrackerEnabled::UnAckedMessageTrackerEnabled(long timeoutMs, long
std::set<MessageId> msgIds;
timePartitions.push_back(msgIds);
}

timeoutHandler();
}

void UnAckedMessageTrackerEnabled::start() { timeoutHandler(); }

bool UnAckedMessageTrackerEnabled::add(const MessageId& msgId) {
std::lock_guard<std::recursive_mutex> acquire(lock_);
auto id = discardBatch(msgId);
Expand Down Expand Up @@ -172,9 +172,10 @@ void UnAckedMessageTrackerEnabled::clear() {
}
}

UnAckedMessageTrackerEnabled::~UnAckedMessageTrackerEnabled() {
void UnAckedMessageTrackerEnabled::stop() {
boost::system::error_code ec;
if (timer_) {
timer_->cancel();
timer_->cancel(ec);
}
}
} /* namespace pulsar */
19 changes: 11 additions & 8 deletions lib/UnAckedMessageTrackerEnabled.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <boost/asio/deadline_timer.hpp>
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <set>

Expand All @@ -34,19 +35,21 @@ class ConsumerImplBase;
using ClientImplPtr = std::shared_ptr<ClientImpl>;
using DeadlineTimerPtr = std::shared_ptr<boost::asio::deadline_timer>;

class UnAckedMessageTrackerEnabled : public UnAckedMessageTrackerInterface {
class UnAckedMessageTrackerEnabled : public std::enable_shared_from_this<UnAckedMessageTrackerEnabled>,
public UnAckedMessageTrackerInterface {
public:
~UnAckedMessageTrackerEnabled();
UnAckedMessageTrackerEnabled(long timeoutMs, ClientImplPtr, ConsumerImplBase&);
UnAckedMessageTrackerEnabled(long timeoutMs, long tickDuration, ClientImplPtr, ConsumerImplBase&);
bool add(const MessageId& msgId);
bool remove(const MessageId& msgId);
void remove(const MessageIdList& msgIds);
void removeMessagesTill(const MessageId& msgId);
void removeTopicMessage(const std::string& topic);
void start() override;
void stop() override;
bool add(const MessageId& msgId) override;
bool remove(const MessageId& msgId) override;
void remove(const MessageIdList& msgIds) override;
void removeMessagesTill(const MessageId& msgId) override;
void removeTopicMessage(const std::string& topic) override;
void timeoutHandler();

void clear();
void clear() override;

protected:
void timeoutHandlerHelper();
Expand Down
2 changes: 2 additions & 0 deletions lib/UnAckedMessageTrackerInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class UnAckedMessageTrackerInterface {
public:
virtual ~UnAckedMessageTrackerInterface() {}
UnAckedMessageTrackerInterface() {}
virtual void start() {}
virtual void stop() {}
virtual bool add(const MessageId& m) = 0;
virtual bool remove(const MessageId& m) = 0;
virtual void remove(const MessageIdList& msgIds) = 0;
Expand Down
2 changes: 1 addition & 1 deletion tests/ConsumerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,7 @@ TEST(ConsumerTest, testNegativeAcksTrackerClose) {

consumer.close();
auto consumerImplPtr = PulsarFriend::getConsumerImplPtr(consumer);
ASSERT_TRUE(consumerImplPtr->negativeAcksTracker_.nackedMessages_.empty());
ASSERT_TRUE(consumerImplPtr->negativeAcksTracker_->nackedMessages_.empty());

client.close();
}
Expand Down

0 comments on commit 0729a9e

Please sign in to comment.