Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint Ablation for Flamby #71

Merged
merged 14 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# PR Type
[Feature | Fix | Documentation | Other() ]
[Feature | Fix | Documentation | Other ]

# Short Description
...

Clickup Ticket(s): Link(s) if applicable.

Add a short description of what is in this PR.

# Tests Added
...

Describe the tests that have been added to ensure the codes correctness, if applicable.
Empty file.
1 change: 1 addition & 0 deletions fl4health/checkpointing/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, best_checkpoint_dir: str, best_checkpoint_name: str) -> None:

def maybe_checkpoint(self, model: nn.Module, _: Optional[float] = None) -> None:
# Always checkpoint the latest model
log(INFO, "Saving latest checkpoint with LatestTorchCheckpointer")
torch.save(model, self.best_checkpoint_path)


Expand Down
8 changes: 3 additions & 5 deletions fl4health/clients/fenda_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[TorchCheckpointer] = None,
temperature: Optional[float] = 0.5,
perfcl_loss_weights: Optional[Tuple[float, float]] = (0.0, 0.0),
cos_sim_loss_weight: Optional[float] = 0.0,
contrastive_loss_weight: Optional[float] = 0.0,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are not None, then some of these losses (at least perfcl) are still computed, but just "zeroed" out. This is problematic if you want to use asymmetric latent spaces in a FENDA architecture.

perfcl_loss_weights: Optional[Tuple[float, float]] = None,
cos_sim_loss_weight: Optional[float] = None,
contrastive_loss_weight: Optional[float] = None,
) -> None:
super().__init__(
data_path=data_path,
Expand Down Expand Up @@ -95,7 +95,6 @@ def predict(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[st
return preds, features

def get_parameters(self, config: Config) -> NDArrays:

# Save the parameters of the old model
assert isinstance(self.model, FendaModel)
if self.contrastive_loss_weight or self.perfcl_loss_weights:
Expand All @@ -105,7 +104,6 @@ def get_parameters(self, config: Config) -> NDArrays:
return super().get_parameters(config)

def set_parameters(self, parameters: NDArrays, config: Config) -> None:

# Set the parameters of the model
super().set_parameters(parameters, config)

Expand Down
1 change: 0 additions & 1 deletion fl4health/server/tabular_feature_alignment_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
# the feature information needed to perform feature alignment. Then the server
# gathers information from the clients that is necessary for initializing the global model.
if not self.initial_polls_complete:

# If the server does not have the needed feature info a priori,
# then it requests such information from the clients before the
# very first fitting round.
Expand Down
19 changes: 17 additions & 2 deletions research/flamby/fed_heart_disease/apfl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer
from fl4health.checkpointing.checkpointer import (
BestMetricTorchCheckpointer,
LatestTorchCheckpointer,
TorchCheckpointer,
)
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.utils.losses import LossMeterType
Expand Down Expand Up @@ -109,17 +113,28 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument(
"--alpha_learning_rate", action="store", type=float, help="Learning rate for the APFL alpha", default=0.01
)
parser.add_argument(
"--no_federated_checkpointing",
action="store_true",
help="boolean to indicate whether we're evaluating an APFL model or not, as those model have special args",
emersodb marked this conversation as resolved.
Show resolved Hide resolved
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}")
log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}")

federated_checkpointing = not args.no_federated_checkpointing
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
checkpoint_name = f"client_{args.client_number}_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client = FedHeartDiseaseApflClient(
data_path=args.dataset_dir,
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_heart_disease/fedadam/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_heart_disease/fedadam/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAdam

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.utils.config import load_config
from research.flamby.flamby_servers.full_exchange_server import FullExchangeServer
from research.flamby.utils import (
Expand All @@ -34,7 +34,13 @@ def main(

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = SimpleClientManager()
model = Baseline()
Expand Down Expand Up @@ -62,7 +68,9 @@ def main(
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_heart_disease/fedavg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_heart_disease/fedavg/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.utils.config import load_config
from research.flamby.flamby_servers.full_exchange_server import FullExchangeServer
from research.flamby.utils import (
Expand All @@ -32,7 +32,13 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = SimpleClientManager()
model = Baseline()
Expand Down Expand Up @@ -60,7 +66,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_heart_disease/fedprox/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_heart_disease/fedprox/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flwr.common.logger import log
from flwr.server.client_manager import SimpleClientManager

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.strategies.fedprox import FedProx
from fl4health.utils.config import load_config
from research.flamby.flamby_servers.fedprox_server import FedProxServer
Expand All @@ -32,7 +32,13 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = SimpleClientManager()
model = Baseline()
Expand Down Expand Up @@ -61,7 +67,9 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
19 changes: 17 additions & 2 deletions research/flamby/fed_heart_disease/fenda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer
from fl4health.checkpointing.checkpointer import (
BestMetricTorchCheckpointer,
LatestTorchCheckpointer,
TorchCheckpointer,
)
from fl4health.clients.fenda_client import FendaClient
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric
Expand Down Expand Up @@ -116,16 +120,27 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument("--cos_sim_loss", action="store_true", help="Activate Cosine Similarity loss")
parser.add_argument("--contrastive_loss", action="store_true", help="Activate Contrastive loss")
parser.add_argument("--perfcl_loss", action="store_true", help="Activate PerFCL loss")
parser.add_argument(
"--no_federated_checkpointing",
action="store_true",
help="boolean to indicate whether we're evaluating an APFL model or not, as those model have special args",
emersodb marked this conversation as resolved.
Show resolved Hide resolved
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}")

federated_checkpointing = not args.no_federated_checkpointing
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
checkpoint_name = f"client_{args.client_number}_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client = FedHeartDiseaseFendaClient(
data_path=Path(args.dataset_dir),
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_heart_disease/scaffold/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_heart_disease/scaffold/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flamby.datasets.fed_heart_disease import Baseline
from flwr.common.logger import log

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager
from fl4health.strategies.scaffold import Scaffold
from fl4health.utils.config import load_config
Expand All @@ -34,7 +34,13 @@ def main(

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = FixedSamplingByFractionClientManager()
model = Baseline()
Expand Down Expand Up @@ -65,7 +71,9 @@ def main(
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
21 changes: 18 additions & 3 deletions research/flamby/fed_ixi/apfl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer
from fl4health.checkpointing.checkpointer import (
BestMetricTorchCheckpointer,
LatestTorchCheckpointer,
TorchCheckpointer,
)
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.utils.losses import LossMeterType
Expand Down Expand Up @@ -48,7 +52,7 @@ def __init__(
self.alpha_learning_rate = alpha_learning_rate
self.client_number = client_number

def get_dataloader(self, config: Config) -> Tuple[DataLoader, DataLoader]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow a typo crept into this APFL client

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets(
self.client_number, str(self.data_path)
)
Expand Down Expand Up @@ -111,17 +115,28 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument(
"--alpha_learning_rate", action="store", type=float, help="Learning rate for the APFL alpha", default=0.01
)
parser.add_argument(
"--no_federated_checkpointing",
action="store_true",
help="boolean to indicate whether we're evaluating an APFL model or not, as those model have special args",
emersodb marked this conversation as resolved.
Show resolved Hide resolved
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}")
log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}")

federated_checkpointing = not args.no_federated_checkpointing
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
checkpoint_name = f"client_{args.client_number}_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client = FedIxiApflClient(
data_path=Path(args.dataset_dir),
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_ixi/fedadam/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 3 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_ixi/fedadam/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAdam

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.utils.config import load_config
from research.flamby.fed_ixi.fedadam.fedadam_model import FedAdamUNet
from research.flamby.flamby_servers.full_exchange_server import FullExchangeServer
Expand All @@ -34,7 +34,13 @@ def main(

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = SimpleClientManager()
model = FedAdamUNet()
Expand Down Expand Up @@ -62,7 +68,9 @@ def main(
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
Loading