Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/SK-1144 | Simplify on_train and on_validate #731

Merged
merged 7 commits into from
Oct 30, 2024
142 changes: 63 additions & 79 deletions fedn/network/clients/README.rst
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could include a callback function for training in this README example. It doesn't need to do anything more than return the input model. The end goal is to make users understand how they can develop clients that help with their use case.

Original file line number Diff line number Diff line change
Expand Up @@ -26,111 +26,95 @@ Step-by-Step Instructions

import argparse
import json
import threading
import uuid

import fedn.network.grpc.fedn_pb2 as fedn

from fedn.network.clients.client_api import ClientAPI, ConnectToApiResult

client_api = ClientAPI()




def get_api_url(api_url: str, api_port: int):
url = f"{api_url}:{api_port}" if api_port else api_url
if not url.endswith("/"):
url += "/"
return url

def on_train(in_model):
training_metadata = {
"num_examples": 1,
"batch_size": 1,
"epochs": 1,
"lr": 1,
}

config = {
"round_id": 1,
}

metadata = {
"training_metadata": training_metadata,
"config": json.dumps(config),
}

# Do your training here, out_model is your result...
out_model = in_model

return out_model, metadata

def on_validate(in_model):

# Calculate metrics here...
metrics = {
"test_accuracy": 0.9,
"test_loss": 0.1,
"train_accuracy": 0.8,
"train_loss": 0.2,
}
return metrics

def main(api_url: str, api_port: int, token: str = None):
print(f"API URL: {api_url}")
print(f"API Token: {token or "-"}")
print(f"API Port: {api_port or "-"}")


client_api = ClientAPI(train_callback=on_train, validate_callback=on_validate)

url = get_api_url(api_url, api_port)

name = input("Enter Client Name: ")

url = f"{api_url}:{api_port}" if api_port else api_url

if not url.endswith("/"):
url += "/"

print(f"Client Name: {name}")

client_api.set_name(name)

client_id = str(uuid.uuid4())

print("Connecting to API...")

client_options = {
"name": "client_example",
client_api.set_client_id(client_id)

controller_config = {
"name": name,
"client_id": client_id,
"package": "local",
"preferred_combiner": "",
}
result, combiner_config = client_api.connect_to_api(url, token, client_options)

result, combiner_config = client_api.connect_to_api(url, token, controller_config)

if result != ConnectToApiResult.Assigned:
print("Failed to connect to API, exiting.")
return

print("Connected to API")


result: bool = client_api.init_grpchandler(config=combiner_config, client_name=client_id, token=token)

if not result:
return

threading.Thread(target=client_api.send_heartbeats, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start()

def on_train(request):
print("Received train request")
model_id: str = request.model_id

model = client_api.get_model_from_combiner(id=str(model_id), client_name=name)

# Do your training here, out_model is your result...
out_model = model
updated_model_id = uuid.uuid4()
client_api.send_model_to_combiner(out_model, str(updated_model_id))

training_metadata = {
"num_examples": 1,
"batch_size": 1,
"epochs": 1,
"lr": 1,
}

config = {
"round_id": 1,
}

client_api.send_model_update(
sender_name=name,
sender_role=fedn.WORKER,
client_id=client_id,
model_id=model_id,
model_update_id=str(updated_model_id),
receiver_name=request.sender.name,
receiver_role=request.sender.role,
meta={
"training_metadata": training_metadata,
"config": json.dumps(config),
},
)

client_api.subscribe("train", on_train)

threading.Thread(target=client_api.listen_to_task_stream, kwargs={"client_name": name, "client_id": client_id}, daemon=True).start()

stop_event = threading.Event()
try:
stop_event.wait()
except KeyboardInterrupt:
print("Client stopped by user.")



client_api.run()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Client Example")
parser.add_argument("--api-url", type=str, required=True, help="The API URL")
parser.add_argument("--api-port", type=int, required=False, help="The API Port")
parser.add_argument("--token", type=str, required=False, help="The API Token")

args = parser.parse_args()
main(args.api_url, args.api_port, args.token)


4. **Run the client**: Run the client by executing the following command:

.. code-block:: bash
Expand Down
Loading
Loading