Skip to content

Commit

Permalink
Log communication cost
Browse files Browse the repository at this point in the history
  • Loading branch information
rishi-s8 committed Oct 21, 2024
1 parent a185d8a commit 7abedbe
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/utils/communication/grpc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,29 @@ def __init__(self, super_node_host: str):
self.peer_ids: OrderedDict[int, Dict[str, int | str]] = OrderedDict(
{0: {"rank": 0, "port": port, "ip": ip}}
)
self.communication_cost_received: int = 0
self.communication_cost_sent: int = 0

def update_communcation_cost(self, func):
def wrapper(request, context):
down_cost = request.ByteSize()
return_data = func(self, request, context)
up_cost = return_data.ByteSize()
with self.lock:
self.communication_cost_received += down_cost
self.communication_cost_sent += up_cost
return return_data
return wrapper

def register_self(self, obj: "BaseNode"):
self.base_node = obj

@update_communcation_cost
def send_data(self, request, context) -> comm_pb2.Empty: # type: ignore
self.received_data.put(deserialize_model(request.model.buffer)) # type: ignore
return comm_pb2.Empty() # type: ignore

@update_communcation_cost
def get_rank(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Rank | None:
try:
with self.lock:
Expand All @@ -114,6 +129,7 @@ def get_rank(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> co
except Exception as e:
context.abort(grpc.StatusCode.INTERNAL, f"Error in get_rank: {str(e)}") # type: ignore

@update_communcation_cost
def get_model(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Model | None:
if not self.base_node:
context.abort(grpc.StatusCode.INTERNAL, "Base node not registered") # type: ignore
Expand All @@ -122,6 +138,7 @@ def get_model(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> c
model = comm_pb2.Model(buffer=serialize_model(self.base_node.get_model_weights()))
return model

@update_communcation_cost
def get_current_round(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Round | None:
if not self.base_node:
context.abort(grpc.StatusCode.INTERNAL, "Base node not registered") # type: ignore
Expand Down Expand Up @@ -228,6 +245,8 @@ def register(self):
"""
def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> int:
rank_data = stub.get_rank(comm_pb2.Empty()) # type: ignore
with self.lock:
self.servicer.communication_cost_received += rank_data.ByteSize()
return rank_data.rank # type: ignore

self.rank = self.recv_with_retries(self.super_node_host, callback_fn)
Expand Down Expand Up @@ -341,6 +360,8 @@ def wait_until_rounds_match(self, id: int):
self_round = self.servicer.base_node.get_local_rounds()
def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> int:
round = stub.get_current_round(comm_pb2.Empty()) # type: ignore
with self.lock:
self.servicer.communication_cost_received += round.ByteSize()
return round.round # type: ignore

while True:
Expand Down Expand Up @@ -373,6 +394,8 @@ def receive(self, node_ids: List[int]) -> List[Any]:
items: List[Any] = []
def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> OrderedDict[str, Tensor]:
model = stub.get_model(comm_pb2.Empty()) # type: ignore
with self.lock:
self.servicer.communication_cost_received += model.ByteSize()
return deserialize_model(model.buffer) # type: ignore

for id in node_ids:
Expand Down

0 comments on commit 7abedbe

Please sign in to comment.