Skip to content

Commit

Permalink
Refactor/SK-1144 | Simplify on_train and on_validate (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminastrand authored Oct 30, 2024
1 parent b1dcdbe commit f3ec362
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 286 deletions.
142 changes: 63 additions & 79 deletions fedn/network/clients/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,111 +26,95 @@ Step-by-Step Instructions
import argparse
import json
import threading
import uuid
import fedn.network.grpc.fedn_pb2 as fedn
from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult
client_api = ClientAPI()
def get_api_url(api_url: str, api_port: int):
url = f"{api_url}:{api_port}" if api_port else api_url
if not url.endswith("/"):
url += "/"
return url
def on_train(in_model):
training_metadata = {
"num_examples": 1,
"batch_size": 1,
"epochs": 1,
"lr": 1,
}
config = {
"round_id": 1,
}
metadata = {
"training_metadata": training_metadata,
"config": json.dumps(config),
}
# Do your training here, out_model is your result...
out_model = in_model
return out_model, metadata
def on_validate(in_model):
# Calculate metrics here...
metrics = {
"test_accuracy": 0.9,
"test_loss": 0.1,
"train_accuracy": 0.8,
"train_loss": 0.2,
}
return metrics
def main(api_url: str, api_port: int, token: str = None):
print(f"API URL: {api_url}")
print(f"API Token: {token or "-"}")
print(f"API Port: {api_port or "-"}")
client_api = ClientAPI(train_callback=on_train, validate_callback=on_validate)
url = get_api_url(api_url, api_port)
name = input("Enter Client Name: ")
url = f"{api_url}:{api_port}" if api_port else api_url
if not url.endswith("/"):
url += "/"
print(f"Client Name: {name}")
client_api.set_name(name)
client_id = str(uuid.uuid4())
print("Connecting to API...")
client_options = {
"name": "client_example",
client_api.set_client_id(client_id)
controller_config = {
"name": name,
"client_id": client_id,
"package": "local",
"preferred_combiner": "",
}
result, combiner_config = client_api.connect_to_api(url, token, client_options)
result, combiner_config = client_api.connect_to_api(url, token, controller_config)
if result != ConnectToApiResult.Assigned:
print("Failed to connect to API, exiting.")
return
print("Connected to API")
result: bool = client_api.init_grpchandler(config=combiner_config, client_name=client_id, token=token)
if not result:
return
threading.Thread(target=client_api.send_heartbeats, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start()
def on_train(request):
print("Received train request")
model_id: str = request.model_id
model = client_api.get_model_from_combiner(id=str(model_id), client_name=name)
# Do your training here, out_model is your result...
out_model = model
updated_model_id = uuid.uuid4()
client_api.send_model_to_combiner(out_model, str(updated_model_id))
training_metadata = {
"num_examples": 1,
"batch_size": 1,
"epochs": 1,
"lr": 1,
}
config = {
"round_id": 1,
}
client_api.send_model_update(
sender_name=name,
sender_role=fedn.WORKER,
client_id=client_id,
model_id=model_id,
model_update_id=str(updated_model_id),
receiver_name=request.sender.name,
receiver_role=request.sender.role,
meta={
"training_metadata": training_metadata,
"config": json.dumps(config),
},
)
client_api.subscribe("train", on_train)
threading.Thread(target=client_api.listen_to_task_stream, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start()
stop_event = threading.Event()
try:
stop_event.wait()
except KeyboardInterrupt:
print("Client stopped by user.")
client_api.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Client Example")
parser.add_argument("--api-url", type=str, required=True, help="The API URL")
parser.add_argument("--api-port", type=int, required=False, help="The API Port")
parser.add_argument("--token", type=str, required=False, help="The API Token")
args = parser.parse_args()
main(args.api_url, args.api_port, args.token)
4. **Run the client**: Run the client by executing the following command:

.. code-block:: bash
Expand Down
Loading

0 comments on commit f3ec362

Please sign in to comment.