From 1251fab7d78805d2d28cb52793529fbe86d0db87 Mon Sep 17 00:00:00 2001 From: Abhay D Date: Fri, 11 Aug 2023 16:43:57 -0700 Subject: [PATCH] Fix behavior on connection death --- src/network/MissionControlProtocol.cpp | 11 ++++++++ src/network/MissionControlProtocol.h | 1 + src/network/websocket/WebSocketProtocol.cpp | 4 +++ src/network/websocket/WebSocketProtocol.h | 23 +++++++++++------ src/network/websocket/WebSocketServer.cpp | 28 +++++++++++++++++++++ src/network/websocket/WebSocketServer.h | 4 +++ 6 files changed, 63 insertions(+), 8 deletions(-) diff --git a/src/network/MissionControlProtocol.cpp b/src/network/MissionControlProtocol.cpp index 53a992761..fd051bf52 100644 --- a/src/network/MissionControlProtocol.cpp +++ b/src/network/MissionControlProtocol.cpp @@ -30,6 +30,7 @@ using websocket::msghandler_t; using websocket::validator_t; const std::chrono::milliseconds TELEM_REPORT_PERIOD = 100ms; +const std::chrono::milliseconds HEARTBEAT_TIMEOUT_PERIOD = 1000ms; // TODO: possibly use frozen::string for this so we don't have to use raw char ptrs // request keys @@ -264,6 +265,14 @@ void MissionControlProtocol::sendCameraStreamReport( this->_server.sendJSON(Constants::MC_PROTOCOL_NAME, msg); } +void MissionControlProtocol::handleHeartbeatTimeout() { + this->stopAndShutdownPowerRepeat(); + robot::emergencyStop(); + log(LOG_ERROR, "Heartbeat timed out! Emergency stopping.\n"); + Globals::E_STOP = true; + Globals::armIKEnabled = false; +} + void MissionControlProtocol::handleConnection() { // Turn off inverse kinematics on connection Globals::armIKEnabled = false; @@ -376,6 +385,8 @@ MissionControlProtocol::MissionControlProtocol(SingleClientWSServer& server) this->addConnectionHandler(std::bind(&MissionControlProtocol::handleConnection, this)); this->addDisconnectionHandler( std::bind(&MissionControlProtocol::stopAndShutdownPowerRepeat, this)); + + this->setPongTimeoutHandler(HEARTBEAT_TIMEOUT_PERIOD, std::bind(&MissionControlProtocol::handleHeartbeatTimeout, this)); this->_streaming_running = true; this->_streaming_thread = std::thread(&MissionControlProtocol::videoStreamTask, this); diff --git a/src/network/MissionControlProtocol.h b/src/network/MissionControlProtocol.h index d9748a6e1..4ccadabf3 100644 --- a/src/network/MissionControlProtocol.h +++ b/src/network/MissionControlProtocol.h @@ -69,6 +69,7 @@ class MissionControlProtocol : public WebSocketProtocol { // TODO: add documenta void sendJointPositionReport(const std::string& jointName, int32_t position); void sendRoverPos(); void handleConnection(); + void handleHeartbeatTimeout(); void startPowerRepeat(); void stopAndShutdownPowerRepeat(); void setRequestedJointPower(jointid_t joint, double power); diff --git a/src/network/websocket/WebSocketProtocol.cpp b/src/network/websocket/WebSocketProtocol.cpp index 4b50de4c6..8c31c0576 100644 --- a/src/network/websocket/WebSocketProtocol.cpp +++ b/src/network/websocket/WebSocketProtocol.cpp @@ -46,6 +46,10 @@ void WebSocketProtocol::addDisconnectionHandler(const connhandler_t& handler) { disconnectionHandlers.push_back(handler); } +void WebSocketProtocol::setPongTimeoutHandler(std::chrono::milliseconds timeout, const pongtimeouthandler_t& handler) { + pongInfo = {timeout, handler}; +} + void WebSocketProtocol::clientConnected() { for (const auto& f : connectionHandlers) { f(); diff --git a/src/network/websocket/WebSocketProtocol.h b/src/network/websocket/WebSocketProtocol.h index 7b4e4473c..f1a88edb0 100644 --- a/src/network/websocket/WebSocketProtocol.h +++ b/src/network/websocket/WebSocketProtocol.h @@ -1,7 +1,9 @@ #pragma once +#include #include #include +#include #include #include @@ -14,6 +16,7 @@ using nlohmann::json; typedef std::function msghandler_t; typedef std::function validator_t; typedef std::function connhandler_t; +typedef std::function pongtimeouthandler_t; /** * @brief Defines a protocol which will be served at an endpoint of a server. @@ -85,14 +88,7 @@ class WebSocketProtocol { void addDisconnectionHandler(const connhandler_t& handler); - /** - * @brief Process the given JSON object that was sent to this protocol's endpoint. - * Generally, this shouldn't be used by client code. - * - * @param obj The JSON object to be processed by this protocol. It is expected to have a - * "type" key. - */ - void processMessage(const json& obj) const; + void setPongTimeoutHandler(std::chrono::milliseconds timeout, const pongtimeouthandler_t& handler); /** * @brief Invoke all connection handlers for this protocol. @@ -114,11 +110,22 @@ class WebSocketProtocol { std::string getProtocolPath() const; private: + friend class SingleClientWSServer; std::string protocolPath; std::map handlerMap; std::map validatorMap; std::vector connectionHandlers; std::vector disconnectionHandlers; + std::optional> pongInfo; + + /** + * @brief Process the given JSON object that was sent to this protocol's endpoint. + * Generally, this shouldn't be used by client code. + * + * @param obj The JSON object to be processed by this protocol. It is expected to have a + * "type" key. + */ + void processMessage(const json& obj) const; }; } // namespace websocket diff --git a/src/network/websocket/WebSocketServer.cpp b/src/network/websocket/WebSocketServer.cpp index 1cf57b836..c209da213 100644 --- a/src/network/websocket/WebSocketServer.cpp +++ b/src/network/websocket/WebSocketServer.cpp @@ -26,6 +26,8 @@ SingleClientWSServer::SingleClientWSServer(const std::string& serverName, uint16 server.set_validate_handler([&](connection_hdl hdl) { return this->validate(hdl); }); server.set_message_handler( [&](connection_hdl hdl, message_t msg) { this->onMessage(hdl, msg); }); + server.set_pong_timeout_handler( + [&](connection_hdl hdl, std::string payload) { this->onPongTimeout(hdl, payload); }); } SingleClientWSServer::~SingleClientWSServer() { @@ -80,6 +82,16 @@ bool SingleClientWSServer::addProtocol(std::unique_ptr protoc std::string path = protocol->getProtocolPath(); if (protocolMap.find(path) == protocolMap.end()) { protocolMap.emplace(path, std::move(protocol)); + const auto& pongInfo = protocolMap.at(path).protocol->pongInfo; + if (pongInfo.has_value()) { + auto eventID = pingScheduler.scheduleEvent(pongInfo->first / 2, [this, path]() { + const auto& pd = this->protocolMap.at(path); + if (pd.client.has_value()) { + server.ping(pd.client.value(), path); + } + }); + protocolMap.at(path).pingEventID = eventID; + } return true; } else { return false; @@ -148,6 +160,10 @@ void SingleClientWSServer::onClose(connection_hdl hdl) { auto& protocolData = protocolMap.at(path); protocolData.client.reset(); + if (protocolData.pingEventID.has_value()) { + pingScheduler.removeEvent(protocolData.pingEventID.value()); + protocolData.pingEventID.reset(); + } protocolData.protocol->clientDisconnected(); } @@ -162,5 +178,17 @@ void SingleClientWSServer::onMessage(connection_hdl hdl, message_t message) { json obj = json::parse(jsonStr); protocolMap.at(path).protocol->processMessage(obj); } + +void SingleClientWSServer::onPongTimeout(connection_hdl hdl, const std::string& payload) { + auto conn = server.get_con_from_hdl(hdl); + + assert(protocolMap.find(payload) != protocolMap.end()); + + log(LOG_ERROR, "Pong timeout on %s\n", payload.c_str()); + auto& pongInfo = protocolMap.at(payload).protocol->pongInfo; + if (pongInfo.has_value()) { + pongInfo->second(); + } +} } // namespace websocket } // namespace net diff --git a/src/network/websocket/WebSocketServer.h b/src/network/websocket/WebSocketServer.h index f6a787ec7..a8a4f223d 100644 --- a/src/network/websocket/WebSocketServer.h +++ b/src/network/websocket/WebSocketServer.h @@ -1,6 +1,7 @@ #pragma once #include "WebSocketProtocol.h" +#include "../../utils/scheduler.h" #include #include @@ -98,6 +99,7 @@ class SingleClientWSServer { ProtocolData(std::unique_ptr protocol); std::unique_ptr protocol; std::optional client; + std::optional::eventid_t> pingEventID; }; std::string serverName; @@ -106,11 +108,13 @@ class SingleClientWSServer { bool isRunning; std::map protocolMap; std::thread serverThread; + util::PeriodicScheduler<> pingScheduler; bool validate(connection_hdl hdl); void onOpen(connection_hdl hdl); void onClose(connection_hdl hdl); void onMessage(connection_hdl hdl, message_t message); + void onPongTimeout(connection_hdl hdl, const std::string& payload); void serverTask(); }; } // namespace websocket