Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint Ablation for Flamby #71

Merged
merged 14 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fl4health/checkpointing/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, best_checkpoint_dir: str, best_checkpoint_name: str) -> None:

def maybe_checkpoint(self, model: nn.Module, _: Optional[float] = None) -> None:
# Always checkpoint the latest model
log(INFO, "Saving latest checkpoint with LatestTorchCheckpointer")
torch.save(model, self.best_checkpoint_path)


Expand Down
8 changes: 2 additions & 6 deletions fl4health/clients/apfl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fl4health.clients.basic_client import BasicClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.utils.losses import Losses, LossMeter, LossMeterType
from fl4health.utils.losses import Losses, LossMeterType
from fl4health.utils.metrics import Metric, MetricMeter, MetricMeterManager, MetricMeterType


Expand All @@ -24,11 +24,7 @@ def __init__(
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super(BasicClient, self).__init__(data_path, device)
self.metrics = metrics
self.checkpointer = checkpointer
self.train_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)
self.val_loss_meter = LossMeter.get_meter_by_type(loss_meter_type)
super().__init__(data_path, metrics, device, loss_meter_type, metric_meter_type, checkpointer)

# Define mapping from prediction key to meter to pass to MetricMeterManager constructor for train and val
train_key_to_meter_map = {
Expand Down
1 change: 0 additions & 1 deletion fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def _handle_reporting(
metric_dict: Dict[str, Scalar],
current_round: Optional[int] = None,
) -> None:

# If reporter is None we do not report to wandb and return
if self.wandb_reporter is None:
return
Expand Down
10 changes: 3 additions & 7 deletions fl4health/clients/fenda_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def __init__(
metric_meter_type: MetricMeterType = MetricMeterType.AVERAGE,
checkpointer: Optional[TorchCheckpointer] = None,
temperature: Optional[float] = 0.5,
perfcl_loss_weights: Optional[Tuple[float, float]] = (0.0, 0.0),
cos_sim_loss_weight: Optional[float] = 0.0,
contrastive_loss_weight: Optional[float] = 0.0,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are not None, then some of these losses (at least perfcl) are still computed, but just "zeroed" out. This is problematic if you want to use asymmetric latent spaces in a FENDA architecture.

perfcl_loss_weights: Optional[Tuple[float, float]] = None,
cos_sim_loss_weight: Optional[float] = None,
contrastive_loss_weight: Optional[float] = None,
) -> None:
super().__init__(
data_path=data_path,
Expand Down Expand Up @@ -68,7 +68,6 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
return FixedLayerExchanger(self.model.layers_to_exchange())

def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:

preds = self.model(input)

if self.contrastive_loss_weight or self.perfcl_loss_weights:
Expand All @@ -88,7 +87,6 @@ def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:
return preds

def get_parameters(self, config: Config) -> NDArrays:

# Save the parameters of the old model
assert isinstance(self.model, FendaModel)
if self.contrastive_loss_weight or self.perfcl_loss_weights:
Expand All @@ -98,7 +96,6 @@ def get_parameters(self, config: Config) -> NDArrays:
return super().get_parameters(config)

def set_parameters(self, parameters: NDArrays, config: Config) -> None:

# Set the parameters of the model
super().set_parameters(parameters, config)

Expand Down Expand Up @@ -181,7 +178,6 @@ def get_perfcl_loss(
return contrastive_loss_minimize, contrastive_loss_maximize

def compute_loss(self, preds: Dict[str, torch.Tensor], target: torch.Tensor) -> Losses:

loss = self.criterion(preds["prediction"], target)
total_loss = loss
additional_losses = {}
Expand Down
19 changes: 17 additions & 2 deletions research/flamby/fed_heart_disease/apfl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer
from fl4health.checkpointing.checkpointer import (
BestMetricTorchCheckpointer,
LatestTorchCheckpointer,
TorchCheckpointer,
)
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.utils.losses import LossMeterType
Expand Down Expand Up @@ -111,17 +115,28 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument(
"--alpha_learning_rate", action="store", type=float, help="Learning rate for the APFL alpha", default=0.01
)
parser.add_argument(
"--no_federated_checkpointing",
action="store_true",
help="boolean to indicate whether we're evaluating an APFL model or not, as those model have special args",
emersodb marked this conversation as resolved.
Show resolved Hide resolved
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}")
log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}")

federated_checkpointing = not args.no_federated_checkpointing
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
checkpoint_name = f"client_{args.client_number}_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client = FedHeartDiseaseApflClient(
data_path=args.dataset_dir,
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_heart_disease/fedadam/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_heart_disease/fedadam/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAdam

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.utils.config import load_config
from research.flamby.flamby_servers.full_exchange_server import FullExchangeServer
from research.flamby.utils import (
Expand All @@ -34,7 +34,13 @@ def main(

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = SimpleClientManager()
model = Baseline()
Expand Down Expand Up @@ -62,7 +68,9 @@ def main(
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_heart_disease/fedavg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_heart_disease/fedavg/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.utils.config import load_config
from research.flamby.flamby_servers.full_exchange_server import FullExchangeServer
from research.flamby.utils import (
Expand All @@ -32,7 +32,13 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = SimpleClientManager()
model = Baseline()
Expand Down Expand Up @@ -60,7 +66,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_heart_disease/fedprox/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_heart_disease/fedprox/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flwr.common.logger import log
from flwr.server.client_manager import SimpleClientManager

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.strategies.fedprox import FedProx
from fl4health.utils.config import load_config
from research.flamby.flamby_servers.fedprox_server import FedProxServer
Expand All @@ -32,7 +32,13 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = SimpleClientManager()
model = Baseline()
Expand Down Expand Up @@ -61,7 +67,9 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
19 changes: 17 additions & 2 deletions research/flamby/fed_heart_disease/fenda/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer
from fl4health.checkpointing.checkpointer import (
BestMetricTorchCheckpointer,
LatestTorchCheckpointer,
TorchCheckpointer,
)
from fl4health.clients.fenda_client import FendaClient
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric, MetricMeterType
Expand Down Expand Up @@ -118,16 +122,27 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument("--cos_sim_loss", action="store_true", help="Activate Cosine Similarity loss")
parser.add_argument("--contrastive_loss", action="store_true", help="Activate Contrastive loss")
parser.add_argument("--perfcl_loss", action="store_true", help="Activate PerFCL loss")
parser.add_argument(
"--no_federated_checkpointing",
action="store_true",
help="boolean to indicate whether we're evaluating an APFL model or not, as those model have special args",
emersodb marked this conversation as resolved.
Show resolved Hide resolved
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}")

federated_checkpointing = not args.no_federated_checkpointing
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
checkpoint_name = f"client_{args.client_number}_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client = FedHeartDiseaseFendaClient(
data_path=Path(args.dataset_dir),
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_heart_disease/scaffold/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
14 changes: 11 additions & 3 deletions research/flamby/fed_heart_disease/scaffold/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flamby.datasets.fed_heart_disease import Baseline
from flwr.common.logger import log

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer
from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager
from fl4health.strategies.scaffold import Scaffold
from fl4health.utils.config import load_config
Expand All @@ -34,7 +34,13 @@ def main(

checkpoint_dir = os.path.join(checkpoint_stub, run_name)
checkpoint_name = "server_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
federated_checkpointing: bool = config.get("federated_checkpointing", True)
log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}")
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client_manager = FixedSamplingByFractionClientManager()
model = Baseline()
Expand Down Expand Up @@ -65,7 +71,9 @@ def main(
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")
if federated_checkpointing:
assert isinstance(checkpointer, BestMetricTorchCheckpointer)
log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_metric}")

# Shutdown the server gracefully
server.shutdown()
Expand Down
21 changes: 18 additions & 3 deletions research/flamby/fed_ixi/apfl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer
from fl4health.checkpointing.checkpointer import (
BestMetricTorchCheckpointer,
LatestTorchCheckpointer,
TorchCheckpointer,
)
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.utils.losses import LossMeterType
Expand Down Expand Up @@ -50,7 +54,7 @@ def __init__(
self.alpha_learning_rate = alpha_learning_rate
self.client_number = client_number

def get_dataloader(self, config: Config) -> Tuple[DataLoader, DataLoader]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow a typo crept into this APFL client

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets(
self.client_number, str(self.data_path)
)
Expand Down Expand Up @@ -113,17 +117,28 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument(
"--alpha_learning_rate", action="store", type=float, help="Learning rate for the APFL alpha", default=0.01
)
parser.add_argument(
"--no_federated_checkpointing",
action="store_true",
help="boolean to indicate whether we're evaluating an APFL model or not, as those model have special args",
emersodb marked this conversation as resolved.
Show resolved Hide resolved
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Alpha Learning Rate: {args.alpha_learning_rate}")
log(INFO, f"Performing Federated Checkpointing: {not args.no_federated_checkpointing}")

federated_checkpointing = not args.no_federated_checkpointing
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
checkpoint_name = f"client_{args.client_number}_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
checkpointer = (
BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)
if federated_checkpointing
else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name)
)

client = FedIxiApflClient(
data_path=Path(args.dataset_dir),
Expand Down
1 change: 1 addition & 0 deletions research/flamby/fed_ixi/fedadam/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ n_clients: 3 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
federated_checkpointing: True # Indicates whether intermediate or latest checkpointing is used on the server side
Loading
Loading