From 634eeb0f08c1c9ad4295f86c636f4377956fae8e Mon Sep 17 00:00:00 2001 From: Michael Gerhold Date: Thu, 14 Nov 2024 18:27:02 +0100 Subject: [PATCH] feat: transfer player names between clients --- src/network/include/network/message_types.hpp | 3 +- src/network/include/network/messages.hpp | 57 ++++++-- src/network/messages.cpp | 124 ++++++++++++++++-- src/server/server.cpp | 88 ++++++++++--- src/server/server.hpp | 21 ++- .../include/simulator/observer_tetrion.hpp | 4 +- src/simulator/multiplayer_tetrion.cpp | 42 +++++- test/network_tests.cpp | 13 +- 8 files changed, 308 insertions(+), 44 deletions(-) diff --git a/src/network/include/network/message_types.hpp b/src/network/include/network/message_types.hpp index 39616e9..3f7816d 100644 --- a/src/network/include/network/message_types.hpp +++ b/src/network/include/network/message_types.hpp @@ -3,7 +3,8 @@ #include enum class MessageType : std::uint8_t { - Heartbeat = 0, + Connect, + Heartbeat, GridState, GameStart, StateBroadcast, diff --git a/src/network/include/network/messages.hpp b/src/network/include/network/messages.hpp index 9198088..bc6345f 100644 --- a/src/network/include/network/messages.hpp +++ b/src/network/include/network/messages.hpp @@ -47,6 +47,25 @@ struct AbstractMessage { [[nodiscard]] virtual bool equals(AbstractMessage const& other) const = 0; }; +static constexpr auto player_name_buffer_size = usize{ 32 }; + +struct Connect final : AbstractMessage { + std::string player_name; + + explicit Connect(std::string_view player_name); + + [[nodiscard]] static constexpr decltype(MessageHeader::payload_size) max_payload_size() { + return static_cast(player_name_buffer_size); + } + + [[nodiscard]] MessageType type() const override; + [[nodiscard]] decltype(MessageHeader::payload_size) payload_size() const override; + [[nodiscard]] c2k::MessageBuffer serialize() const override; + [[nodiscard]] static Connect deserialize(c2k::MessageBuffer& buffer); + + [[nodiscard]] bool equals(AbstractMessage const& other) const override; +}; + struct Heartbeat final : AbstractMessage { public: std::uint64_t frame; @@ -107,19 +126,36 @@ struct GridState final : AbstractMessage { } }; +struct ClientIdentity final { + u8 client_id; + std::string player_name; + + ClientIdentity(u8 const client_id, std::string player_name) + : client_id{ client_id }, player_name{ std::move(player_name) } {} + + [[nodiscard]] bool operator==(ClientIdentity const& other) const = default; +}; + struct GameStart final : AbstractMessage { std::uint8_t client_id; std::uint64_t start_frame; std::uint64_t random_seed; - std::uint8_t num_players; + std::vector client_identities; GameStart( std::uint8_t const client_id, std::uint64_t const start_frame, std::uint64_t const random_seed, - std::uint8_t const num_players + std::vector client_identities ) - : client_id{ client_id }, start_frame{ start_frame }, random_seed{ random_seed }, num_players{ num_players } {} + : client_id{ client_id }, + start_frame{ start_frame }, + random_seed{ random_seed }, + client_identities{ std::move(client_identities) } { + if (this->client_identities.size() > std::numeric_limits::max()) { + throw std::invalid_argument{ "Number of clients is too high." }; + } + } [[nodiscard]] MessageType type() const override; [[nodiscard]] decltype(MessageHeader::payload_size) payload_size() const override; @@ -127,24 +163,29 @@ struct GameStart final : AbstractMessage { [[nodiscard]] static GameStart deserialize(c2k::MessageBuffer& buffer); [[nodiscard]] static constexpr decltype(MessageHeader::payload_size) max_payload_size() { - return calculate_payload_size(); + return calculate_payload_size(std::numeric_limits::max()); + } + + [[nodiscard]] u8 num_players() const { + return gsl::narrow(client_identities.size()); } private: - [[nodiscard]] static constexpr decltype(MessageHeader::payload_size) calculate_payload_size() { + [[nodiscard]] static constexpr decltype(MessageHeader::payload_size) calculate_payload_size(u8 const num_players) { return static_cast( - sizeof(client_id) + sizeof(start_frame) + sizeof(random_seed) + sizeof(num_players) + sizeof(client_id) + sizeof(start_frame) + sizeof(random_seed) + sizeof(u8) /* num players */ + + num_players * (sizeof(ClientIdentity::client_id) + player_name_buffer_size) ); } [[nodiscard]] bool equals(AbstractMessage const& other) const override { auto const& other_game_start = static_cast(other); - return std::tie(client_id, start_frame, random_seed, num_players) + return std::tie(client_id, start_frame, random_seed, client_identities) == std::tie( other_game_start.client_id, other_game_start.start_frame, other_game_start.random_seed, - other_game_start.num_players + other_game_start.client_identities ); } }; diff --git a/src/network/messages.cpp b/src/network/messages.cpp index cd511aa..d8e56f3 100644 --- a/src/network/messages.cpp +++ b/src/network/messages.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -25,6 +26,8 @@ auto const message_max_payload_size = [message_type] { switch (message_type) { + case MessageType::Connect: + return Connect::max_payload_size(); case MessageType::Heartbeat: return Heartbeat::max_payload_size(); case MessageType::GridState: @@ -70,6 +73,8 @@ try { switch (message_type) { + case MessageType::Connect: + return std::make_unique(Connect::deserialize(buffer)); case MessageType::Heartbeat: return std::make_unique(Heartbeat::deserialize(buffer)); case MessageType::GridState: @@ -87,6 +92,70 @@ std::unreachable(); } +[[nodiscard]] static std::string sanitize(std::string_view const player_name) { + auto sanitized = std::string{}; + auto const max_length = std::min(player_name_buffer_size - 1, player_name.length()); + for (auto i = usize{ 0 }; i < max_length; ++i) { + auto const c = player_name.at(i); + if (not std::isprint(static_cast(c))) { + sanitized += '?'; + } else { + sanitized += c; + } + } + assert(sanitized.length() < player_name_buffer_size - 1); + return sanitized; +} + +Connect::Connect(std::string_view const player_name) + : player_name{ sanitize(player_name) } {} + +[[nodiscard]] MessageType Connect::type() const { + return MessageType::Connect; +} + +decltype(MessageHeader::payload_size) Connect::payload_size() const { + return max_payload_size(); +} + +[[nodiscard]] c2k::MessageBuffer Connect::serialize() const { + auto buffer = c2k::MessageBuffer{}; + buffer << static_cast(MessageType::Connect) << payload_size(); + auto const expected_message_size = buffer.size() + player_name_buffer_size; + for (auto const c : player_name) { + buffer << c; + } + while (buffer.size() < expected_message_size) { + buffer << '\0'; + } + assert(buffer.size() == expected_message_size); + return buffer; +} + +[[nodiscard]] Connect Connect::deserialize(c2k::MessageBuffer& buffer) { + static constexpr auto required_num_bytes = max_payload_size(); // Message has a fixed size. + if (buffer.size() < required_num_bytes) { + throw MessageDeserializationError{ std::format( + "too few bytes to deserialize Connect message ({} needed, {} received)", + required_num_bytes, + buffer.size() + ) }; + } + auto player_name = std::string{}; + while (not buffer.size() == 0) { + auto const c = buffer.try_extract().value(); + if (c == '\0') { + break; + } + player_name += c; + } + return Connect{ sanitize(std::move(player_name)) }; +} + +[[nodiscard]] bool Connect::equals(AbstractMessage const& other) const { + return other.type() == type() and dynamic_cast(other).player_name == player_name; +} + [[nodiscard]] MessageType Heartbeat::type() const { return MessageType::Heartbeat; } @@ -160,19 +229,31 @@ } [[nodiscard]] decltype(MessageHeader::payload_size) GameStart::payload_size() const { - return calculate_payload_size(); + return calculate_payload_size(gsl::narrow(client_identities.size())); } [[nodiscard]] c2k::MessageBuffer GameStart::serialize() const { auto buffer = c2k::MessageBuffer{}; // clang-format off - buffer << static_cast(MessageType::GameStart) - << payload_size() - << client_id - << start_frame - << random_seed - << num_players; + buffer << static_cast(MessageType::GameStart) + << payload_size() + << client_id + << start_frame + << random_seed + << gsl::narrow(client_identities.size()); // clang-format on + for (auto const& [other_client_id, player_name] : client_identities) { + buffer << other_client_id; + auto num_bytes = usize{ 0 }; + for (auto const c : player_name) { + buffer << c; + ++num_bytes; + } + while (num_bytes < player_name_buffer_size) { + buffer << '\0'; + ++num_bytes; + } + } assert(buffer.size() == payload_size() + header_size); return buffer; } @@ -183,7 +264,7 @@ decltype(client_id), decltype(start_frame), decltype(random_seed), - decltype(num_players) + u8 >(); // clang-format on if (buffer.size() < required_num_bytes) { @@ -203,12 +284,35 @@ decltype(GameStart::client_id), decltype(GameStart::start_frame), decltype(GameStart::random_seed), - decltype(GameStart::num_players) + u8 >() .value(); // clang-format on + + auto const num_remaining_bytes = num_players * (sizeof(u8) + player_name_buffer_size); + if (buffer.size() < num_remaining_bytes) { + throw MessageDeserializationError{ std::format( + "too few bytes to deserialize client identities within GameStart message ({} needed, {} received)", + num_remaining_bytes, + buffer.size() + ) }; + } + auto client_identities = std::vector{}; + client_identities.reserve(num_players); + for (auto i = decltype(num_players){ 0 }; i < num_players; ++i) { + auto const other_client_id = buffer.try_extract().value(); + auto player_name = std::string{}; + for (auto j = usize{ 0 }; j < player_name_buffer_size; ++j) { + auto const c = buffer.try_extract().value(); + if (c != '\0') { + player_name += c; + } + } + client_identities.emplace_back(other_client_id, std::move(player_name)); + } assert(buffer.size() == 0); - return GameStart{ client_id, start_frame, random_seed, num_players }; + + return GameStart{ client_id, start_frame, random_seed, std::move(client_identities) }; } StateBroadcast::StateBroadcast(std::uint64_t const frame, std::vector states_per_client) diff --git a/src/server/server.cpp b/src/server/server.cpp index dcc2a1b..f187f26 100644 --- a/src/server/server.cpp +++ b/src/server/server.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include "network/messages.hpp" @@ -21,6 +22,38 @@ void Server::process_client(std::stop_token const& stop_token, Server& self, std using namespace std::chrono_literals; auto& socket = self.m_client_sockets.at(index); + // Wait for Connect message from client... + while (not stop_token.stop_requested() and socket.is_connected()) { + auto message = std::unique_ptr{}; + try { + message = AbstractMessage::from_socket(socket); + } catch (c2k::TimeoutError const&) { + spdlog::error("Waiting for Connect message from client {}...", index); + continue; + } catch (c2k::ReadError const& exception) { + spdlog::error("error while reading from socket: {}", exception.what()); + break; + } + + if (message->type() == MessageType::Connect) { + auto& connect_message = dynamic_cast(*message); + spdlog::info("Client identified itself as '{}'.", connect_message.player_name); + // clang-format off + self.m_client_infos.apply( + [index, name = std::move(connect_message.player_name)] + (std::vector& client_infos) mutable { + client_infos.at(index).player_name = std::move(name); + assert(client_infos.at(index).state == ClientState::Connected); + client_infos.at(index).state = ClientState::Identified; + } + ); + // clang-format on + break; + } + + // todo: This client sent an unexpected message. It should be disconnected. + } + while (not stop_token.stop_requested() and socket.is_connected()) { auto message = std::unique_ptr{}; try { @@ -53,7 +86,7 @@ void Server::process_client(std::stop_token const& stop_token, Server& self, std spdlog::info("client {}:{} disconnected", socket.remote_address().address, socket.remote_address().port); auto const client_id = self.m_client_infos.apply([index](std::vector& client_infos) { auto& client_info = client_infos.at(index); - client_info.is_connected = false; + client_info.state = ClientState::Disconnected; return client_info.id; }); self.broadcast_client_disconnected_message(client_id); @@ -86,29 +119,52 @@ void Server::keep_broadcasting(std::stop_token const& stop_token, Server& self) std::this_thread::sleep_for(10ms); } - // wait for all clients to connect + // Wait for all clients to connect and identify. while (not stop_token.stop_requested()) { - auto const num_connected_clients = - self.m_client_infos.apply([](std::vector const& client_infos) { return client_infos.size(); }); - if (num_connected_clients == self.m_expected_player_count) { + // clang-format off + auto const num_identified_clients = self.m_client_infos.apply( + [](std::vector const& client_infos) { + return gsl::narrow( + std::ranges::count_if( + client_infos, + [&](ClientInfo const& info) { + return info.state == ClientState::Identified; + } + ) + ); + } + ); + // clang-format on + if (num_identified_clients == self.m_expected_player_count) { break; } - spdlog::info("not all clients have connected yet, number of clients: {}", num_connected_clients); + spdlog::info( + "not all clients have connected/identified yet ({} of {})", + num_identified_clients, + self.m_expected_player_count.load() + ); // todo: replace sleep with condition variable std::this_thread::sleep_for(100ms); } - auto i = std::uint8_t{ 0 }; - for (auto& socket : self.m_client_sockets) { + auto const client_identities = self.m_client_infos.apply([](std::vector& client_infos) { + auto identities = std::vector{}; + identities.reserve(client_infos.size()); + for (auto const& [i, client_info] : std::views::enumerate(client_infos)) { + identities.emplace_back(gsl::narrow(i), client_info.player_name); + } + return identities; + }); + + for (auto const& [i, socket] : std::views::enumerate(self.m_client_sockets)) { spdlog::info("assigning id {} to client and sending seed {}", i, self.m_seed); auto const message = GameStart{ - i, + gsl::narrow(i), start_frame, self.m_seed, - gsl::narrow(self.m_client_sockets.size()), + client_identities, }; socket.send(message.serialize()).wait(); - ++i; } auto last_min_num_frames_simulated = std::optional{ std::nullopt }; @@ -118,7 +174,7 @@ void Server::keep_broadcasting(std::stop_token const& stop_token, Server& self) self.m_client_infos.apply([&self, &last_min_num_frames_simulated](std::vector& client_infos) { auto const num_clients_connected = std::ranges::count_if(client_infos, [](ClientInfo const& client_info) { - return client_info.is_connected; + return client_info.is_connected(); }); if (num_clients_connected == 0) { @@ -128,14 +184,14 @@ void Server::keep_broadcasting(std::stop_token const& stop_token, Server& self) // go through all the connected clients and determine the minimum number of key states // that have been queued up for all clients auto const min_num_key_states_queued = std::ranges::min( - client_infos | std::views::filter([](auto const& client_info) { return client_info.is_connected; }) + client_infos | std::views::filter([](auto const& client_info) { return client_info.is_connected(); }) | std::views::transform([](auto const& client_info) { return client_info.key_states.size(); }) ); for (auto i = usize{ 0 }; i < min_num_key_states_queued; ++i) { auto garbage_send_events = std::unordered_map{}; for (auto& client_info : client_infos) { - if (not client_info.is_connected) { + if (not client_info.is_connected()) { continue; } auto const key_state = client_info.key_states.at(i); @@ -162,13 +218,13 @@ void Server::keep_broadcasting(std::stop_token const& stop_token, Server& self) // first we need to find the minimum number of frames simulated by any client that is connected auto const min_num_frames_simulated = std::ranges::min( - client_infos | std::views::filter([](auto const& client_info) { return client_info.is_connected; }) + client_infos | std::views::filter([](auto const& client_info) { return client_info.is_connected(); }) | std::views::transform([](auto const& client_info) { return client_info.tetrion.next_frame(); }) ); // to not block the broadcasting, we will create empty key states for all clients that are not connected for (auto& client_info : client_infos) { - if (not client_info.is_connected) { + if (not client_info.is_connected()) { while (client_info.tetrion.next_frame() < min_num_frames_simulated) { static constexpr auto key_state = KeyState{}; client_info.key_states.push_back(key_state); diff --git a/src/server/server.hpp b/src/server/server.hpp index 04a079c..6e20f25 100644 --- a/src/server/server.hpp +++ b/src/server/server.hpp @@ -8,14 +8,33 @@ #include #include +enum class ClientState { + Connected, + Identified, + Disconnected, +}; + struct ClientInfo final { u8 id; ObpfTetrion tetrion; std::vector key_states; - bool is_connected = true; + ClientState state = ClientState::Connected; + std::string player_name; // Not filled by constructor, because the name is transferred later. explicit ClientInfo(u8 const id, u64 const seed, u64 const start_frame) : id{ id }, tetrion{ seed, start_frame } {} + + [[nodiscard]] bool is_connected() const { + switch (state) { + using enum ClientState; + case Connected: + case Identified: + return true; + case Disconnected: + return false; + } + throw std::logic_error{ "unreachable" }; + } }; class Server final { diff --git a/src/simulator/include/simulator/observer_tetrion.hpp b/src/simulator/include/simulator/observer_tetrion.hpp index 434e0dd..36a5686 100644 --- a/src/simulator/include/simulator/observer_tetrion.hpp +++ b/src/simulator/include/simulator/observer_tetrion.hpp @@ -12,8 +12,8 @@ struct ObserverTetrion final : ObpfTetrion { bool m_is_connected = true; public: - ObserverTetrion(u64 const seed, u64 const start_frame, u8 const m_client_id, Key) - : ObpfTetrion{ seed, start_frame }, m_client_id{ m_client_id } {} + ObserverTetrion(u64 const seed, u64 const start_frame, u8 const m_client_id, std::string player_name, Key) + : ObpfTetrion{ seed, start_frame, std::move(player_name) }, m_client_id{ m_client_id } {} [[nodiscard]] std::optional simulate_next_frame(KeyState) override { return std::nullopt; diff --git a/src/simulator/multiplayer_tetrion.cpp b/src/simulator/multiplayer_tetrion.cpp index 770e554..14a3373 100644 --- a/src/simulator/multiplayer_tetrion.cpp +++ b/src/simulator/multiplayer_tetrion.cpp @@ -11,12 +11,17 @@ NullableUniquePointer MultiplayerTetrion::create( ) { auto socket = c2k::Sockets::create_client(c2k::AddressFamily::Ipv4, server, port); auto message = std::unique_ptr{}; + + // Identify this client... + socket.send(Connect{ player_name }.serialize()).wait(); + + // Wait for the GameStart message coming from the server... while (true) { try { message = AbstractMessage::from_socket(socket); break; - } catch (c2k::TimeoutError const& exception) { - spdlog::error("timeout error while waiting for game start message: {}", exception.what()); + } catch (c2k::TimeoutError const&) { + spdlog::info("waiting for the game to start..."); } catch (c2k::ReadError const& exception) { spdlog::error("error while reading from socket: {}", exception.what()); return nullptr; @@ -32,7 +37,7 @@ NullableUniquePointer MultiplayerTetrion::create( spdlog::info("received game start message"); auto const& game_start_message = dynamic_cast(*message); - auto const num_observers = static_cast(game_start_message.num_players - 1); + auto const num_observers = static_cast(game_start_message.num_players() - 1); auto observers = std::vector>{}; observers.reserve(num_observers); auto observer_id = u8{ 0 }; @@ -40,22 +45,51 @@ NullableUniquePointer MultiplayerTetrion::create( if (observer_id == game_start_message.client_id) { ++observer_id; } + + auto const find_iterator = + std::ranges::find_if(game_start_message.client_identities, [observer_id](auto const& identity) { + return identity.client_id == observer_id; + }); + + using namespace std::string_literals; + // clang-format off + auto observer_name = ( + find_iterator == game_start_message.client_identities.cend() + ? ""s + : find_iterator->player_name + ); + // clang-format on + observers.push_back(std::make_unique( game_start_message.random_seed, game_start_message.start_frame, observer_id, + std::move(observer_name), ObserverTetrion::Key{} )); ++observer_id; } + auto const find_iterator = + std::ranges::find_if(game_start_message.client_identities, [&game_start_message](auto const& identity) { + return identity.client_id == game_start_message.client_id; + }); + + // clang-format off + auto this_player_name = ( + find_iterator == game_start_message.client_identities.cend() + ? player_name + : find_iterator->player_name + ); + // clang-format on + return std::make_unique( std::move(socket), game_start_message.client_id, game_start_message.start_frame, game_start_message.random_seed, std::move(observers), - std::move(player_name), + std::move(this_player_name), Key{} ); } diff --git a/test/network_tests.cpp b/test/network_tests.cpp index 71df707..7125830 100644 --- a/test/network_tests.cpp +++ b/test/network_tests.cpp @@ -62,7 +62,7 @@ TEST(NetworkTests, TooBigHeartbeatMessageFails) { std::ignore = send_receive_buffer_and_deserialize(buffer); FAIL() << "expected MessageSerializationError"; } catch (MessageDeserializationError const& e) { - EXPECT_STREQ(e.what(), "message payload size 24 is too big for message type 0 (maximum is 23)"); + EXPECT_STREQ(e.what(), "message payload size 24 is too big for message type 1 (maximum is 23)"); } catch (...) { FAIL() << "expected MessageSerializationError"; } @@ -140,7 +140,16 @@ TEST(NetworkTests, SlightlyTooBigGridStateMessageFails) { TEST(NetworkTests, GameStartMessage) { auto const random_seed = static_cast(std::random_device{}()); - auto const message = GameStart{ 31, 180, random_seed, 5 }; + auto const message = GameStart{ + 31, + 180, + random_seed, + std::vector{ + ClientIdentity{ 0, "player0" }, + ClientIdentity{ 1, "player1" }, + ClientIdentity{ 2, "player2" }, + }, + }; auto const deserialized_message = send_receive_and_deserialize(message);