diff --git a/lib/MultiTopicsConsumerImpl.cc b/lib/MultiTopicsConsumerImpl.cc index 80566c86..e95a9ac4 100644 --- a/lib/MultiTopicsConsumerImpl.cc +++ b/lib/MultiTopicsConsumerImpl.cc @@ -338,41 +338,23 @@ void MultiTopicsConsumerImpl::unsubscribeAsync(ResultCallback originalCallback) } state_ = Closing; - std::shared_ptr> consumerUnsubed = std::make_shared>(0); auto self = get_shared_this_ptr(); - int numConsumers = 0; consumers_.forEachValue( - [&numConsumers, &consumerUnsubed, &self, callback](const ConsumerImplPtr& consumer) { - numConsumers++; - consumer->unsubscribeAsync([self, consumerUnsubed, callback](Result result) { - self->handleUnsubscribedAsync(result, consumerUnsubed, callback); + [this, self, callback](const ConsumerImplPtr& consumer, SharedFuture future) { + consumer->unsubscribeAsync([this, self, callback, future](Result result) { + if (result != ResultOk) { + state_ = Failed; + LOG_ERROR("Error Closing one of the consumers in TopicsConsumer, result: " + << result << " subscription - " << subscriptionName_); + } + if (future.tryComplete()) { + LOG_DEBUG("Unsubscribed all of the partition consumer for TopicsConsumer. - " + << consumerStr_); + callback((state_ != Failed) ? ResultOk : ResultUnknownError); + } }); - }); - if (numConsumers == 0) { - // No need to unsubscribe, since the list matching the regex was empty - callback(ResultOk); - } -} - -void MultiTopicsConsumerImpl::handleUnsubscribedAsync(Result result, - std::shared_ptr> consumerUnsubed, - ResultCallback callback) { - (*consumerUnsubed)++; - - if (result != ResultOk) { - state_ = Failed; - LOG_ERROR("Error Closing one of the consumers in TopicsConsumer, result: " - << result << " subscription - " << subscriptionName_); - } - - if (consumerUnsubed->load() == numberTopicPartitions_->load()) { - LOG_DEBUG("Unsubscribed all of the partition consumer for TopicsConsumer. - " << consumerStr_); - Result result1 = (state_ != Failed) ? ResultOk : ResultUnknownError; - // The `callback` is a wrapper of user provided callback, it's not null and will call `shutdown()` if - // unsubscribe succeeds. - callback(result1); - return; - } + }, + [callback] { callback(ResultOk); }); } void MultiTopicsConsumerImpl::unsubscribeOneTopicAsync(const std::string& topic, ResultCallback callback) { @@ -899,50 +881,52 @@ std::shared_ptr MultiTopicsConsumerImpl::topicNamesValid(const std::v return topicNamePtr; } -void MultiTopicsConsumerImpl::seekAsync(const MessageId& msgId, ResultCallback callback) { - callback(ResultOperationNotSupported); -} - -void MultiTopicsConsumerImpl::seekAsync(uint64_t timestamp, ResultCallback callback) { - if (state_ != Ready) { - callback(ResultAlreadyClosed); - return; - } - +void MultiTopicsConsumerImpl::beforeSeek() { duringSeek_.store(true, std::memory_order_release); consumers_.forEachValue([](const ConsumerImplPtr& consumer) { consumer->pauseMessageListener(); }); unAckedMessageTrackerPtr_->clear(); incomingMessages_.clear(); incomingMessagesSize_ = 0L; +} + +void MultiTopicsConsumerImpl::afterSeek() { + duringSeek_.store(false, std::memory_order_release); + auto self = get_shared_this_ptr(); + listenerExecutor_->postWork([this, self] { + consumers_.forEachValue([](const ConsumerImplPtr& consumer) { consumer->resumeMessageListener(); }); + }); +} + +void MultiTopicsConsumerImpl::seekAsync(const MessageId& msgId, ResultCallback callback) { + if (msgId == MessageId::earliest() || msgId == MessageId::latest()) { + return seekAllAsync(msgId, callback); + } + auto optConsumer = consumers_.find(msgId.getTopicName()); + if (!optConsumer) { + LOG_ERROR(getName() << "cannot seek a message id whose topic \"" + msgId.getTopicName() + + "\" is not subscribed"); + callback(ResultOperationNotSupported); + return; + } + + beforeSeek(); auto weakSelf = weak_from_this(); - auto numConsumersLeft = std::make_shared>(consumers_.size()); - auto wrappedCallback = [this, weakSelf, callback, numConsumersLeft](Result result) { + optConsumer.get()->seekAsync(msgId, [this, weakSelf, callback](Result result) { auto self = weakSelf.lock(); - if (PULSAR_UNLIKELY(!self)) { - callback(result); - return; - } - if (result != ResultOk) { - *numConsumersLeft = 0; // skip the following callbacks + if (self) { + afterSeek(); callback(result); - return; - } - if (--*numConsumersLeft > 0) { - return; + } else { + callback(ResultAlreadyClosed); } - duringSeek_.store(false, std::memory_order_release); - listenerExecutor_->postWork([this, self] { - consumers_.forEachValue( - [](const ConsumerImplPtr& consumer) { consumer->resumeMessageListener(); }); - }); - callback(ResultOk); - }; - consumers_.forEachValue([timestamp, &wrappedCallback](const ConsumerImplPtr& consumer) { - consumer->seekAsync(timestamp, wrappedCallback); }); } +void MultiTopicsConsumerImpl::seekAsync(uint64_t timestamp, ResultCallback callback) { + seekAllAsync(timestamp, callback); +} + void MultiTopicsConsumerImpl::setNegativeAcknowledgeEnabledForTesting(bool enabled) { consumers_.forEachValue([enabled](const ConsumerImplPtr& consumer) { consumer->setNegativeAcknowledgeEnabledForTesting(enabled); diff --git a/lib/MultiTopicsConsumerImpl.h b/lib/MultiTopicsConsumerImpl.h index b5c51ec9..6763942f 100644 --- a/lib/MultiTopicsConsumerImpl.h +++ b/lib/MultiTopicsConsumerImpl.h @@ -25,7 +25,7 @@ #include #include "Commands.h" -#include "ConsumerImplBase.h" +#include "ConsumerImpl.h" #include "ConsumerInterceptors.h" #include "Future.h" #include "Latch.h" @@ -38,7 +38,6 @@ namespace pulsar { typedef std::shared_ptr> ConsumerSubResultPromisePtr; -class ConsumerImpl; using ConsumerImplPtr = std::shared_ptr; class ClientImpl; using ClientImplPtr = std::shared_ptr; @@ -152,8 +151,6 @@ class MultiTopicsConsumerImpl : public ConsumerImplBase { void handleSingleConsumerCreated(Result result, ConsumerImplBaseWeakPtr consumerImplBaseWeakPtr, std::shared_ptr> partitionsNeedCreate, ConsumerSubResultPromisePtr topicSubResultPromise); - void handleUnsubscribedAsync(Result result, std::shared_ptr> consumerUnsubed, - ResultCallback callback); void handleOneTopicUnsubscribedAsync(Result result, std::shared_ptr> consumerUnsubed, int numberPartitions, TopicNamePtr topicNamePtr, std::string& topicPartitionName, ResultCallback callback); @@ -179,6 +176,16 @@ class MultiTopicsConsumerImpl : public ConsumerImplBase { return std::static_pointer_cast(shared_from_this()); } + template +#if __cplusplus >= 202002L + requires std::convertible_to || + std::same_as>, MessageId> +#endif + void seekAllAsync(const SeekArg& seekArg, ResultCallback callback); + + void beforeSeek(); + void afterSeek(); + FRIEND_TEST(ConsumerTest, testMultiTopicsConsumerUnAckedMessageRedelivery); FRIEND_TEST(ConsumerTest, testPartitionedConsumerUnAckedMessageRedelivery); FRIEND_TEST(ConsumerTest, testAcknowledgeCumulativeWithPartition); @@ -187,5 +194,42 @@ class MultiTopicsConsumerImpl : public ConsumerImplBase { }; typedef std::shared_ptr MultiTopicsConsumerImplPtr; + +template +#if __cplusplus >= 202002L + requires std::convertible_to || + std::same_as>, MessageId> +#endif + inline void MultiTopicsConsumerImpl::seekAllAsync(const SeekArg& seekArg, ResultCallback callback) { + if (state_ != Ready) { + callback(ResultAlreadyClosed); + return; + } + beforeSeek(); + auto weakSelf = weak_from_this(); + auto failed = std::make_shared(false); + consumers_.forEachValue( + [this, weakSelf, &seekArg, callback, failed](const ConsumerImplPtr& consumer, SharedFuture future) { + consumer->seekAsync(seekArg, [this, weakSelf, callback, failed, future](Result result) { + auto self = weakSelf.lock(); + if (!self || failed->load(std::memory_order_acquire)) { + callback(result); + return; + } + if (result != ResultOk) { + failed->store(true, std::memory_order_release); // skip the following callbacks + afterSeek(); + callback(result); + return; + } + if (future.tryComplete()) { + afterSeek(); + callback(ResultOk); + } + }); + }, + [callback] { callback(ResultOk); }); +} + } // namespace pulsar #endif // PULSAR_MULTI_TOPICS_CONSUMER_HEADER diff --git a/lib/SynchronizedHashMap.h b/lib/SynchronizedHashMap.h index 082aeaf4..e224913b 100644 --- a/lib/SynchronizedHashMap.h +++ b/lib/SynchronizedHashMap.h @@ -18,8 +18,10 @@ */ #pragma once +#include #include #include +#include #include #include #include @@ -27,6 +29,16 @@ namespace pulsar { +class SharedFuture { + public: + SharedFuture(size_t size) : count_(std::make_shared(size)) {} + + bool tryComplete() const { return --*count_ == 0; } + + private: + std::shared_ptr count_; +}; + // V must be default constructible and copyable template class SynchronizedHashMap { @@ -60,10 +72,57 @@ class SynchronizedHashMap { } } - void forEachValue(std::function f) const { - Lock lock(mutex_); - for (const auto& kv : data_) { - f(kv.second); + template +#if __cplusplus >= 202002L + requires requires(ValueFunc&& each, const V& value) { + each(value); + } +#endif + void forEachValue(ValueFunc&& each) { + Lock lock{mutex_}; + for (auto&& kv : data_) { + each(kv.second); + } + } + + // This override provides a convenient approach to execute tasks on each consumer concurrently and + // supports checking if all tasks are done in the `each` callback. + // + // All map values will be passed as the 1st argument to the `each` function. The 2nd argument is a shared + // future whose `tryComplete` method marks this task as completed. If users want to check if all task are + // completed in the `each` function, this method must be called. + // + // For example, given a `SynchronizedHashMap` object `m` and the following call: + // + // ```c++ + // m.forEachValue([](const std::string& s, SharedFuture future) { + // std::cout << s << std::endl; + // if (future.tryComplete()) { + // std::cout << "done" << std::endl; + // } + // }, [] { std::cout << "empty map" << std::endl; }); + // ``` + // + // If the map is empty, only "empty map" will be printed. Otherwise, all values will be printed + // and "done" will be printed after that. + template +#if __cplusplus >= 202002L + requires requires(ValueFunc&& each, const V& value, SharedFuture count, EmptyFunc emptyFunc) { + each(value, count); + emptyFunc(); + } +#endif + void forEachValue(ValueFunc&& each, EmptyFunc&& emptyFunc) { + std::unique_lock lock{mutex_}; + if (data_.empty()) { + lock.unlock(); + emptyFunc(); + return; + } + SharedFuture future{data_.size()}; + for (auto&& kv : data_) { + const auto& value = kv.second; + each(value, future); } } diff --git a/tests/ConsumerSeekTest.cc b/tests/ConsumerSeekTest.cc new file mode 100644 index 00000000..f03ea5e3 --- /dev/null +++ b/tests/ConsumerSeekTest.cc @@ -0,0 +1,205 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include +#include +#include + +#include "HttpHelper.h" +#include "lib/LogUtils.h" + +DECLARE_LOG_OBJECT() + +static const std::string lookupUrl = "pulsar://localhost:6650"; +static const std::string adminUrl = "http://localhost:8080/"; + +extern std::string unique_str(); + +namespace pulsar { + +class ConsumerSeekTest : public ::testing::TestWithParam { + public: + void SetUp() override { client_ = Client{lookupUrl}; } + + void TearDown() override { client_.close(); } + + protected: + Client client_{lookupUrl}; + ProducerConfiguration producerConf_; + + std::vector initProducersForPartitionedTopic(const std::string& topic) { + constexpr int numPartitions = 3; + int res = makePutRequest(adminUrl + "admin/v2/persistent/public/default/" + topic + "/partitions", + std::to_string(numPartitions)); + if (res != 204 && res != 409) { + throw std::runtime_error("Failed to create partitioned topic: " + std::to_string(res)); + } + + std::vector producers(numPartitions); + for (int i = 0; i < numPartitions; i++) { + auto result = client_.createProducer(topic + "-partition-" + std::to_string(i), producers[i]); + if (result != ResultOk) { + throw std::runtime_error(std::string{"Failed to create producer: "} + strResult(result)); + } + } + return producers; + } + + Consumer createConsumer(const std::string& topic) { + Consumer consumer; + ConsumerConfiguration conf; + conf.setStartMessageIdInclusive(GetParam()); + auto result = client_.subscribe(topic, "sub", conf, consumer); + if (result != ResultOk) { + throw std::runtime_error(std::string{"Failed to subscribe: "} + strResult(result)); + } + return consumer; + } +}; + +TEST_P(ConsumerSeekTest, testSeekForMessageId) { + Client client(lookupUrl); + + const std::string topic = "test-seek-for-message-id-" + std::string((GetParam() ? "batch-" : "")) + + std::to_string(time(nullptr)); + + Producer producer; + ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf_, producer)); + + Consumer consumerExclusive; + ASSERT_EQ(ResultOk, client.subscribe(topic, "sub-0", consumerExclusive)); + + Consumer consumerInclusive; + ASSERT_EQ(ResultOk, + client.subscribe(topic, "sub-1", ConsumerConfiguration().setStartMessageIdInclusive(true), + consumerInclusive)); + + const auto numMessages = 100; + MessageId seekMessageId; + + int r = (rand() % (numMessages - 1)); + for (int i = 0; i < numMessages; i++) { + MessageId id; + ASSERT_EQ(ResultOk, + producer.send(MessageBuilder().setContent("msg-" + std::to_string(i)).build(), id)); + + if (i == r) { + seekMessageId = id; + } + } + + LOG_INFO("The seekMessageId is: " << seekMessageId << ", r : " << r); + + consumerExclusive.seek(seekMessageId); + Message msg0; + ASSERT_EQ(ResultOk, consumerExclusive.receive(msg0, 3000)); + + consumerInclusive.seek(seekMessageId); + Message msg1; + ASSERT_EQ(ResultOk, consumerInclusive.receive(msg1, 3000)); + + LOG_INFO("consumerExclusive received " << msg0.getDataAsString() << " from " << msg0.getMessageId()); + LOG_INFO("consumerInclusive received " << msg1.getDataAsString() << " from " << msg1.getMessageId()); + + ASSERT_EQ(msg0.getDataAsString(), "msg-" + std::to_string(r + 1)); + ASSERT_EQ(msg1.getDataAsString(), "msg-" + std::to_string(r)); + + consumerInclusive.close(); + consumerExclusive.close(); + producer.close(); +} + +TEST_P(ConsumerSeekTest, testMultiTopicsSeekAll) { + std::string topic = "consumer-seek-test-multi-topics-seek-all-" + unique_str(); + auto producers = initProducersForPartitionedTopic(topic); + auto consumer = createConsumer(topic); + const auto numPartitions = producers.size(); + + auto receive = [&consumer, numPartitions] { + std::set values; + for (int i = 0; i < numPartitions; i++) { + Message msg; + auto result = consumer.receive(msg, 3000); + if (result != ResultOk) { + throw std::runtime_error(std::string{"Receive failed: "} + strResult(result)); + } + values.emplace(msg.getDataAsString()); + } + return values; + }; + + for (int i = 0; i < numPartitions; i++) { + producers[i].send(MessageBuilder().setContent("msg-" + std::to_string(i) + "-0").build()); + } + ASSERT_EQ(receive(), (std::set{"msg-0-0", "msg-1-0", "msg-2-0"})); + + // Seek to earliest + ASSERT_EQ(ResultOk, consumer.seek(MessageId::earliest())); + ASSERT_EQ(receive(), (std::set{"msg-0-0", "msg-1-0", "msg-2-0"})); + + // Seek to latest + for (int i = 0; i < numPartitions; i++) { + producers[i].send(MessageBuilder().setContent("msg-" + std::to_string(i) + "-1").build()); + } + ASSERT_EQ(ResultOk, consumer.seek(MessageId::latest())); + + for (int i = 0; i < numPartitions; i++) { + producers[i].send(MessageBuilder().setContent("msg-" + std::to_string(i) + "-2").build()); + } + ASSERT_EQ(receive(), (std::set{"msg-0-2", "msg-1-2", "msg-2-2"})); +} + +TEST_P(ConsumerSeekTest, testMultiTopicsSeekSingle) { + std::string topic = "consumer-seek-test-multi-topics-seek-single-" + unique_str(); + auto producers = initProducersForPartitionedTopic(topic); + auto consumer = createConsumer(topic); + + MessageId msgId; + producers[0].send(MessageBuilder().setContent("msg-0").build(), msgId); + ASSERT_EQ(ResultOperationNotSupported, consumer.seek(msgId)); + producers[0].send(MessageBuilder().setContent("msg-1").build(), msgId); + ASSERT_EQ(ResultOperationNotSupported, consumer.seek(msgId)); + + std::vector msgIds; + Message msg; + for (int i = 0; i < 2; i++) { + ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); + msgIds.emplace_back(msg.getMessageId()); + } + + ASSERT_EQ(ResultOk, consumer.seek(msgIds[0])); + ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); + if (GetParam()) { + ASSERT_EQ(msg.getMessageId(), msgIds[0]); + } else { + ASSERT_EQ(msg.getMessageId(), msgIds[1]); + } +} + +TEST_F(ConsumerSeekTest, testNoInternalConsumer) { + Consumer consumer; + ASSERT_EQ(ResultOk, client_.subscribeWithRegex("testNoInternalConsumer.*", "sub", consumer)); + ASSERT_EQ(ResultOk, consumer.seek(MessageId::earliest())); +} + +INSTANTIATE_TEST_SUITE_P(Pulsar, ConsumerSeekTest, ::testing::Values(true, false)); + +} // namespace pulsar diff --git a/tests/ConsumerTest.cc b/tests/ConsumerTest.cc index 2aab7229..f9840f97 100644 --- a/tests/ConsumerTest.cc +++ b/tests/ConsumerTest.cc @@ -1136,69 +1136,6 @@ TEST(ConsumerTest, testPatternSubscribeTopic) { client.close(); } -class ConsumerSeekTest : public ::testing::TestWithParam { - public: - void SetUp() override { producerConf_ = ProducerConfiguration().setBatchingEnabled(GetParam()); } - - void TearDown() override { client_.close(); } - - protected: - Client client_{lookupUrl}; - ProducerConfiguration producerConf_; -}; - -TEST_P(ConsumerSeekTest, testSeekForMessageId) { - Client client(lookupUrl); - - const std::string topic = "test-seek-for-message-id-" + std::string((GetParam() ? "batch-" : "")) + - std::to_string(time(nullptr)); - - Producer producer; - ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf_, producer)); - - Consumer consumerExclusive; - ASSERT_EQ(ResultOk, client.subscribe(topic, "sub-0", consumerExclusive)); - - Consumer consumerInclusive; - ASSERT_EQ(ResultOk, - client.subscribe(topic, "sub-1", ConsumerConfiguration().setStartMessageIdInclusive(true), - consumerInclusive)); - - const auto numMessages = 100; - MessageId seekMessageId; - - int r = (rand() % (numMessages - 1)); - for (int i = 0; i < numMessages; i++) { - MessageId id; - ASSERT_EQ(ResultOk, - producer.send(MessageBuilder().setContent("msg-" + std::to_string(i)).build(), id)); - - if (i == r) { - seekMessageId = id; - } - } - - LOG_INFO("The seekMessageId is: " << seekMessageId << ", r : " << r); - - consumerExclusive.seek(seekMessageId); - Message msg0; - ASSERT_EQ(ResultOk, consumerExclusive.receive(msg0, 3000)); - - consumerInclusive.seek(seekMessageId); - Message msg1; - ASSERT_EQ(ResultOk, consumerInclusive.receive(msg1, 3000)); - - LOG_INFO("consumerExclusive received " << msg0.getDataAsString() << " from " << msg0.getMessageId()); - LOG_INFO("consumerInclusive received " << msg1.getDataAsString() << " from " << msg1.getMessageId()); - - ASSERT_EQ(msg0.getDataAsString(), "msg-" + std::to_string(r + 1)); - ASSERT_EQ(msg1.getDataAsString(), "msg-" + std::to_string(r)); - - consumerInclusive.close(); - consumerExclusive.close(); - producer.close(); -} - TEST(ConsumerTest, testNegativeAcksTrackerClose) { Client client(lookupUrl); auto topicName = "testNegativeAcksTrackerClose"; @@ -1252,8 +1189,6 @@ TEST(ConsumerTest, testAckNotPersistentTopic) { client.close(); } -INSTANTIATE_TEST_CASE_P(Pulsar, ConsumerSeekTest, ::testing::Values(true, false)); - class InterceptorForNegAckDeadlock : public ConsumerInterceptor { public: Message beforeConsume(const Consumer& consumer, const Message& message) override { return message; } diff --git a/tests/SynchronizedHashMapTest.cc b/tests/SynchronizedHashMapTest.cc index 85378e03..cf184d9f 100644 --- a/tests/SynchronizedHashMapTest.cc +++ b/tests/SynchronizedHashMapTest.cc @@ -91,6 +91,28 @@ TEST(SynchronizedHashMapTest, testForEach) { m.forEach([&pairs](const int& key, const int& value) { pairs.emplace_back(key, value); }); PairVector expectedPairs({{1, 100}, {2, 200}, {3, 300}}); ASSERT_EQ(sort(pairs), expectedPairs); + + m.clear(); + int result = 0; + values.clear(); + m.forEachValue([&values](int value, SharedFuture) { values.emplace_back(value); }, + [&result] { result = 1; }); + ASSERT_TRUE(values.empty()); + ASSERT_EQ(result, 1); + + m.emplace(1, 100); + m.forEachValue([&values](int value, SharedFuture) { values.emplace_back(value); }, + [&result] { result = 2; }); + ASSERT_EQ(values, (std::vector({100}))); + ASSERT_EQ(result, 1); + + values.clear(); + m.emplace(2, 200); + m.forEachValue([&values](int value, SharedFuture) { values.emplace_back(value); }, + [&result] { result = 2; }); + std::sort(values.begin(), values.end()); + ASSERT_EQ(values, (std::vector({100, 200}))); + ASSERT_EQ(result, 1); } TEST(SynchronizedHashMap, testRecursiveMutex) {