Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/SK-946 | Add functionality for user defined server-functions #666

Merged
merged 25 commits into from
Oct 31, 2024
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add custom aggregator and metadata
viktorvaladi committed Sep 2, 2024
commit 76b59497505cdbf52c458e1c52b324641268fd5d
4 changes: 4 additions & 0 deletions examples/custom-aggregator/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data
seed.npz
*.tgz
*.tar.gz
6 changes: 6 additions & 0 deletions examples/custom-aggregator/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data
*.npz
*.tgz
*.tar.gz
.mnist-pytorch
client.yaml
8 changes: 8 additions & 0 deletions examples/custom-aggregator/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FEDn Project: Custom aggregator (hyperparameter tuning)
-----------------------------

Will be updated after studio update.

To run custom aggregators:

client.start_session(aggregator="custom", function_provider_path="client/aggregator.py", rounds=101)
113 changes: 113 additions & 0 deletions examples/custom-aggregator/client/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import numpy as np

from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.functionproviderbase import FunctionProviderBase


class FunctionProvider(FunctionProviderBase):
"""A FunctionProvider class responsible for aggregating model parameters
from multiple clients and performing hyperparameter tuning by adjusting
the learning rate every 20th round. The class logs the current state of
the model, learning rate, and round to facilitate monitoring and evaluation.
"""

def __init__(self) -> None:
self.current_round = -1
self.initial_parameters = None
self.learning_rates = [0.001, 0.01, 0.0001, 0.1, 0.00001]
self.current_lr_index = -1
self.current_lr = 0 # start with 0 learning rate the first round to get initial parameters
self.current_parameters = None

# Tracking metrics
self.highest_accuracy = 0
self.highest_accuracy_round = -1
self.highest_accuracy_lr = 0
self.mean_loss_per_lr = []
self.mean_acc_per_lr = []
self.highest_mean_acc = 0
self.highest_mean_acc_round = -1
self.highest_mean_acc_lr = None

def aggregate(self, results: list[tuple[list[np.ndarray], dict]]) -> list[np.ndarray]:
"""Aggregate model parameters using weighted average based on the number of examples each client has.

Args:
----
results (list of tuples): Each tuple contains:
- A list of numpy.ndarrays representing model parameters from a client.
- A dictionary containing client metadata, which must include a key "num_examples" indicating
the number of examples used by the client.

Returns:
-------
list of numpy.ndarrays: Aggregated model parameters as a list of numpy.ndarrays.

"""
total_loss = 0
total_acc = 0
num_clients = len(results)
if self.current_round == -1:
self.initial_parameters = results[0][0]
averaged_parameters = self.initial_parameters # first round no updates were made.
elif self.current_round % 20 == 0:
if self.mean_loss_per_lr:
logger.info(f"Completed Learning Rate: {self.current_lr}")
logger.info(f"Mean Loss: {np.mean(self.mean_loss_per_lr)}, Highest Accuracy: {np.max(self.mean_acc_per_lr)}")
logger.info(
f"""Highest mean accuracy across rounds: {self.highest_mean_acc}
at round {self.highest_mean_acc_round} with lr {self.highest_mean_acc_lr}"""
)

# Reset tracking for the new learning rate
self.mean_loss_per_lr = []
self.mean_acc_per_lr = []

averaged_parameters = self.initial_parameters
self.current_lr_index += 1
self.current_lr = self.learning_rates[self.current_lr_index]
else:
# Aggregate using fedavg
summed_parameters = [np.zeros_like(param) for param in results[0][0]]
total_weight = 0
for client_params, client_metadata in results:
weight = client_metadata.get("num_examples", 1)
total_weight += weight
for i, param in enumerate(client_params):
summed_parameters[i] += param * weight

total_loss += client_metadata.get("test_loss", 0)
total_acc += client_metadata.get("test_acc", 0)

averaged_parameters = [param / total_weight for param in summed_parameters]

# Calculate average loss and accuracy by number of clients
avg_loss = total_loss / num_clients if num_clients > 0 else 0
avg_acc = total_acc / num_clients if num_clients > 0 else 0

# Update the tracking for the current learning rate
self.mean_loss_per_lr.append(avg_loss)
self.mean_acc_per_lr.append(avg_acc)

# Check if we have a new highest accuracy
if avg_acc > self.highest_accuracy:
self.highest_accuracy = avg_acc
self.highest_accuracy_round = self.current_round
self.highest_accuracy_lr = self.current_lr

# Check if we have a new highest mean accuracy across rounds
if avg_acc > self.highest_mean_acc:
self.highest_mean_acc = avg_acc
self.highest_mean_acc_round = self.current_round
self.highest_mean_acc_lr = self.current_lr

# Print the metrics
logger.info(f"Round {self.current_round} - Learning Rate: {self.current_lr}")
logger.info(f"Average Test Loss: {avg_loss}, Average Test Accuracy: {avg_acc}")
logger.info(f"Highest Accuracy Achieved: {self.highest_accuracy} at round {self.highest_accuracy_round} with lr {self.highest_accuracy_lr}")

self.current_round += 1
return averaged_parameters

def get_model_metadata(self):
return {"learning_rate": self.current_lr, "parameter_tuning": True}
97 changes: 97 additions & 0 deletions examples/custom-aggregator/client/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
from math import floor

import torch
import torchvision

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)


def get_data(out_dir="data"):
# Make dir if necessary
if not os.path.exists(out_dir):
os.mkdir(out_dir)

# Only download if not already downloaded
if not os.path.exists(f"{out_dir}/train"):
torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True)
if not os.path.exists(f"{out_dir}/test"):
torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True)


def load_data(data_path, is_train=True):
"""Load data from disk.

:param data_path: Path to data file.
:type data_path: str
:param is_train: Whether to load training or test data.
:type is_train: bool
:return: Tuple of data and labels.
:rtype: tuple
"""
if data_path is None:
data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt")

data = torch.load(data_path)

if is_train:
X = data["x_train"]
y = data["y_train"]
else:
X = data["x_test"]
y = data["y_test"]

# Normalize
X = X / 255

return X, y


def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n / parts)
result = []
for i in range(parts):
result.append(dataset[i * local_n : (i + 1) * local_n])
return result


def split(out_dir="data"):
n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2))

# Make dir
if not os.path.exists(f"{out_dir}/clients"):
os.mkdir(f"{out_dir}/clients")

# Load and convert to dict
train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True)
test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False)
data = {
"x_train": splitset(train_data.data, n_splits),
"y_train": splitset(train_data.targets, n_splits),
"x_test": splitset(test_data.data, n_splits),
"y_test": splitset(test_data.targets, n_splits),
}

# Make splits
for i in range(n_splits):
subdir = f"{out_dir}/clients/{str(i+1)}"
if not os.path.exists(subdir):
os.mkdir(subdir)
torch.save(
{
"x_train": data["x_train"][i],
"y_train": data["y_train"][i],
"x_test": data["x_test"][i],
"y_test": data["y_test"][i],
},
f"{subdir}/mnist.pt",
)


if __name__ == "__main__":
# Prepare data if not already done
if not os.path.exists(abs_path + "/data/clients/1"):
get_data()
split()
12 changes: 12 additions & 0 deletions examples/custom-aggregator/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
python_env: python_env.yaml
entry_points:
build:
command: python model.py
startup:
command: python data.py
train:
command: python train.py
validate:
command: python validate.py
predict:
command: python predict.py
76 changes: 76 additions & 0 deletions examples/custom-aggregator/client/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import collections

import torch

from fedn.utils.helpers.helpers import get_helper

HELPER_MODULE = "numpyhelper"
helper = get_helper(HELPER_MODULE)


def compile_model():
"""Compile the pytorch model.

:return: The compiled model.
:rtype: torch.nn.Module
"""

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(784, 64)
self.fc2 = torch.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(32, 10)

def forward(self, x):
x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784)))
x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.log_softmax(self.fc3(x), dim=1)
return x

return Net()


def save_parameters(model, out_path):
"""Save model paramters to file.

:param model: The model to serialize.
:type model: torch.nn.Module
:param out_path: The path to save to.
:type out_path: str
"""
parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()]
helper.save(parameters_np, out_path)


def load_parameters(model_path):
"""Load model parameters from file and populate model.

param model_path: The path to load from.
:type model_path: str
:return: The loaded model.
:rtype: torch.nn.Module
"""
model = compile_model()
parameters_np = helper.load(model_path)

params_dict = zip(model.state_dict().keys(), parameters_np)
state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
model.load_state_dict(state_dict, strict=True)
return model


def init_seed(out_path="seed.npz"):
"""Initialize seed model and save it to file.

:param out_path: The path to save the seed model to.
:type out_path: str
"""
# Init and save
model = compile_model()
save_parameters(model, out_path)


if __name__ == "__main__":
init_seed("../seed.npz")
37 changes: 37 additions & 0 deletions examples/custom-aggregator/client/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import sys

import torch
from data import load_data
from model import load_parameters

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


def predict(in_model_path, out_artifact_path, data_path=None):
"""Validate model.

:param in_model_path: The path to the input model.
:type in_model_path: str
:param out_artifact_path: The path to save the predict output to.
:type out_artifact_path: str
:param data_path: The path to the data file.
:type data_path: str
"""
# Load data
x_test, y_test = load_data(data_path, is_train=False)

# Load model
model = load_parameters(in_model_path)
model.eval()

# Predict
with torch.no_grad():
y_pred = model(x_test)
# Save prediction to file/artifact, the artifact will be uploaded to the object store by the client
torch.save(y_pred, out_artifact_path)


if __name__ == "__main__":
predict(sys.argv[1], sys.argv[2])
9 changes: 9 additions & 0 deletions examples/custom-aggregator/client/python_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: mnist-pytorch
build_dependencies:
- pip
- setuptools
- wheel
dependencies:
- torch==2.3.1
- torchvision==0.18.1
- fedn
Loading