Skip to content

Commit

Permalink
Clear all locked object when client crashed.
Browse files Browse the repository at this point in the history
Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm committed Aug 12, 2024
1 parent e1e1b26 commit 19cd380
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 19 deletions.
28 changes: 20 additions & 8 deletions src/client/rpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,6 @@ Status RPCClient::GetRemoteBlob(const ObjectID& id, const bool unsafe,
std::vector<int> fd_sent;

std::string message_out;
RDMABlobScopeGuard rdmaBlobScopeGuard;
if (rdma_connected_) {
WriteGetRemoteBuffersRequest(std::set<ObjectID>{id}, unsafe, false, true,
message_out);
Expand All @@ -788,14 +787,17 @@ Status RPCClient::GetRemoteBlob(const ObjectID& id, const bool unsafe,
json message_in;
RETURN_ON_ERROR(doRead(message_in));
RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent));
RETURN_ON_ASSERT(payloads.size() == 1, "Expects only one payload");

RDMABlobScopeGuard rdmaBlobScopeGuard;
if (rdma_connected_) {
std::unordered_set<ObjectID> ids{payloads[0].object_id};
std::unordered_set<ObjectID> ids{id};
std::function<void(std::unordered_set<ObjectID>)> func = std::bind(
&RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1);
rdmaBlobScopeGuard.set(func, ids);
}

RETURN_ON_ASSERT(payloads.size() == 1, "Expects only one payload");

buffer = std::shared_ptr<RemoteBlob>(new RemoteBlob(
payloads[0].object_id, remote_instance_id_, payloads[0].data_size));
// read the actual payload
Expand Down Expand Up @@ -892,7 +894,6 @@ Status RPCClient::GetRemoteBlobs(
std::unordered_set<ObjectID> id_set(ids.begin(), ids.end());
std::vector<Payload> payloads;
std::vector<int> fd_sent;
RDMABlobScopeGuard rdmaBlobScopeGuard;

std::string message_out;
if (rdma_connected_) {
Expand All @@ -905,16 +906,19 @@ Status RPCClient::GetRemoteBlobs(
json message_in;
RETURN_ON_ERROR(doRead(message_in));
RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent));
RETURN_ON_ASSERT(payloads.size() == id_set.size(),
"The result size doesn't match with the requested sizes: " +
std::to_string(payloads.size()) + " vs. " +
std::to_string(id_set.size()));

RDMABlobScopeGuard rdmaBlobScopeGuard;
if (rdma_connected_) {
std::function<void(std::unordered_set<ObjectID>)> func = std::bind(
&RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1);
rdmaBlobScopeGuard.set(func, id_set);
}

RETURN_ON_ASSERT(payloads.size() == id_set.size(),
"The result size doesn't match with the requested sizes: " +
std::to_string(payloads.size()) + " vs. " +
std::to_string(id_set.size()));

std::unordered_map<ObjectID, std::shared_ptr<RemoteBlob>> id_payload_map;
if (rdma_connected_) {
for (auto const& payload : payloads) {
Expand Down Expand Up @@ -982,6 +986,14 @@ Status RPCClient::GetRemoteBlobs(
json message_in;
RETURN_ON_ERROR(doRead(message_in));
RETURN_ON_ERROR(ReadGetBuffersReply(message_in, payloads, fd_sent));

RDMABlobScopeGuard rdmaBlobScopeGuard;
if (rdma_connected_) {
std::function<void(std::unordered_set<ObjectID>)> func = std::bind(
&RPCClient::doReleaseBlobsWithRDMARequest, this, std::placeholders::_1);
rdmaBlobScopeGuard.set(func, id_set);
}

RETURN_ON_ASSERT(payloads.size() == id_set.size(),
"The result size doesn't match with the requested sizes: " +
std::to_string(payloads.size()) + " vs. " +
Expand Down
6 changes: 6 additions & 0 deletions src/server/async/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ void RPCServer::doVineyardReleaseMemory(VineyardRecvContext* recv_context,

void RPCServer::doVineyardClose(VineyardRecvContext* recv_context) {
VLOG(100) << "Receive close msg!";
if (recv_context == nullptr) {
return;
}
rdma_server_->CloseConnection(recv_context->rdma_conn_id);

std::lock_guard<std::recursive_mutex> scope_lock(this->rdma_mutex_);
Expand Down Expand Up @@ -369,6 +372,9 @@ void RPCServer::doRDMARecv() {
VineyardRecvContext* recv_context =
reinterpret_cast<VineyardRecvContext*>(context);
doVineyardClose(recv_context);
if (recv_context) {
delete recv_context;
}
}
VLOG(100) << "Get RX completion failed! Error:" << status.message();
VLOG(100) << "Retry...";
Expand Down
77 changes: 66 additions & 11 deletions src/server/async/socket_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -786,10 +786,22 @@ bool SocketConnection::doGetRemoteBuffers(const json& root) {

TRY_READ_REQUEST(ReadGetRemoteBuffersRequest, root, ids, unsafe, compress,
use_rdma);
server_ptr_->LockTransmissionObjects(ids);
RESPONSE_ON_ERROR(bulk_store_->GetUnsafe(ids, unsafe, objects));
RESPONSE_ON_ERROR(bulk_store_->AddDependency(
std::unordered_set<ObjectID>(ids.begin(), ids.end()), this->getConnId()));
this->LockTransmissionObjects(ids);
if (!bulk_store_->GetUnsafe(ids, unsafe, objects).ok()) {
this->UnlockTransmissionObjects(ids);
WriteErrorReply(Status::KeyError("Failed to get objects"), message_out);
this->doWrite(message_out);
return false;
}
if (!bulk_store_
->AddDependency(std::unordered_set<ObjectID>(ids.begin(), ids.end()),
this->getConnId())
.ok()) {
this->UnlockTransmissionObjects(ids);
WriteErrorReply(Status::KeyError("Failed to add dependency"), message_out);
this->doWrite(message_out);
return false;
}
WriteGetBuffersReply(objects, {}, compress, message_out);

if (!use_rdma) {
Expand All @@ -802,7 +814,7 @@ bool SocketConnection::doGetRemoteBuffers(const json& root) {
<< "Failed to send buffers to remote client: "
<< status.ToString();
}
self->server_ptr_->UnlockTransmissionObjects(ids);
self->UnlockTransmissionObjects(ids);
return Status::OK();
});
return Status::OK();
Expand Down Expand Up @@ -1846,12 +1858,10 @@ bool SocketConnection::doReleaseBlobsWithRDMA(const json& root) {
std::vector<ObjectID> ids;
TRY_READ_REQUEST(ReadReleaseBlobsWithRDMARequest, root, ids);

boost::asio::post(server_ptr_->GetIOContext(), [self, ids]() {
self->server_ptr_->UnlockTransmissionObjects(ids);
std::string message_out;
WriteReleaseBlobsWithRDMAReply(message_out);
self->doWrite(message_out);
});
this->UnlockTransmissionObjects(ids);
std::string message_out;
WriteReleaseBlobsWithRDMAReply(message_out);
this->doWrite(message_out);

return false;
}
Expand Down Expand Up @@ -1884,6 +1894,7 @@ void SocketConnection::doWrite(std::string&& buf) {
}

void SocketConnection::doStop() {
this->ClearLockedObjects();
if (this->Stop()) {
// drop connection
socket_server_ptr_->RemoveConnection(conn_id_);
Expand Down Expand Up @@ -1928,6 +1939,50 @@ void SocketConnection::doAsyncWrite(std::string&& buf, callback_t<> callback,
});
}

void SocketConnection::LockTransmissionObjects(
const std::vector<ObjectID>& ids) {
{
std::lock_guard<std::mutex> lock(locked_objects_mutex_);
for (auto const& id : ids) {
if (locked_objects_.find(id) == locked_objects_.end()) {
locked_objects_[id] = 1;
} else {
++locked_objects_[id];
}
}
}
server_ptr_->LockTransmissionObjects(ids);
}

void SocketConnection::UnlockTransmissionObjects(
const std::vector<ObjectID>& ids) {
{
std::lock_guard<std::mutex> lock(locked_objects_mutex_);
for (auto const& id : ids) {
if (locked_objects_.find(id) != locked_objects_.end()) {
if (--locked_objects_[id] == 0) {
locked_objects_.erase(id);
}
}
}
}
server_ptr_->UnlockTransmissionObjects(ids);
}

void SocketConnection::ClearLockedObjects() {
std::vector<ObjectID> ids;
{
std::lock_guard<std::mutex> lock(locked_objects_mutex_);
for (auto const& kv : locked_objects_) {
for (int i = 0; i < kv.second; ++i) {
ids.push_back(kv.first);
}
}
locked_objects_.clear();
}
server_ptr_->UnlockTransmissionObjects(ids);
}

SocketServer::SocketServer(std::shared_ptr<VineyardServer> vs_ptr)
: vs_ptr_(vs_ptr), next_conn_id_(0) {}

Expand Down
10 changes: 10 additions & 0 deletions src/server/async/socket_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "common/memory/payload.h"
#include "common/util/asio.h" // IWYU pragma: keep
Expand Down Expand Up @@ -193,6 +194,12 @@ class SocketConnection : public std::enable_shared_from_this<SocketConnection> {
this->server_ptr_ = session;
}

void LockTransmissionObjects(const std::vector<ObjectID>& ids);

void UnlockTransmissionObjects(const std::vector<ObjectID>& ids);

void ClearLockedObjects();

// whether the connection has been correctly "registered"
std::atomic_bool registered_;

Expand All @@ -216,6 +223,9 @@ class SocketConnection : public std::enable_shared_from_this<SocketConnection> {
size_t read_msg_header_;
std::string read_msg_body_;

std::unordered_map<ObjectID, int> locked_objects_;
std::mutex locked_objects_mutex_;

friend class IPCServer;
friend class RPCServer;
};
Expand Down

0 comments on commit 19cd380

Please sign in to comment.