diff --git a/ntcore/src/generate/java/NetworkTableInstance.java.jinja b/ntcore/src/generate/java/NetworkTableInstance.java.jinja index 4c4da68a444..47158fd6026 100644 --- a/ntcore/src/generate/java/NetworkTableInstance.java.jinja +++ b/ntcore/src/generate/java/NetworkTableInstance.java.jinja @@ -68,6 +68,7 @@ public final class NetworkTableInstance implements AutoCloseable { if (m_owned && m_handle != 0) { m_listeners.close(); NetworkTablesJNI.destroyInstance(m_handle); + m_handle = 0; } } @@ -986,5 +987,5 @@ public final class NetworkTableInstance implements AutoCloseable { } private boolean m_owned; - private final int m_handle; + private int m_handle; } diff --git a/ntcore/src/main/native/cpp/LocalStorage.cpp b/ntcore/src/main/native/cpp/LocalStorage.cpp index 9f015f8de72..c645e21b7a0 100644 --- a/ntcore/src/main/native/cpp/LocalStorage.cpp +++ b/ntcore/src/main/native/cpp/LocalStorage.cpp @@ -25,6 +25,12 @@ using namespace nt; +// maximum number of local publishers / subscribers to any given topic +static constexpr size_t kMaxPublishers = 512; +static constexpr size_t kMaxSubscribers = 512; +static constexpr size_t kMaxMultiSubscribers = 512; +static constexpr size_t kMaxListeners = 512; + namespace { // Utility wrapper for making a set-like vector @@ -495,15 +501,19 @@ void LSImpl::NotifyValue(TopicData* topic, unsigned int eventFlags) { if (subscriber->active) { subscriber->pollStorage.emplace_back(topic->lastValue); subscriber->handle.Set(); - m_listenerStorage.Notify(subscriber->valueListeners, eventFlags, - topic->handle, 0, topic->lastValue); + if (!subscriber->valueListeners.empty()) { + m_listenerStorage.Notify(subscriber->valueListeners, eventFlags, + topic->handle, 0, topic->lastValue); + } } } for (auto&& subscriber : topic->multiSubscribers) { subscriber->handle.Set(); - m_listenerStorage.Notify(subscriber->valueListeners, eventFlags, - topic->handle, 0, topic->lastValue); + if (!subscriber->valueListeners.empty()) { + m_listenerStorage.Notify(subscriber->valueListeners, eventFlags, + topic->handle, 0, topic->lastValue); + } } } @@ -889,6 +899,12 @@ std::unique_ptr LSImpl::RemoveMultiSubscriber( void LSImpl::AddListenerImpl(NT_Listener listenerHandle, TopicData* topic, unsigned int eventMask) { + if (topic->localSubscribers.size() >= kMaxSubscribers) { + ERROR( + "reached maximum number of subscribers to '{}', ignoring listener add", + topic->name); + return; + } // subscribe to make sure topic updates are received PubSubConfig config; config.topicsOnly = (eventMask & NT_EVENT_VALUE_ALL) == 0; @@ -906,6 +922,12 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, auto topic = subscriber->topic; if ((eventMask & NT_EVENT_TOPIC) != 0) { + if (topic->listeners.size() >= kMaxListeners) { + ERROR("reached maximum number of listeners to '{}', not adding listener", + topic->name); + return; + } + m_listenerStorage.Activate( listenerHandle, eventMask & (NT_EVENT_TOPIC | NT_EVENT_IMMEDIATE)); @@ -922,6 +944,11 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, } if ((eventMask & NT_EVENT_VALUE_ALL) != 0) { + if (subscriber->valueListeners.size() >= kMaxListeners) { + ERROR("reached maximum number of listeners to '{}', not adding listener", + topic->name); + return; + } m_listenerStorage.Activate( listenerHandle, eventMask & (NT_EVENT_VALUE_ALL | NT_EVENT_IMMEDIATE), [subentryHandle](unsigned int mask, Event* event) { @@ -968,6 +995,11 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, } if ((eventMask & NT_EVENT_TOPIC) != 0) { + if (m_topicPrefixListeners.size() >= kMaxListeners) { + ERROR("reached maximum number of listeners, not adding listener"); + return; + } + m_listenerStorage.Activate( listenerHandle, eventMask & (NT_EVENT_TOPIC | NT_EVENT_IMMEDIATE)); @@ -989,6 +1021,11 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, } if ((eventMask & NT_EVENT_VALUE_ALL) != 0) { + if (subscriber->valueListeners.size() >= kMaxListeners) { + ERROR("reached maximum number of listeners, not adding listener"); + return; + } + m_listenerStorage.Activate( listenerHandle, eventMask & (NT_EVENT_VALUE_ALL | NT_EVENT_IMMEDIATE), [subentryHandle = subscriber->handle.GetHandle()](unsigned int mask, @@ -1018,6 +1055,10 @@ void LSImpl::AddListenerImpl(NT_Listener listenerHandle, void LSImpl::AddListener(NT_Listener listenerHandle, std::span prefixes, unsigned int eventMask) { + if (m_multiSubscribers.size() >= kMaxMultiSubscribers) { + ERROR("reached maximum number of multi-subscribers, not adding listener"); + return; + } // subscribe to make sure topic updates are received PubSubOptions options; options.topicsOnly = (eventMask & NT_EVENT_VALUE_ALL) == 0; @@ -1548,6 +1589,13 @@ NT_Subscriber LocalStorage::Subscribe(NT_Topic topicHandle, NT_Type type, return 0; } + if (topic->localSubscribers.size() >= kMaxSubscribers) { + WPI_ERROR(m_impl->m_logger, + "reached maximum number of subscribers to '{}', not subscribing", + topic->name); + return 0; + } + // Create subscriber return m_impl->AddLocalSubscriber(topic, PubSubConfig{type, typeStr, options}) ->handle; @@ -1562,6 +1610,13 @@ NT_MultiSubscriber LocalStorage::SubscribeMultiple( std::span prefixes, std::span options) { std::scoped_lock lock{m_mutex}; + + if (m_impl->m_multiSubscribers.size() >= kMaxMultiSubscribers) { + WPI_ERROR(m_impl->m_logger, + "reached maximum number of multi-subscribers, not subscribing"); + return 0; + } + PubSubOptions opts{options}; opts.prefixMatch = true; return m_impl->AddMultiSubscriber(prefixes, opts)->handle; @@ -1594,6 +1649,13 @@ NT_Publisher LocalStorage::Publish(NT_Topic topicHandle, NT_Type type, return 0; } + if (topic->localPublishers.size() >= kMaxPublishers) { + WPI_ERROR(m_impl->m_logger, + "reached maximum number of publishers to '{}', not publishing", + topic->name); + return 0; + } + return m_impl ->AddLocalPublisher(topic, properties, PubSubConfig{type, typeStr, options}) @@ -1627,6 +1689,14 @@ NT_Entry LocalStorage::GetEntry(NT_Topic topicHandle, NT_Type type, return 0; } + if (topic->localSubscribers.size() >= kMaxSubscribers) { + WPI_ERROR( + m_impl->m_logger, + "reached maximum number of subscribers to '{}', not creating entry", + topic->name); + return 0; + } + // Create subscriber auto subscriber = m_impl->AddLocalSubscriber(topic, PubSubConfig{type, typeStr, options}); @@ -2010,6 +2080,14 @@ NT_Entry LocalStorage::GetEntry(std::string_view name) { auto* topic = m_impl->GetOrCreateTopic(name); if (topic->entry == 0) { + if (topic->localSubscribers.size() >= kMaxSubscribers) { + WPI_ERROR( + m_impl->m_logger, + "reached maximum number of subscribers to '{}', not creating entry", + topic->name); + return 0; + } + // Create subscriber auto* subscriber = m_impl->AddLocalSubscriber(topic, {}); diff --git a/ntcore/src/main/native/cpp/net/ServerImpl.cpp b/ntcore/src/main/native/cpp/net/ServerImpl.cpp index f99a2f47e32..c18c9812073 100644 --- a/ntcore/src/main/native/cpp/net/ServerImpl.cpp +++ b/ntcore/src/main/native/cpp/net/ServerImpl.cpp @@ -620,8 +620,20 @@ void ClientData4Base::ClientSubscribe(int64_t subuid, sub->periodMs = kMinPeriodMs; } + // update periodic sender (if not local) + if (!m_local) { + if (m_periodMs == UINT32_MAX) { + m_periodMs = sub->periodMs; + } else { + m_periodMs = std::gcd(m_periodMs, sub->periodMs); + } + if (m_periodMs < kMinPeriodMs) { + m_periodMs = kMinPeriodMs; + } + m_setPeriodic(m_periodMs); + } + // see if this immediately subscribes to any topics - bool updatedPeriodic = false; for (auto&& topic : m_server.m_topics) { bool removed = false; if (replace) { @@ -647,14 +659,6 @@ void ClientData4Base::ClientSubscribe(int64_t subuid, m_server.UpdateMetaTopicSub(topic.get()); } - if (added || removed) { - // update periodic sender (if not local) - if (!m_local) { - m_periodMs = std::gcd(m_periodMs, sub->periodMs); - updatedPeriodic = true; - } - } - if (!wasSubscribed && added && !removed) { // announce topic to client DEBUG4("client {}: announce {}", m_id, topic->name); @@ -667,12 +671,6 @@ void ClientData4Base::ClientSubscribe(int64_t subuid, } } } - if (updatedPeriodic) { - if (m_periodMs < kMinPeriodMs) { - m_periodMs = kMinPeriodMs; - } - m_setPeriodic(m_periodMs); - } // update meta data UpdateMetaClientSub(); diff --git a/ntcore/src/main/native/cpp/networktables/NetworkTableInstance.cpp b/ntcore/src/main/native/cpp/networktables/NetworkTableInstance.cpp index 47203b20840..12f36d32571 100644 --- a/ntcore/src/main/native/cpp/networktables/NetworkTableInstance.cpp +++ b/ntcore/src/main/native/cpp/networktables/NetworkTableInstance.cpp @@ -101,9 +101,44 @@ void NetworkTableInstance::SetServer(std::span servers, SetServer(serversArr); } +NT_Listener NetworkTableInstance::AddListener(Topic topic, + unsigned int eventMask, + ListenerCallback listener) { + if (::nt::GetInstanceFromHandle(topic.GetHandle()) != m_handle) { + fmt::print(stderr, "AddListener: topic is not from this instance\n"); + return 0; + } + return ::nt::AddListener(topic.GetHandle(), eventMask, std::move(listener)); +} + +NT_Listener NetworkTableInstance::AddListener(Subscriber& subscriber, + unsigned int eventMask, + ListenerCallback listener) { + if (::nt::GetInstanceFromHandle(subscriber.GetHandle()) != m_handle) { + fmt::print(stderr, "AddListener: subscriber is not from this instance\n"); + return 0; + } + return ::nt::AddListener(subscriber.GetHandle(), eventMask, + std::move(listener)); +} + +NT_Listener NetworkTableInstance::AddListener(NetworkTableEntry& entry, + int eventMask, + ListenerCallback listener) { + if (::nt::GetInstanceFromHandle(entry.GetHandle()) != m_handle) { + fmt::print(stderr, "AddListener: entry is not from this instance\n"); + return 0; + } + return ::nt::AddListener(entry.GetHandle(), eventMask, std::move(listener)); +} + NT_Listener NetworkTableInstance::AddListener(MultiSubscriber& subscriber, int eventMask, ListenerCallback listener) { + if (::nt::GetInstanceFromHandle(subscriber.GetHandle()) != m_handle) { + fmt::print(stderr, "AddListener: subscriber is not from this instance\n"); + return 0; + } return ::nt::AddListener(subscriber.GetHandle(), eventMask, std::move(listener)); } diff --git a/ntcore/src/main/native/cpp/ntcore_cpp.cpp b/ntcore/src/main/native/cpp/ntcore_cpp.cpp index b14862eda8b..512ceeccc4a 100644 --- a/ntcore/src/main/native/cpp/ntcore_cpp.cpp +++ b/ntcore/src/main/native/cpp/ntcore_cpp.cpp @@ -497,6 +497,14 @@ NT_Listener AddPolledListener(NT_ListenerPoller poller, NT_Listener AddPolledListener(NT_ListenerPoller poller, NT_Handle handle, unsigned int mask) { if (auto ii = InstanceImpl::GetTyped(poller, Handle::kListenerPoller)) { + if (Handle{handle}.GetInst() != Handle{poller}.GetInst()) { + WPI_ERROR( + ii->logger, + "AddPolledListener(): trying to listen to handle {} (instance {}) " + "with poller {} (instance {}), ignored due to different instance", + handle, Handle{handle}.GetInst(), poller, Handle{poller}.GetInst()); + return {}; + } auto listener = ii->listenerStorage.AddListener(poller); DoAddListener(*ii, listener, handle, mask); return listener; diff --git a/ntcore/src/main/native/include/networktables/NetworkTableInstance.h b/ntcore/src/main/native/include/networktables/NetworkTableInstance.h index 7c9ed6e810f..02778b1ee47 100644 --- a/ntcore/src/main/native/include/networktables/NetworkTableInstance.h +++ b/ntcore/src/main/native/include/networktables/NetworkTableInstance.h @@ -132,7 +132,7 @@ class NetworkTableInstance final { * * @param inst Instance */ - static void Destroy(NetworkTableInstance inst); + static void Destroy(NetworkTableInstance& inst); /** * Gets the native handle for the entry. diff --git a/ntcore/src/main/native/include/networktables/NetworkTableInstance.inc b/ntcore/src/main/native/include/networktables/NetworkTableInstance.inc index 361b8eee97f..62d320c9a6c 100644 --- a/ntcore/src/main/native/include/networktables/NetworkTableInstance.inc +++ b/ntcore/src/main/native/include/networktables/NetworkTableInstance.inc @@ -27,9 +27,10 @@ inline NetworkTableInstance NetworkTableInstance::Create() { return NetworkTableInstance{CreateInstance()}; } -inline void NetworkTableInstance::Destroy(NetworkTableInstance inst) { +inline void NetworkTableInstance::Destroy(NetworkTableInstance& inst) { if (inst.m_handle != 0) { DestroyInstance(inst.m_handle); + inst.m_handle = 0; } } @@ -99,22 +100,6 @@ inline NT_Listener NetworkTableInstance::AddConnectionListener( std::move(callback)); } -inline NT_Listener NetworkTableInstance::AddListener( - Topic topic, unsigned int eventMask, ListenerCallback listener) { - return ::nt::AddListener(topic.GetHandle(), eventMask, std::move(listener)); -} - -inline NT_Listener NetworkTableInstance::AddListener( - Subscriber& subscriber, unsigned int eventMask, ListenerCallback listener) { - return ::nt::AddListener(subscriber.GetHandle(), eventMask, - std::move(listener)); -} - -inline NT_Listener NetworkTableInstance::AddListener( - NetworkTableEntry& entry, int eventMask, ListenerCallback listener) { - return ::nt::AddListener(entry.GetHandle(), eventMask, std::move(listener)); -} - inline NT_Listener NetworkTableInstance::AddListener( std::span prefixes, int eventMask, ListenerCallback listener) { diff --git a/ntcore/src/main/native/include/networktables/NetworkTableListener.inc b/ntcore/src/main/native/include/networktables/NetworkTableListener.inc index 3595f206511..ed7006ff8e7 100644 --- a/ntcore/src/main/native/include/networktables/NetworkTableListener.inc +++ b/ntcore/src/main/native/include/networktables/NetworkTableListener.inc @@ -72,7 +72,11 @@ inline NetworkTableListener::NetworkTableListener(NetworkTableListener&& rhs) inline NetworkTableListener& NetworkTableListener::operator=( NetworkTableListener&& rhs) { - std::swap(m_handle, rhs.m_handle); + if (m_handle != 0) { + nt::RemoveListener(m_handle); + } + m_handle = rhs.m_handle; + rhs.m_handle = 0; return *this; } @@ -102,7 +106,11 @@ inline NetworkTableListenerPoller::NetworkTableListenerPoller( inline NetworkTableListenerPoller& NetworkTableListenerPoller::operator=( NetworkTableListenerPoller&& rhs) { - std::swap(m_handle, rhs.m_handle); + if (m_handle != 0) { + nt::DestroyListenerPoller(m_handle); + } + m_handle = rhs.m_handle; + rhs.m_handle = 0; return *this; } diff --git a/ntcore/src/test/native/cpp/ValueListenerTest.cpp b/ntcore/src/test/native/cpp/ValueListenerTest.cpp index 7b10476f352..dbe3201389a 100644 --- a/ntcore/src/test/native/cpp/ValueListenerTest.cpp +++ b/ntcore/src/test/native/cpp/ValueListenerTest.cpp @@ -346,4 +346,64 @@ TEST_F(ValueListenerTest, PollImmediateSubMultiple) { EXPECT_EQ(valueData->value, nt::Value::MakeDouble(1.0)); } +TEST_F(ValueListenerTest, TwoSubOneListener) { + auto topic = nt::GetTopic(m_inst, "foo"); + auto pub = nt::Publish(topic, NT_DOUBLE, "double"); + auto sub1 = nt::Subscribe(topic, NT_DOUBLE, "double"); + auto sub2 = nt::Subscribe(topic, NT_DOUBLE, "double"); + auto sub3 = nt::SubscribeMultiple(m_inst, {{"foo"}}); + + auto poller = nt::CreateListenerPoller(m_inst); + auto h = nt::AddPolledListener(poller, sub1, nt::EventFlags::kValueLocal); + (void)sub2; + (void)sub3; + + nt::SetDouble(pub, 0); + + bool timedOut = false; + ASSERT_TRUE(wpi::WaitForObject(poller, 1.0, &timedOut)); + ASSERT_FALSE(timedOut); + auto results = nt::ReadListenerQueue(poller); + + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].flags & nt::EventFlags::kValueLocal, + nt::EventFlags::kValueLocal); + EXPECT_EQ(results[0].listener, h); + auto valueData = results[0].GetValueEventData(); + ASSERT_TRUE(valueData); + EXPECT_EQ(valueData->subentry, sub1); + EXPECT_EQ(valueData->topic, topic); + EXPECT_EQ(valueData->value, nt::Value::MakeDouble(0.0)); +} + +TEST_F(ValueListenerTest, TwoSubOneMultiListener) { + auto topic = nt::GetTopic(m_inst, "foo"); + auto pub = nt::Publish(topic, NT_DOUBLE, "double"); + auto sub1 = nt::Subscribe(topic, NT_DOUBLE, "double"); + auto sub2 = nt::Subscribe(topic, NT_DOUBLE, "double"); + auto sub3 = nt::SubscribeMultiple(m_inst, {{"foo"}}); + + auto poller = nt::CreateListenerPoller(m_inst); + auto h = nt::AddPolledListener(poller, sub3, nt::EventFlags::kValueLocal); + (void)sub1; + (void)sub2; + + nt::SetDouble(pub, 0); + + bool timedOut = false; + ASSERT_TRUE(wpi::WaitForObject(poller, 1.0, &timedOut)); + ASSERT_FALSE(timedOut); + auto results = nt::ReadListenerQueue(poller); + + ASSERT_EQ(results.size(), 1u); + EXPECT_EQ(results[0].flags & nt::EventFlags::kValueLocal, + nt::EventFlags::kValueLocal); + EXPECT_EQ(results[0].listener, h); + auto valueData = results[0].GetValueEventData(); + ASSERT_TRUE(valueData); + EXPECT_EQ(valueData->subentry, sub3); + EXPECT_EQ(valueData->topic, topic); + EXPECT_EQ(valueData->value, nt::Value::MakeDouble(0.0)); +} + } // namespace nt