diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml
index 0772bd138..3fa23b877 100644
--- a/.github/workflows/linting.yml
+++ b/.github/workflows/linting.yml
@@ -23,9 +23,14 @@ jobs:
- name: Install dependencies
run: pip install isort black==22.3.0
+ pip install flake8
- name: Run black
- run: black --check .
+ run: black --line-length=89 --check .
+
+
+ - name: Run FLAKE8
+ run: flake8 --max-line-length=89 --per-file-ignores="*/__init__.py:F401" ./flamby
- name: Run isort
run: isort .
diff --git a/.github/workflows/pr_validation.yml b/.github/workflows/pr_validation.yml
index ba1ae9385..4ec3eb1a1 100644
--- a/.github/workflows/pr_validation.yml
+++ b/.github/workflows/pr_validation.yml
@@ -15,6 +15,8 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
+ with:
+ python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
diff --git a/docs/conf.py b/docs/conf.py
index 4a2ceafd3..3bd1ad39d 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -175,7 +175,7 @@
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
- (master_doc, "FLamby.tex", "FLamby Documentation", "Collaboration", "manual"),
+ (master_doc, "FLamby.tex", "FLamby Documentation", "Collaboration", "manual")
]
@@ -200,7 +200,7 @@
"FLamby",
"One line description of project.",
"Miscellaneous",
- ),
+ )
]
diff --git a/flamby/benchmarks/benchmark_utils.py b/flamby/benchmarks/benchmark_utils.py
index a5f888fa9..f9e46f2d8 100644
--- a/flamby/benchmarks/benchmark_utils.py
+++ b/flamby/benchmarks/benchmark_utils.py
@@ -568,9 +568,7 @@ def ensemble_perf_from_predictions(
return ensemble_perf
-def set_dataset_specific_config(
- dataset_name, compute_ensemble_perf=False, use_gpu=True
-):
+def set_dataset_specific_config(dataset_name, compute_ensemble_perf=False, use_gpu=True):
"""_summary_
Parameters
diff --git a/flamby/benchmarks/conf.py b/flamby/benchmarks/conf.py
index b35cf8b7f..0eaab9f20 100644
--- a/flamby/benchmarks/conf.py
+++ b/flamby/benchmarks/conf.py
@@ -85,8 +85,7 @@ def get_dataset_args(
for param in params:
try:
p = getattr(
- __import__(f"flamby.datasets.{dataset_name}", fromlist=param),
- param,
+ __import__(f"flamby.datasets.{dataset_name}", fromlist=param), param
)
except AttributeError:
p = None
diff --git a/flamby/benchmarks/fed_benchmark.py b/flamby/benchmarks/fed_benchmark.py
index e4b285d15..852762a19 100644
--- a/flamby/benchmarks/fed_benchmark.py
+++ b/flamby/benchmarks/fed_benchmark.py
@@ -464,10 +464,7 @@ def main(args_cli):
parser = argparse.ArgumentParser()
parser.add_argument(
- "--GPU",
- type=int,
- default=0,
- help="GPU to run the training on (if available)",
+ "--GPU", type=int, default=0, help="GPU to run the training on (if available)"
)
parser.add_argument(
"--cpu-only",
@@ -488,10 +485,7 @@ def main(args_cli):
help="Do 0 round and 0 epoch to check if the script is working",
)
parser.add_argument(
- "--workers",
- type=int,
- default=0,
- help="Numbers of workers for the dataloader",
+ "--workers", type=int, default=0, help="Numbers of workers for the dataloader"
)
parser.add_argument(
"--learning_rate",
diff --git a/flamby/create_dataset_config.py b/flamby/create_dataset_config.py
index 4d8d33f2c..1d10db829 100644
--- a/flamby/create_dataset_config.py
+++ b/flamby/create_dataset_config.py
@@ -5,10 +5,7 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- "--path",
- type=str,
- help="The path where the dataset is located",
- required=True,
+ "--path", type=str, help="The path where the dataset is located", required=True
)
parser.add_argument(
"--dataset-name",
diff --git a/flamby/datasets/fed_camelyon16/dataset.py b/flamby/datasets/fed_camelyon16/dataset.py
index b658ee4ad..587bca18a 100644
--- a/flamby/datasets/fed_camelyon16/dataset.py
+++ b/flamby/datasets/fed_camelyon16/dataset.py
@@ -86,7 +86,8 @@ def __init__(
self.features_centers = []
self.features_sets = []
self.perms = {}
- # We need this ist to be sorted for reproducibility but shuffled to avoid weirdness
+ # We need this list to be sorted for reproducibility but shuffled to
+ # avoid weirdness
npys_list = sorted(self.tiles_dir.glob("*.npy"))
random.seed(0)
random.shuffle(npys_list)
diff --git a/flamby/datasets/fed_camelyon16/dataset_creation_scripts/download.py b/flamby/datasets/fed_camelyon16/dataset_creation_scripts/download.py
index 68483e887..5a874df5f 100644
--- a/flamby/datasets/fed_camelyon16/dataset_creation_scripts/download.py
+++ b/flamby/datasets/fed_camelyon16/dataset_creation_scripts/download.py
@@ -5,6 +5,7 @@
import sys
from pathlib import Path
+import numpy as np
import pandas as pd
from google_client import create_service
from googleapiclient.errors import HttpError
@@ -65,12 +66,11 @@ def main(path_to_secret, output_folder, port=6006, debug=False):
len(train_df.index) + len(test_df.index)
)
downloaded_images_status_file["Slide"] = None
- downloaded_images_status_file.Slide.iloc[: len(train_df.index)] = train_df[
- "name"
- ]
- downloaded_images_status_file.Slide.iloc[len(train_df.index) :] = test_df[
- "name"
- ]
+ total_size = len(train_df.index) + len(test_df.index)
+ train_idxs = np.arange(0, len(train_df.index))
+ test_idxs = np.arange(len(train_df.index), total_size)
+ downloaded_images_status_file.Slide.iloc[train_idxs] = train_df["name"]
+ downloaded_images_status_file.Slide.iloc[test_idxs] = test_df["name"]
downloaded_images_status_file.to_csv(
downloaded_images_status_file_path, index=False
)
@@ -92,7 +92,9 @@ def main(path_to_secret, output_folder, port=6006, debug=False):
port=port,
)
regex = "(?<=https://drive.google.com/file/d/)[a-zA-Z0-9]+"
- # Resourcekey is now mandatory (credit @Kris in: https://stackoverflow.com/questions/71343002/downloading-files-from-public-google-drive-in-python-scoping-issues)
+ # Resourcekey is now mandatory (credit @Kris in:
+ # https://stackoverflow.com/questions/71343002/
+ # downloading-files-from-public-google-drive-in-python-scoping-issues)
regex_rkey = "(?<=resourcekey=).+"
for current_df in [train_df, test_df]:
for i in tqdm(range(len(current_df.index))):
diff --git a/flamby/datasets/fed_camelyon16/dataset_creation_scripts/tiling_slides.py b/flamby/datasets/fed_camelyon16/dataset_creation_scripts/tiling_slides.py
index bae07cd0a..0acdce2dd 100644
--- a/flamby/datasets/fed_camelyon16/dataset_creation_scripts/tiling_slides.py
+++ b/flamby/datasets/fed_camelyon16/dataset_creation_scripts/tiling_slides.py
@@ -52,9 +52,7 @@ def __len__(self):
def __getitem__(self, idx):
pil_image = self.slide.read_region(
- self.coords[idx].astype("int_"),
- self.level,
- (self.tile_size, self.tile_size),
+ self.coords[idx].astype("int_"), self.level, (self.tile_size, self.tile_size)
).convert("RGB")
if self.transform is not None:
pil_image = self.transform(pil_image)
diff --git a/flamby/datasets/fed_camelyon16/model.py b/flamby/datasets/fed_camelyon16/model.py
index 1ce7400a4..73d328b11 100644
--- a/flamby/datasets/fed_camelyon16/model.py
+++ b/flamby/datasets/fed_camelyon16/model.py
@@ -10,15 +10,13 @@ class Baseline(nn.Module):
def __init__(self):
super(Baseline, self).__init__()
# As per the article
- self.O = 2048 # Original dimension of the input embeddings
+ self.Od = 2048 # Original dimension of the input embeddings
self.M = 128 # New dimension of the input embedding
self.L = 128 # Dimension of the new features after query and value projections
self.K = 1000 # Number of elements in each bag
- self.feature_extractor_part1 = nn.Sequential(
- nn.Linear(self.O, self.M),
- )
+ self.feature_extractor_part1 = nn.Sequential(nn.Linear(self.Od, self.M))
# The Gated Attention using tanh and sigmoid from Eq 9
# from https://arxiv.org/abs/1802.04712
diff --git a/flamby/datasets/fed_dummy_dataset.py b/flamby/datasets/fed_dummy_dataset.py
index fcb958a73..ffe4e356a 100644
--- a/flamby/datasets/fed_dummy_dataset.py
+++ b/flamby/datasets/fed_dummy_dataset.py
@@ -24,8 +24,9 @@ def __len__(self):
return self.size
def __getitem__(self, idx):
- return torch.rand(3, 224, 224).to(self.X_dtype), torch.randint(0, 2, (1,)).to(
- self.y_dtype
+ return (
+ torch.rand(3, 224, 224).to(self.X_dtype),
+ torch.randint(0, 2, (1,)).to(self.y_dtype),
)
@@ -53,10 +54,7 @@ def forward(self, X):
m = Baseline()
lo = BaselineLoss()
dl = DataLoader(
- FedDummyDataset(center=1, train=True),
- batch_size=32,
- shuffle=True,
- num_workers=0,
+ FedDummyDataset(center=1, train=True), batch_size=32, shuffle=True, num_workers=0
)
it = iter(dl)
X, y = next(it)
diff --git a/flamby/datasets/fed_heart_disease/dataset.py b/flamby/datasets/fed_heart_disease/dataset.py
index f7dc5cff2..62c367b2a 100644
--- a/flamby/datasets/fed_heart_disease/dataset.py
+++ b/flamby/datasets/fed_heart_disease/dataset.py
@@ -73,12 +73,7 @@ def __init__(
self.y_dtype = y_dtype
self.debug = debug
- self.centers_number = {
- "cleveland": 0,
- "hungarian": 1,
- "switzerland": 2,
- "va": 3,
- }
+ self.centers_number = {"cleveland": 0, "hungarian": 1, "switzerland": 2, "va": 3}
self.features = pd.DataFrame()
self.labels = pd.DataFrame()
@@ -165,9 +160,7 @@ def __init__(
}
# We finally broadcast the means and stds over all datasets
- self.mean_of_features = torch.zeros(
- (len(self.features), 13), dtype=self.X_dtype
- )
+ self.mean_of_features = torch.zeros((len(self.features), 13), dtype=self.X_dtype)
self.std_of_features = torch.ones((len(self.features), 13), dtype=self.X_dtype)
for i in range(self.mean_of_features.shape[0]):
self.mean_of_features[i] = self.centers_stats[self.centers[i]]["mean"]
@@ -177,8 +170,7 @@ def __init__(
to_select = [(self.sets[idx] == "train") for idx, _ in enumerate(self.features)]
features_train = [fp for idx, fp in enumerate(self.features) if to_select[idx]]
features_tensor_train = torch.cat(
- [features_train[i][None, :] for i in range(len(features_train))],
- axis=0,
+ [features_train[i][None, :] for i in range(len(features_train))], axis=0
)
self.mean_of_features_pooled_train = features_tensor_train.mean(axis=0)
self.std_of_features_pooled_train = features_tensor_train.std(axis=0)
diff --git a/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py b/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py
index 8014a8305..db249ba34 100755
--- a/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py
+++ b/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py
@@ -20,12 +20,9 @@ def main(output_folder, debug=False):
# location of the files in the UCI archive
accept_license(
- "https://archive-beta.ics.uci.edu/ml/datasets/heart+disease",
- "fed_heart_disease",
- )
- base_url = (
- "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/"
+ "https://archive-beta.ics.uci.edu/ml/datasets/heart+disease", "fed_heart_disease"
)
+ base_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/"
centers = ["cleveland", "hungarian", "switzerland", "va"]
md5_hashes = [
"2d91a8ff69cfd9616aa47b59d6f843db",
@@ -69,9 +66,7 @@ def main(output_folder, debug=False):
sys.exit()
# get status of download
- downloaded_status_file_path = os.path.join(
- output_folder, "download_status_file.csv"
- )
+ downloaded_status_file_path = os.path.join(output_folder, "download_status_file.csv")
if not (os.path.exists(downloaded_status_file_path)):
downloaded_status_file = pd.DataFrame()
downloaded_status_file["Status"] = ["Not found"] * 4
diff --git a/flamby/datasets/fed_heart_disease/metric.py b/flamby/datasets/fed_heart_disease/metric.py
index bd13ee11a..885d90bc3 100644
--- a/flamby/datasets/fed_heart_disease/metric.py
+++ b/flamby/datasets/fed_heart_disease/metric.py
@@ -1,5 +1,4 @@
import numpy as np
-from sklearn.metrics import roc_auc_score
def metric(y_true, y_pred):
diff --git a/flamby/datasets/fed_isic2019/benchmark.py b/flamby/datasets/fed_isic2019/benchmark.py
index e34c42f1e..7bf34b710 100644
--- a/flamby/datasets/fed_isic2019/benchmark.py
+++ b/flamby/datasets/fed_isic2019/benchmark.py
@@ -22,14 +22,7 @@
def train_model(
- model,
- optimizer,
- scheduler,
- dataloaders,
- dataset_sizes,
- device,
- lossfunc,
- num_epochs,
+ model, optimizer, scheduler, dataloaders, dataset_sizes, device, lossfunc, num_epochs
):
"""Training function
Parameters
@@ -224,16 +217,10 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument(
- "--GPU",
- type=int,
- default=0,
- help="GPU to run the training on (if available)",
+ "--GPU", type=int, default=0, help="GPU to run the training on (if available)"
)
parser.add_argument(
- "--workers",
- type=int,
- default=4,
- help="Numbers of workers for the dataloader",
+ "--workers", type=int, default=4, help="Numbers of workers for the dataloader"
)
args = parser.parse_args()
@@ -243,10 +230,7 @@ def main(args):
sz = 200
test_aug = albumentations.Compose(
- [
- albumentations.CenterCrop(sz, sz),
- albumentations.Normalize(always_apply=True),
- ]
+ [albumentations.CenterCrop(sz, sz), albumentations.Normalize(always_apply=True)]
)
test_dataset = dataset.FedIsic2019(train=False, pooled=True)
test_dataloader = torch.utils.data.DataLoader(
diff --git a/flamby/datasets/fed_isic2019/dataset_creation_scripts/download_isic.py b/flamby/datasets/fed_isic2019/dataset_creation_scripts/download_isic.py
index b56e3a05e..247bdd161 100644
--- a/flamby/datasets/fed_isic2019/dataset_creation_scripts/download_isic.py
+++ b/flamby/datasets/fed_isic2019/dataset_creation_scripts/download_isic.py
@@ -76,23 +76,17 @@
for i, row in ISIC_2019_Training_Metadata.iterrows():
if pd.isnull(row["lesion_id"]):
image = row["image"]
- os.system(
- "rm " + data_directory + "/ISIC_2019_Training_Input/" + image + ".jpg"
- )
+ os.system("rm " + data_directory + "/ISIC_2019_Training_Input/" + image + ".jpg")
if image != ISIC_2019_Training_GroundTruth["image"][i]:
print("Mismatch between Metadata and Ground Truth")
ISIC_2019_Training_GroundTruth = ISIC_2019_Training_GroundTruth.drop(i)
ISIC_2019_Training_Metadata = ISIC_2019_Training_Metadata.drop(i)
# generating dataset field from lesion_id field in the metadata dataframe
-ISIC_2019_Training_Metadata["dataset"] = ISIC_2019_Training_Metadata["lesion_id"].str[
- :4
-]
+ISIC_2019_Training_Metadata["dataset"] = ISIC_2019_Training_Metadata["lesion_id"].str[:4]
# join with HAM10000 metadata in order to expand the HAM datacenters
-result = pd.merge(
- ISIC_2019_Training_Metadata, HAM10000_metadata, how="left", on="image"
-)
+result = pd.merge(ISIC_2019_Training_Metadata, HAM10000_metadata, how="left", on="image")
result["dataset"] = result["dataset_x"] + result["dataset_y"].astype(str)
result.drop(["dataset_x", "dataset_y", "lesion_id"], axis=1, inplace=True)
diff --git a/flamby/datasets/fed_isic2019/heterogeneity_pic.py b/flamby/datasets/fed_isic2019/heterogeneity_pic.py
index d91818f8d..5edfa8fde 100644
--- a/flamby/datasets/fed_isic2019/heterogeneity_pic.py
+++ b/flamby/datasets/fed_isic2019/heterogeneity_pic.py
@@ -42,22 +42,13 @@ def forward(self, image):
parser = argparse.ArgumentParser()
parser.add_argument(
- "--GPU",
- type=int,
- default=0,
- help="GPU to run the training on (if available)",
+ "--GPU", type=int, default=0, help="GPU to run the training on (if available)"
)
parser.add_argument(
- "--workers",
- type=int,
- default=0,
- help="Numbers of workers for the dataloader",
+ "--workers", type=int, default=0, help="Numbers of workers for the dataloader"
)
parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="The seed for the UMPA and dataloading",
+ "--seed", type=int, default=42, help="The seed for the UMPA and dataloading"
)
args = parser.parse_args()
np.random.seed(args.seed)
diff --git a/flamby/datasets/fed_ixi/dataset_creation_scripts/download.py b/flamby/datasets/fed_ixi/dataset_creation_scripts/download.py
index 74c2ccca1..fd6c041de 100644
--- a/flamby/datasets/fed_ixi/dataset_creation_scripts/download.py
+++ b/flamby/datasets/fed_ixi/dataset_creation_scripts/download.py
@@ -21,10 +21,16 @@ def dl_ixi_tiny(output_folder, debug=False):
The folder where to download the dataset.
"""
print(
- "The IXI dataset is made available under the Creative Commons CC BY-SA 3.0 license.\n\
- If you use the IXI data please acknowledge the source of the IXI data, e.g. the following website: https://brain-development.org/ixi-dataset/\n\
- IXI Tiny is derived from the same source. Acknowledge the following reference on TorchIO : https://torchio.readthedocs.io/datasets.html#ixitiny\n\
- Pérez-García F, Sparks R, Ourselin S. TorchIO: a Python library for efficient loading, preprocessing, augmentation and patch-based sampling of medical images in deep learning. arXiv:2003.04696 [cs, eess, stat]. 2020. https://doi.org/10.48550/arXiv.2003.04696"
+ "The IXI dataset is made available under the Creative Commons CC BY-SA \
+ 3.0 license.\n\
+ If you use the IXI data please acknowledge the source of the IXI data, e.g.\
+ the following website: https://brain-development.org/ixi-dataset/\
+ IXI Tiny is derived from the same source. Acknowledge the following reference\
+ on TorchIO : https://torchio.readthedocs.io/datasets.html#ixitiny\
+ Pérez-García F, Sparks R, Ourselin S. TorchIO: a Python library for \
+ efficient loading, preprocessing, augmentation and patch-based sampling \
+ of medical images in deep learning. arXiv:2003.04696 [cs, eess, stat]. \
+ 2020. https://doi.org/10.48550/arXiv.2003.04696"
)
accept_license("https://brain-development.org/ixi-dataset/", "fed_ixi")
os.makedirs(output_folder, exist_ok=True)
diff --git a/flamby/datasets/fed_ixi/ixi_plotting.py b/flamby/datasets/fed_ixi/ixi_plotting.py
index 3ed60c496..e791f5efd 100644
--- a/flamby/datasets/fed_ixi/ixi_plotting.py
+++ b/flamby/datasets/fed_ixi/ixi_plotting.py
@@ -53,11 +53,7 @@ def plot_histogram(axis, array, num_positions=100, label=None, alpha=0.05, color
[0], [0], color="r", lw=4, label="Client {}".format(CENTER_LABELS_CORRESP["HH"])
),
Line2D(
- [0],
- [0],
- color="b",
- lw=4,
- label="Client {}".format(CENTER_LABELS_CORRESP["IOP"]),
+ [0], [0], color="b", lw=4, label="Client {}".format(CENTER_LABELS_CORRESP["IOP"])
),
]
ax.legend(handles=legend_elements, loc="upper right")
diff --git a/flamby/datasets/fed_ixi/loss.py b/flamby/datasets/fed_ixi/loss.py
index 8a3cea80d..daafb82b3 100644
--- a/flamby/datasets/fed_ixi/loss.py
+++ b/flamby/datasets/fed_ixi/loss.py
@@ -7,7 +7,9 @@ def __init__(self):
super(BaselineLoss, self).__init__()
def forward(self, output: torch.Tensor, target: torch.Tensor):
- """Get dice loss to evaluate the semantic segmentation model. Its value lies between 0 and 1. The more the loss is close to 0, the more the performance is good.
+ """Get dice loss to evaluate the semantic segmentation model.
+ Its value lies between 0 and 1. The more the loss is close to 0,
+ the more the performance is good.
Parameters
----------
@@ -26,7 +28,9 @@ def forward(self, output: torch.Tensor, target: torch.Tensor):
def get_dice_score(output, target, epsilon=1e-9):
- """Get dice score to evaluate the semantic segmentation model. Its value lies between 0 and 1. The more the score is close to 1, the more the performance is good.
+ """Get dice score to evaluate the semantic segmentation model.
+ Its value lies between 0 and 1. The more the score is close to 1,
+ the more the performance is good.
Parameters
----------
diff --git a/flamby/datasets/fed_ixi/metric.py b/flamby/datasets/fed_ixi/metric.py
index 32ee2aa5f..005759d5d 100644
--- a/flamby/datasets/fed_ixi/metric.py
+++ b/flamby/datasets/fed_ixi/metric.py
@@ -1,6 +1,4 @@
import numpy as np
-import torch
-import torch.nn.functional as F
def metric(y_true, y_pred):
diff --git a/flamby/datasets/fed_ixi/model.py b/flamby/datasets/fed_ixi/model.py
index db90ee1b3..f2e34ceaf 100644
--- a/flamby/datasets/fed_ixi/model.py
+++ b/flamby/datasets/fed_ixi/model.py
@@ -12,7 +12,7 @@
Requires: Python >=3.6
"""
-#### UNet
+# UNet
class Baseline(nn.Module):
@@ -117,11 +117,7 @@ def __init__(
elif dimensions == 3:
in_channels = 2 * out_channels_first_layer
self.classifier = ConvolutionalBlock(
- dimensions,
- in_channels,
- out_classes,
- kernel_size=1,
- activation=None,
+ dimensions, in_channels, out_classes, kernel_size=1, activation=None
)
def forward(self, x):
@@ -135,7 +131,7 @@ def forward(self, x):
return x
-#### Conv ####
+# Conv
class ConvolutionalBlock(nn.Module):
@@ -218,16 +214,10 @@ def add_if_not_none(module_list, module):
module_list.append(module)
-#### Decoding ####
+# Decoding
CHANNELS_DIMENSION = 1
-UPSAMPLING_MODES = (
- "nearest",
- "linear",
- "bilinear",
- "bicubic",
- "trilinear",
-)
+UPSAMPLING_MODES = ("nearest", "linear", "bilinear", "bicubic", "trilinear")
class Decoder(nn.Module):
@@ -371,11 +361,7 @@ def get_upsampling_layer(upsampling_type: str) -> nn.Upsample:
message = 'Upsampling type is "{}"' " but should be one of the following: {}"
message = message.format(upsampling_type, UPSAMPLING_MODES)
raise ValueError(message)
- upsample = nn.Upsample(
- scale_factor=2,
- mode=upsampling_type,
- align_corners=False,
- )
+ upsample = nn.Upsample(scale_factor=2, mode=upsampling_type, align_corners=False)
return upsample
@@ -395,7 +381,7 @@ def fix_upsampling_type(upsampling_type: str, dimensions: int):
return upsampling_type
-#### Encoding ####
+# Encoding
class Encoder(nn.Module):
@@ -555,9 +541,7 @@ def out_channels(self):
def get_downsampling_layer(
- dimensions: int,
- pooling_type: str,
- kernel_size: int = 2,
+ dimensions: int, pooling_type: str, kernel_size: int = 2
) -> nn.Module:
class_name = "{}Pool{}d".format(pooling_type.capitalize(), dimensions)
class_ = getattr(nn, class_name)
diff --git a/flamby/datasets/fed_ixi/utils.py b/flamby/datasets/fed_ixi/utils.py
index 5ee73b248..6229daeae 100644
--- a/flamby/datasets/fed_ixi/utils.py
+++ b/flamby/datasets/fed_ixi/utils.py
@@ -1,6 +1,6 @@
"""Federated IXI Dataset utils
-
-A set of function that allow data management suited for the `IXI dataset `_.
+A set of function that allow data management suited for the `IXI dataset
+`_.
"""
@@ -14,11 +14,8 @@
from zipfile import ZipFile
import nibabel as nib
-import nibabel.processing as processing
import numpy
import numpy as np
-from nibabel import Nifti1Header
-from numpy import ndarray
def _get_id_from_filename(x, verify_single_matches=True) -> Union[List[int], int]:
@@ -55,7 +52,8 @@ def _get_id_from_filename(x, verify_single_matches=True) -> Union[List[int], int
def _assembly_nifti_filename_regex(
patient_id: int, modality: str
) -> Union[str, PathLike, Path]:
- """Assembles NIFTI filename regular expression using the standard in the IXI dataset based on id and modality.
+ """Assembles NIFTI filename regular expression using the standard in the
+ IXI dataset based on id and modality.
Parameters
----------
@@ -78,7 +76,8 @@ def _assembly_nifti_filename_regex(
def _assembly_nifti_img_and_label_regex(
patient_id: int, modality: str
) -> Tuple[Union[str, PathLike, Path], Union[str, PathLike, Path]]:
- """Assembles NIFTI filename regular expression for image and label using the standard in the IXI tiny dataset based on id and modality.
+ """Assembles NIFTI filename regular expression for image and label using
+ the standard in the IXI tiny dataset based on id and modality.
Parameters
----------
@@ -104,7 +103,8 @@ def _assembly_nifti_img_and_label_regex(
def _find_file_in_tar(tar_file: TarFile, patient_id: int, modality) -> str:
- """Searches the file in a TAR file that corresponds to a particular regular expression.
+ """Searches the file in a TAR file that corresponds to a particular
+ regular expression.
Parameters
----------
@@ -134,7 +134,8 @@ def _find_file_in_tar(tar_file: TarFile, patient_id: int, modality) -> str:
def _find_files_in_zip(zip_file: ZipFile, patient_id: int, modality) -> Tuple[str]:
- """Searches the files in a ZIP file that corresponds to particular regular expressions.
+ """Searches the files in a ZIP file that corresponds to particular regular
+ expressions.
Parameters
----------
@@ -165,7 +166,7 @@ def _find_files_in_zip(zip_file: ZipFile, patient_id: int, modality) -> Tuple[st
result[1] = re.match(regex_label, filename).group()
except AttributeError:
continue
- if result[0] != None and result[1] != None:
+ if result[0] is not None and result[1] is not None:
return tuple(result)
raise FileNotFoundError(
@@ -176,7 +177,8 @@ def _find_files_in_zip(zip_file: ZipFile, patient_id: int, modality) -> Tuple[st
def _extract_center_name_from_filename(filename: str):
"""Extracts center name from file dataset.
- Unfortunately, IXI has the center encoded in the namefile rather than in the demographics information.
+ Unfortunately, IXI has the center encoded in the namefile rather than in
+ the demographics information.
Parameters
----------
@@ -189,7 +191,8 @@ def _extract_center_name_from_filename(filename: str):
Name of the center where the data comes from (e.g. Guys for the previous example)
"""
- # We decided to wrap a function for this for clarity and easier modularity for future expansion
+ # We decided to wrap a function for this for clarity and easier modularity
+ # for future expansion
return filename.split("-")[1]
@@ -201,7 +204,8 @@ def _load_nifti_image_by_id(
Parameters
----------
tar_file : TarFile
- `TarFile `_ object
+ `TarFile
+ `_ object
patient_id : int
Patient's ID whose image is to be extracted.
modality : str
@@ -214,7 +218,8 @@ def _load_nifti_image_by_id(
img : ndarray
NumPy array containing the intensities of the voxels.
center_name : str
- Name of the center the file comes from. In IXI this is encoded only in the filename.
+ Name of the center the file comes from. In IXI this is encoded only
+ in the filename.
"""
filename = _find_file_in_tar(tar_file, patient_id, modality)
with tempfile.TemporaryDirectory() as td:
@@ -237,7 +242,8 @@ def _load_nifti_image_and_label_by_id(
Parameters
----------
zip_file : ZipFile
- `ZipFile `_ object
+ `ZipFile
+ `_ object
patient_id : int
Patient's ID whose image is to be extracted.
modality : str
@@ -252,7 +258,8 @@ def _load_nifti_image_and_label_by_id(
label : ndarray
NumPy array containing the intensities of the voxels.
center_name : str
- Name of the center the file comes from. In IXI this is encoded only in the filename.
+ Name of the center the file comes from. In IXI this is encoded only in
+ the filename.
"""
img_filename, label_filename = _find_files_in_zip(zip_file, patient_id, modality)
with tempfile.TemporaryDirectory() as td:
@@ -313,11 +320,20 @@ def _create_train_test_split(
Returns
-------
train_test_hh : list
- A list containing randomly generated dichotomous values. The size is the number of images from HH hospital. Dichotomous values (train and test) follow a train test split threshold (e. g. 70%).
+ A list containing randomly generated dichotomous values.
+ The size is the number of images from HH hospital.
+ Dichotomous values (train and test) follow a train test split
+ threshold (e. g. 70%).
train_test_guys : list
- A list containing randomly generated dichotomous values. The size is the number of images from Guys hospital. Dichotomous values (train and test) follow a train test split threshold (e. g. 70%).
+ A list containing randomly generated dichotomous values.
+ The size is the number of images from Guys hospital.
+ Dichotomous values (train and test) follow a train test split
+ threshold (e. g. 70%).
train_test_iop : list
- A list containing randomly generated dichotomous values. The size is the number of images from IOP hospital. Dichotomous values (train and test) follow a train test split threshold (e. g. 70%).
+ A list containing randomly generated dichotomous values.
+ The size is the number of images from IOP hospital.
+ Dichotomous values (train and test) follow a train test split
+ threshold (e. g. 70%).
"""
split_ratio = 0.8
diff --git a/flamby/datasets/fed_kits19/benchmark.py b/flamby/datasets/fed_kits19/benchmark.py
index 45f1c94d0..43ea8a154 100644
--- a/flamby/datasets/fed_kits19/benchmark.py
+++ b/flamby/datasets/fed_kits19/benchmark.py
@@ -22,14 +22,7 @@
def train_model(
- model,
- optimizer,
- scheduler,
- dataloaders,
- dataset_sizes,
- device,
- lossfunc,
- num_epochs,
+ model, optimizer, scheduler, dataloaders, dataset_sizes, device, lossfunc, num_epochs
):
"""Training function
Parameters
@@ -105,9 +98,7 @@ def train_model(
best_model_wts = copy.deepcopy(model.state_dict())
print(
- "Training Loss: {:.4f} Validation Acc: {:.4f} ".format(
- epoch_loss, epoch_acc
- )
+ "Training Loss: {:.4f} Validation Acc: {:.4f} ".format(epoch_loss, epoch_acc)
)
training_loss_list.append(epoch_loss)
training_dice_list.append(epoch_acc)
@@ -134,7 +125,7 @@ def main(args):
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU)
torch.use_deterministic_algorithms(False)
- dict = check_dataset_from_config(dataset_name="fed_kits19", debug=False)
+ check_dataset_from_config(dataset_name="fed_kits19", debug=False)
train_dataset = FedKits19(train=True, pooled=True)
train_dataloader = torch.utils.data.DataLoader(
@@ -152,7 +143,6 @@ def main(args):
dataloaders = {"train": train_dataloader, "test": test_dataloader}
dataset_sizes = {"train": len(train_dataset), "test": len(test_dataset)}
- # device = torch.device("cuda:"+str(args.GPU) if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device", device)
@@ -167,9 +157,7 @@ def main(args):
lossfunc = BaselineLoss()
- optimizer = torch.optim.Adam(
- model.parameters(), LR, weight_decay=3e-5, amsgrad=True
- )
+ optimizer = torch.optim.Adam(model.parameters(), LR, weight_decay=3e-5, amsgrad=True)
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
@@ -201,29 +189,15 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument(
- "--GPU",
- type=int,
- default=0,
- help="GPU to run the training on (if available)",
- )
- parser.add_argument(
- "--workers",
- type=int,
- default=1,
- help="Numbers of workers for the dataloader",
+ "--GPU", type=int, default=0, help="GPU to run the training on (if available)"
)
parser.add_argument(
- "--epochs",
- type=int,
- default=NUM_EPOCHS_POOLED,
- help="Numbers of Epochs",
+ "--workers", type=int, default=1, help="Numbers of workers for the dataloader"
)
parser.add_argument(
- "--seed",
- type=int,
- default=0,
- help="Seed",
+ "--epochs", type=int, default=NUM_EPOCHS_POOLED, help="Numbers of Epochs"
)
+ parser.add_argument("--seed", type=int, default=0, help="Seed")
args = parser.parse_args()
main(args)
diff --git a/flamby/datasets/fed_kits19/dataset.py b/flamby/datasets/fed_kits19/dataset.py
index 07090fc06..458614994 100644
--- a/flamby/datasets/fed_kits19/dataset.py
+++ b/flamby/datasets/fed_kits19/dataset.py
@@ -25,7 +25,11 @@
import numpy as np
import pandas as pd
import torch
-from batchgenerators.utilities.file_and_folder_operations import *
+from batchgenerators.utilities.file_and_folder_operations import (
+ isfile,
+ join,
+ load_pickle,
+)
from nnunet.training.data_augmentation.default_data_augmentation import (
default_3D_augmentation_params,
get_patch_size,
@@ -33,11 +37,9 @@
from torch.utils.data import Dataset
import flamby.datasets.fed_kits19
-from flamby.datasets.fed_kits19.dataset_creation_scripts.utils.data_augmentations import (
- transformations,
-)
-from flamby.datasets.fed_kits19.dataset_creation_scripts.utils.set_environment_variables import (
+from flamby.datasets.fed_kits19.dataset_creation_scripts.utils import (
set_environment_variables,
+ transformations,
)
from flamby.utils import check_dataset_from_config
@@ -132,9 +134,7 @@ def __init__(
self.images_path = OrderedDict()
for i in self.images:
self.images_path[c] = OrderedDict()
- self.images_path[c]["data_file"] = join(
- self.dataset_directory, "%s.npz" % i
- )
+ self.images_path[c]["data_file"] = join(self.dataset_directory, "%s.npz" % i)
self.images_path[c]["properties_file"] = join(
self.dataset_directory, "%s.pkl" % i
)
@@ -160,18 +160,10 @@ def __getitem__(self, idx):
# randomly oversample the foreground classes
if self.oversample_next_sample == 1:
self.oversample_next_sample = 0
- item = self.oversample_foreground_class(
- case_all_data,
- True,
- properties,
- )
+ item = self.oversample_foreground_class(case_all_data, True, properties)
else:
self.oversample_next_sample = 1
- item = self.oversample_foreground_class(
- case_all_data,
- False,
- properties,
- )
+ item = self.oversample_foreground_class(case_all_data, False, properties)
# apply data augmentations
if self.train_test == "train":
@@ -181,12 +173,7 @@ def __getitem__(self, idx):
return np.squeeze(item["data"], axis=1), np.squeeze(item["target"], axis=1)
- def oversample_foreground_class(
- self,
- case_all_data,
- force_fg,
- properties,
- ):
+ def oversample_foreground_class(self, case_all_data, force_fg, properties):
# taken from nnunet
data_shape = (1, 1, *self.patch_size)
seg_shape = (1, 1, *self.patch_size)
@@ -242,7 +229,7 @@ def oversample_foreground_class(
# all
selected_class = None
voxels_of_that_class = None
- print("case does not contain any foreground classes", i)
+ print("case does not contain any foreground classes")
else:
selected_class = np.random.choice(foreground_classes)
diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/create_config.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/create_config.py
index d7c736f6c..802e99e64 100644
--- a/flamby/datasets/fed_kits19/dataset_creation_scripts/create_config.py
+++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/create_config.py
@@ -1,6 +1,6 @@
import argparse
-from flamby.utils import create_config, get_config_file_path, write_value_in_config
+from flamby.utils import create_config
if __name__ == "__main__":
parser = argparse.ArgumentParser()
diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/kits19_heterogenity_plot.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/kits19_heterogenity_plot.py
index f87ad4d16..f0a0f20d7 100644
--- a/flamby/datasets/fed_kits19/dataset_creation_scripts/kits19_heterogenity_plot.py
+++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/kits19_heterogenity_plot.py
@@ -6,7 +6,7 @@
import nibabel as nib
import numpy as np
import seaborn as sns
-from batchgenerators.utilities.file_and_folder_operations import *
+from batchgenerators.utilities.file_and_folder_operations import join
from matplotlib.lines import Line2D
from numpy import random
diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/parsing_and_adding_metadata.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/parsing_and_adding_metadata.py
index f2407cd28..9b0c96080 100644
--- a/flamby/datasets/fed_kits19/dataset_creation_scripts/parsing_and_adding_metadata.py
+++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/parsing_and_adding_metadata.py
@@ -1,4 +1,5 @@
-# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research
+# Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +32,7 @@
subfolders,
)
-from flamby.datasets.fed_kits19.dataset_creation_scripts.utils.set_environment_variables import (
+from flamby.datasets.fed_kits19.dataset_creation_scripts.utils import (
set_environment_variables,
)
from flamby.utils import get_config_file_path, read_config, write_value_in_config
@@ -103,9 +104,7 @@ def read_csv_file(csv_file="../metadata/anony_sites.csv"):
data_length = len(client_data_idxx)
train_ids = int(0.8 * data_length)
for i in client_data_idxx[:train_ids]:
- writer.writerow(
- [i, silo_count, "train", "train_" + str(silo_count)]
- )
+ writer.writerow([i, silo_count, "train", "train_" + str(silo_count)])
for i in client_data_idxx[train_ids:]:
writer.writerow([i, silo_count, "test", "test_" + str(silo_count)])
@@ -122,7 +121,8 @@ def read_csv_file(csv_file="../metadata/anony_sites.csv"):
if __name__ == "__main__":
"""
- This is the KiTS dataset after Nick fixed all the labels that had errors. Downloaded on Jan 6th 2020
+ This is the KiTS dataset after Nick fixed all the labels that had errors.
+ Downloaded on Jan 6th 2020
"""
# parse python script input parameters
@@ -155,7 +155,7 @@ def read_csv_file(csv_file="../metadata/anony_sites.csv"):
case_ids, site_ids, unique_hospital_ids, thresholded_ids = read_csv_file()
print(thresholded_ids)
- if args.debug == True:
+ if args.debug:
train_patients = thresholded_ids[:25]
test_patients = all_cases[210:211] # we do not need the test data
else:
@@ -183,9 +183,7 @@ def read_csv_file(csv_file="../metadata/anony_sites.csv"):
json_dict["reference"] = "KiTS data for nnunet_library"
json_dict["licence"] = ""
json_dict["release"] = "0.0"
- json_dict["modality"] = {
- "0": "CT",
- }
+ json_dict["modality"] = {"0": "CT"}
json_dict["labels"] = {"0": "background", "1": "Kidney", "2": "Tumor"}
json_dict["numTraining"] = len(train_patient_names)
diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/run_nnUnet_plan_and_preprocess.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/run_nnUnet_plan_and_preprocess.py
index be8ef3e14..ceb7ab39b 100644
--- a/flamby/datasets/fed_kits19/dataset_creation_scripts/run_nnUnet_plan_and_preprocess.py
+++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/run_nnUnet_plan_and_preprocess.py
@@ -1,4 +1,5 @@
-# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research
+# Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +18,7 @@
# information
import sys
-from flamby.datasets.fed_kits19.dataset_creation_scripts.utils.set_environment_variables import (
+from flamby.datasets.fed_kits19.dataset_creation_scripts.utils import (
set_environment_variables,
)
from flamby.utils import get_config_file_path, write_value_in_config
@@ -47,14 +48,7 @@
for _ in range(2):
sys.argv.pop(index_num_threads)
- sys.argv = sys.argv + [
- "-t",
- "064",
- "-tf",
- args.num_threads,
- "-tl",
- args.num_threads,
- ]
+ sys.argv = sys.argv + ["-t", "064", "-tf", args.num_threads, "-tl", args.num_threads]
main()
path_to_config_file = get_config_file_path("fed_kits19", False)
diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/__init__.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/__init__.py
index e69de29bb..36dfb231b 100644
--- a/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/__init__.py
+++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/__init__.py
@@ -0,0 +1,2 @@
+from .data_augmentations import transformations
+from .set_environment_variables import set_environment_variables
diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/data_augmentations.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/data_augmentations.py
index 207477eb1..3fc89900b 100644
--- a/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/data_augmentations.py
+++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/data_augmentations.py
@@ -1,4 +1,5 @@
-# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research
+# Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -47,16 +48,11 @@
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import (
NonDetMultiThreadedAugmenter,
)
-except ImportError as ie:
+except ImportError:
NonDetMultiThreadedAugmenter = None
-def transformations(
- patch_size,
- params,
- border_val_seg=-1,
- regions=None,
-):
+def transformations(patch_size, params, border_val_seg=-1, regions=None):
assert (
params.get("mirror") is None
), "old version of params, use new keyword do_mirror"
@@ -72,7 +68,8 @@ def transformations(
SegChannelSelectionTransform(params.get("selected_seg_channels"))
)
- # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
+ # don't do color augmentations while in 2d mode with 3d data because the
+ # color channel is overloaded!!
if params.get("dummy_2D") is not None and params.get("dummy_2D"):
tr_transforms.append(Convert3DTo2DTransform())
patch_size_spatial = patch_size[1:]
diff --git a/flamby/datasets/fed_kits19/metric.py b/flamby/datasets/fed_kits19/metric.py
index 6cb533abe..27b254f1a 100644
--- a/flamby/datasets/fed_kits19/metric.py
+++ b/flamby/datasets/fed_kits19/metric.py
@@ -3,7 +3,21 @@
import torch.nn.functional as F
from tqdm import tqdm
-softmax_helper = lambda x: F.softmax(x, 1)
+
+def softmax_helper(x):
+ """This function computes the softmax using torch functionnal on the 1-axis.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ The input.
+
+ Returns
+ -------
+ torch.Tensor
+ Output
+ """
+ return F.softmax(x, 1)
def Dice_coef(output, target, eps=1e-5): # dice score used for evaluation
@@ -25,12 +39,7 @@ def metric(predictions, gt):
return (tk_dice + tu_dice) / 2
-def evaluate_dice_on_tests(
- model,
- test_dataloaders,
- metric,
- use_gpu=True,
-):
+def evaluate_dice_on_tests(model, test_dataloaders, metric, use_gpu=True):
"""This function takes a pytorch model and evaluate it on a list of\
dataloaders using the provided metric function.
diff --git a/flamby/datasets/fed_kits19/model.py b/flamby/datasets/fed_kits19/model.py
index 5d94d5c3e..1e8fd8b09 100644
--- a/flamby/datasets/fed_kits19/model.py
+++ b/flamby/datasets/fed_kits19/model.py
@@ -1,4 +1,5 @@
-# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research
+# Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/flamby/datasets/fed_lidc_idri/__init__.py b/flamby/datasets/fed_lidc_idri/__init__.py
index 864393d05..09a84c341 100644
--- a/flamby/datasets/fed_lidc_idri/__init__.py
+++ b/flamby/datasets/fed_lidc_idri/__init__.py
@@ -10,8 +10,5 @@
)
from flamby.datasets.fed_lidc_idri.dataset import FedLidcIdri, LidcIdriRaw, collate_fn
from flamby.datasets.fed_lidc_idri.loss import BaselineLoss
-from flamby.datasets.fed_lidc_idri.metric import (
- evaluate_dice_on_tests_by_chunks,
- metric,
-)
+from flamby.datasets.fed_lidc_idri.metric import evaluate_dice_on_tests_by_chunks, metric
from flamby.datasets.fed_lidc_idri.model import Baseline
diff --git a/flamby/datasets/fed_lidc_idri/benchmark.py b/flamby/datasets/fed_lidc_idri/benchmark.py
index b9a56d23f..21a63b717 100644
--- a/flamby/datasets/fed_lidc_idri/benchmark.py
+++ b/flamby/datasets/fed_lidc_idri/benchmark.py
@@ -109,9 +109,7 @@ def main(num_workers_torch, use_gpu=True, gpu_id=0, log=False, debug=False):
if log:
writer.add_scalar(
- "Loss/train/client",
- tot_loss / num_local_steps_per_epoch,
- e,
+ "Loss/train/client", tot_loss / num_local_steps_per_epoch, e
)
# Finally, evaluate DICE
@@ -143,10 +141,7 @@ def main(num_workers_torch, use_gpu=True, gpu_id=0, log=False, debug=False):
default=20,
)
parser.add_argument(
- "--gpu-id",
- type=int,
- default=0,
- help="PCI Bus id of the GPU to use.",
+ "--gpu-id", type=int, default=0, help="PCI Bus id of the GPU to use."
)
parser.add_argument(
"--cpu-only",
diff --git a/flamby/datasets/fed_lidc_idri/data_utils.py b/flamby/datasets/fed_lidc_idri/data_utils.py
index 1f651f78a..90e45df9f 100644
--- a/flamby/datasets/fed_lidc_idri/data_utils.py
+++ b/flamby/datasets/fed_lidc_idri/data_utils.py
@@ -184,11 +184,7 @@ def random_sampler(image, label, patch_shape=(128, 128, 128), n_samples=2):
paddings,
mode="reflect",
).squeeze()
- label = F.pad(
- label,
- paddings,
- mode="constant",
- )
+ label = F.pad(label, paddings, mode="constant")
# Extract patches, taking into account the shift in coordinates due to padding
image_patches = extract_patches(image, centroids + patch_shape, patch_shape)
@@ -232,12 +228,7 @@ def all_sampler(X, y, patch_shape=(128, 128, 128)):
def fast_sampler(
- X,
- y,
- patch_shape=(128, 128, 128),
- n_patches=2,
- ratio=1.0,
- center=False,
+ X, y, patch_shape=(128, 128, 128), n_patches=2, ratio=1.0, center=False
):
"""
Parameters
@@ -274,9 +265,9 @@ def fast_sampler(
# Add noise to centroids so that the nodules are not always centered in
# the patch:
if not center:
- noise = (
- torch.rand(centroids_1.shape[0], 3) * torch.max(patch_shape)
- ).long() % (patch_shape.div(2, rounding_mode="floor")[None, ...])
+ noise = (torch.rand(centroids_1.shape[0], 3) * torch.max(patch_shape)).long() % (
+ patch_shape.div(2, rounding_mode="floor")[None, ...]
+ )
centroids_1 += noise - patch_shape.div(4, rounding_mode="floor")
# Sample random centroids
@@ -297,16 +288,8 @@ def fast_sampler(
paddings = tuple(torch.stack([patch_shape, patch_shape], dim=-1).flatten())[::-1]
- X = F.pad(
- X[None, None, :, :, :],
- paddings,
- mode="reflect",
- ).squeeze()
- y = F.pad(
- y,
- paddings,
- mode="constant",
- )
+ X = F.pad(X[None, None, :, :, :], paddings, mode="reflect").squeeze()
+ y = F.pad(y, paddings, mode="constant")
# Extract patches, taking into account the shift in coordinates due to padding
image_patches = extract_patches(X, centroids + patch_shape, patch_shape)
diff --git a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/download_ct_scans.py b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/download_ct_scans.py
index ea79594af..d45f71ad7 100644
--- a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/download_ct_scans.py
+++ b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/download_ct_scans.py
@@ -242,9 +242,7 @@ def LIDC_to_niftis(extraction_results_dataframe, spacing=[1.0, 1.0, 1.0], debug=
extraction_results_dataframe.iterrows(),
)
progbar = tqdm.tqdm(
- loop,
- total=extraction_results_dataframe.shape[0],
- desc="Converting to NiFTIs...",
+ loop, total=extraction_results_dataframe.shape[0], desc="Converting to NiFTIs..."
)
converted_dicoms = Parallel(n_jobs=1, prefer="processes")(
delayed(convert_to_niftis)(*t, spacing=spacing) for t in progbar
diff --git a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/process_raw.py b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/process_raw.py
index 7378b7ad4..96be254fc 100644
--- a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/process_raw.py
+++ b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/process_raw.py
@@ -295,9 +295,7 @@ def opN(Nodule):
xs = roi.findall("{http://www.nih.gov}edgeMap/{http://www.nih.gov}xCoord")
ys = roi.findall("{http://www.nih.gov}edgeMap/{http://www.nih.gov}yCoord")
- zs = [
- locz,
- ] * len(xs)
+ zs = [locz] * len(xs)
xs = np.array([int(x.text) for x in xs])
ys = np.array([int(y.text) for y in ys])
diff --git a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/tciaclient.py b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/tciaclient.py
index 9bb059e2b..675e8538b 100644
--- a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/tciaclient.py
+++ b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/tciaclient.py
@@ -1,15 +1,16 @@
-# Copy pasted from https://github.com/nadirsaghar/TCIA-REST-API-Client/blob/master/tcia-rest-client-python/src/tciaclient.py
+# Copy pasted from
+# https://github.com/nadirsaghar/TCIA-REST-API-Client/blob/master/
+# tcia-rest-client-python/src/tciaclient.py
#
-import math
import os
-import sys
import urllib.error
import urllib.parse
import urllib.request
#
-# Refer https://wiki.cancerimagingarchive.net/display/Public/REST+API+Usage+Guide for complete list of API
+# Refer to https://wiki.cancerimagingarchive.net/display/
+# Public/REST+API+Usage+Guide for complete list of API
#
@@ -141,11 +142,7 @@ def get_body_part_values(
return resp
def get_patient_study(
- self,
- collection=None,
- patientId=None,
- studyInstanceUid=None,
- outputFormat="json",
+ self, collection=None, patientId=None, studyInstanceUid=None, outputFormat="json"
):
serviceUrl = self.baseUrl + "/query/" + self.GET_PATIENT_STUDY
diff --git a/flamby/datasets/fed_lidc_idri/lidc_heterogeneity_plot.py b/flamby/datasets/fed_lidc_idri/lidc_heterogeneity_plot.py
index 2e99c7aa9..dedf7aaaa 100644
--- a/flamby/datasets/fed_lidc_idri/lidc_heterogeneity_plot.py
+++ b/flamby/datasets/fed_lidc_idri/lidc_heterogeneity_plot.py
@@ -14,12 +14,7 @@ def make_plot():
list_x = np.linspace(0, 1, num=200)
for center in [0, 1, 2, 3]:
print(f"doing center {center}")
- ds = FedLidcIdri(
- train=True,
- pooled=False,
- center=center,
- debug=False,
- )
+ ds = FedLidcIdri(train=True, pooled=False, center=center, debug=False)
list_data = []
for k in tqdm(range(len(ds))):
data = ds[k][0].detach().cpu().ravel()
diff --git a/flamby/datasets/fed_synthetic/dataset_creation_scripts/download.py b/flamby/datasets/fed_synthetic/dataset_creation_scripts/download.py
index 96dcddfa1..4f1546f32 100755
--- a/flamby/datasets/fed_synthetic/dataset_creation_scripts/download.py
+++ b/flamby/datasets/fed_synthetic/dataset_creation_scripts/download.py
@@ -98,10 +98,7 @@ def main(output_folder, debug=False, **kwargs):
default=None,
)
parser.add_argument(
- "--noise-heterogeneity",
- type=float,
- help="Sample repartition.",
- default=None,
+ "--noise-heterogeneity", type=float, help="Sample repartition.", default=None
)
parser.add_argument(
"--features-heterogeneity",
diff --git a/flamby/datasets/fed_synthetic/synthetic_generator.py b/flamby/datasets/fed_synthetic/synthetic_generator.py
index 0dc7f4a95..e4bc2ad03 100644
--- a/flamby/datasets/fed_synthetic/synthetic_generator.py
+++ b/flamby/datasets/fed_synthetic/synthetic_generator.py
@@ -127,9 +127,7 @@ def generate_synthetic_dataset(
if noise_heterogeneity is None:
snr_locs = np.ones(n_centers) * snr
elif type(noise_heterogeneity) in [list, np.array]:
- assert (
- snr == 3
- ), "Option snr is incompatible with noise_heterogeneity as a list."
+ assert snr == 3, "Option snr is incompatible with noise_heterogeneity as a list."
snr_locs = np.array(noise_heterogeneity)
else:
raise ValueError(
diff --git a/flamby/datasets/fed_tcga_brca/benchmark.py b/flamby/datasets/fed_tcga_brca/benchmark.py
index 30db78f90..35270fbbc 100644
--- a/flamby/datasets/fed_tcga_brca/benchmark.py
+++ b/flamby/datasets/fed_tcga_brca/benchmark.py
@@ -199,21 +199,13 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument(
- "--GPU",
- type=int,
- default=0,
- help="GPU to run the training on (if available)",
+ "--GPU", type=int, default=0, help="GPU to run the training on (if available)"
)
parser.add_argument(
- "--workers",
- type=int,
- default=4,
- help="Numbers of workers for the dataloader",
+ "--workers", type=int, default=4, help="Numbers of workers for the dataloader"
)
parser.add_argument(
- "--log",
- action="store_true",
- help="Whether or not to dump tensorboard events.",
+ "--log", action="store_true", help="Whether or not to dump tensorboard events."
)
parser.add_argument(
"--log-period",
diff --git a/flamby/datasets/fed_tcga_brca/dataset.py b/flamby/datasets/fed_tcga_brca/dataset.py
index 19a27954e..a30386224 100644
--- a/flamby/datasets/fed_tcga_brca/dataset.py
+++ b/flamby/datasets/fed_tcga_brca/dataset.py
@@ -48,10 +48,7 @@ def __len__(self):
def __getitem__(self, idx):
x = self.data.iloc[idx, 1:40]
y = self.data.iloc[idx, 40:42]
- return (
- torch.tensor(x, dtype=self.X_dtype),
- torch.tensor(y, dtype=self.y_dtype),
- )
+ return (torch.tensor(x, dtype=self.X_dtype), torch.tensor(y, dtype=self.y_dtype))
class FedTcgaBrca(TcgaBrcaRaw):
diff --git a/flamby/datasets/split_utils.py b/flamby/datasets/split_utils.py
index 25585c206..ae297fa20 100644
--- a/flamby/datasets/split_utils.py
+++ b/flamby/datasets/split_utils.py
@@ -37,7 +37,7 @@ def split_indices_linear(original_table, dataset_sizes, num_target_centers):
current_idx = 0
for idx_new_client in range(num_target_centers - 1):
assignment_new_client[
- current_idx : current_idx + num_samples_per_new_client
+ slice(current_idx, current_idx + num_samples_per_new_client)
] = idx_new_client
current_idx += num_samples_per_new_client
assignment_new_client[current_idx:] = num_target_centers - 1
@@ -219,10 +219,10 @@ def split_dataset(
_current_idx = 0
for idx_client_orig, length_client in enumerate(client_size_list):
original_table[split][
- 0, _current_idx : _current_idx + length_client
+ 0, np.arange(_current_idx, _current_idx + length_client)
] = idx_client_orig
original_table[split][
- 1, _current_idx : _current_idx + length_client
+ 1, np.arange(_current_idx, _current_idx + length_client)
] = np.arange(0, length_client)
_current_idx += length_client
diff --git a/flamby/extract_config.py b/flamby/extract_config.py
index 087ba46d4..ba74c2817 100644
--- a/flamby/extract_config.py
+++ b/flamby/extract_config.py
@@ -36,9 +36,7 @@ def main(args_cli):
"You should provide as many dataset names as you gave results"
" files or 1 if they all come from the same dataset."
)
- optimizers_classes = [
- e[1] for e in inspect.getmembers(torch.optim, inspect.isclass)
- ]
+ optimizers_classes = [e[1] for e in inspect.getmembers(torch.optim, inspect.isclass)]
csvs = [pd.read_csv(e) for e in csv_files]
for dname, csv, csvf in zip(dataset_names, csvs, csv_files):
config = {}
diff --git a/flamby/personalization_example/plot_perso_results.py b/flamby/personalization_example/plot_perso_results.py
index cebd604bb..96a16ed33 100644
--- a/flamby/personalization_example/plot_perso_results.py
+++ b/flamby/personalization_example/plot_perso_results.py
@@ -1,39 +1,21 @@
# Plot
import matplotlib.pyplot as plt
-
-# import numpy as np
import pandas as pd
import seaborn as sns
sns.set_theme(style="darkgrid")
-# datasets_names = ["Fed-Heart-Disease", "Fed-Camelyon16", "Fed-ISIC2019"]
-# n_repetitions = 5
-# rows = []
-# for d in datasets_names:
-# for se in np.arange(42, 42 + n_repetitions):
-# rows.append({"dataset": d, "seed": se, "perf": float(np.random.uniform(0., 1., 1)), "finetune": True})
-# rows.append({"dataset": d, "seed": se, "perf": float(np.random.uniform(0., 1., 1)), "finetune": False})
-
-# results = pd.DataFrame.from_dict(rows)
-
results = pd.read_csv("results_perso_vs_normal.csv")
results = results.rename(columns={"perf": "Performance"})
fig, ax = plt.subplots()
-g = sns.boxplot(
- data=results,
- x="dataset",
- y="Performance",
- hue="finetune",
- ax=ax,
-)
+g = sns.boxplot(data=results, x="dataset", y="Performance", hue="finetune", ax=ax)
ax.set_xlabel(None)
ax.set_ylim(0.0, 1.0)
mapping_dict = {"True": "Fine-tuned", "False": "Not fine-tuned"}
handles, labels = ax.get_legend_handles_labels()
-ax.legend(handles=handles, labels=[mapping_dict[l] for l in labels])
+ax.legend(handles=handles, labels=[mapping_dict[lab] for lab in labels])
plt.savefig("perso_vs_non_perso.pdf", dpi=100, bbox_inches="tight")
diff --git a/flamby/results/plot_results.py b/flamby/results/plot_results.py
index bf476324a..41849d2cc 100644
--- a/flamby/results/plot_results.py
+++ b/flamby/results/plot_results.py
@@ -127,7 +127,7 @@ def partial_share_y_axes(axs):
]
# Messing with palettes to keep the same color for pooled and strategies
- current_palette = [palette[0]] + palette[1 : (current_num_clients + 1)]
+ current_palette = [palette[0]] + palette[slice(1, current_num_clients + 1)]
current_palette += palette[7:]
assert len(current_palette) == len(current_methods_display)
# print(current_palette[len(current_methods_display) -1])
diff --git a/flamby/strategies/cyclic.py b/flamby/strategies/cyclic.py
index 86ce60e43..0edae86fd 100644
--- a/flamby/strategies/cyclic.py
+++ b/flamby/strategies/cyclic.py
@@ -145,9 +145,7 @@ def __init__(
self.deterministic_cycle = deterministic_cycle
- self._rng = (
- rng if (rng is not None) else np.random.default_rng(int(time.time()))
- )
+ self._rng = rng if (rng is not None) else np.random.default_rng(int(time.time()))
self._clients = self._shuffle_clients()
self._current_idx = -1
diff --git a/flamby/strategies/scaffold.py b/flamby/strategies/scaffold.py
index 453d6a7fb..b07214125 100644
--- a/flamby/strategies/scaffold.py
+++ b/flamby/strategies/scaffold.py
@@ -124,20 +124,14 @@ def __init__(
]
# initialize the corrections used by each client to 0s.
self.client_corrections_state_list = [
- [
- torch.zeros_like(torch.from_numpy(p))
- for p in _model._get_current_params()
- ]
+ [torch.zeros_like(torch.from_numpy(p)) for p in _model._get_current_params()]
for _model in self.models_list
]
self.client_lr = learning_rate
self.server_lr = server_learning_rate
def _local_optimization(
- self,
- _model: _Model,
- dataloader_with_memory,
- correction_state: List,
+ self, _model: _Model, dataloader_with_memory, correction_state: List
):
"""Carry out the local optimization step.
diff --git a/flamby/strategies/utils.py b/flamby/strategies/utils.py
index 4e54862b7..4f552cd86 100644
--- a/flamby/strategies/utils.py
+++ b/flamby/strategies/utils.py
@@ -270,9 +270,7 @@ def _prox_local_train(self, dataloader_with_memory, num_updates, mu):
_loss = _prox_loss.detach()
if mu > 0.0:
- squared_norm = compute_model_diff_squared_norm(
- model_initial, self.model
- )
+ squared_norm = compute_model_diff_squared_norm(model_initial, self.model)
_prox_loss += mu / 2 * squared_norm
# Backpropagation
diff --git a/integration/FedML/fedml_utils.py b/integration/FedML/fedml_utils.py
index 73734a855..1ecb78eba 100644
--- a/integration/FedML/fedml_utils.py
+++ b/integration/FedML/fedml_utils.py
@@ -95,11 +95,7 @@ def test(self, test_data, device, args):
def _test(self, test_data, device):
logging.info("Evaluating on Trainer ID: {}".format(self.id))
- test_metrics = {
- "test_correct": 0,
- "test_total": 0,
- "test_loss": 0,
- }
+ test_metrics = {"test_correct": 0, "test_total": 0, "test_loss": 0}
if not test_data:
logging.info("No test data for this trainer")
diff --git a/tests/benchmarks/test_fed_benchmark.py b/tests/benchmarks/test_fed_benchmark.py
index 1a5ae9d7a..ff758306f 100644
--- a/tests/benchmarks/test_fed_benchmark.py
+++ b/tests/benchmarks/test_fed_benchmark.py
@@ -32,9 +32,7 @@ def assert_dfs_equal(pair0, pair1, ignore_columns=[]):
ignore_columns : list, optional
The columns that comparison should exclude, by default []
"""
- ignore_columns = [
- col for col in ignore_columns if (col in pair0) and (col in pair1)
- ]
+ ignore_columns = [col for col in ignore_columns if (col in pair0) and (col in pair1)]
df1 = pair0.drop(columns=ignore_columns).fillna("-9")
df2 = pair1.drop(columns=ignore_columns).fillna("-9")[df1.columns]
assert (
diff --git a/tests/strategies/test_fed_avg.py b/tests/strategies/test_fed_avg.py
index e5c409b46..cd97e75e5 100644
--- a/tests/strategies/test_fed_avg.py
+++ b/tests/strategies/test_fed_avg.py
@@ -111,10 +111,7 @@ def test_fedavg_Isic():
]
)
test_aug = albumentations.Compose(
- [
- albumentations.CenterCrop(sz, sz),
- albumentations.Normalize(always_apply=True),
- ]
+ [albumentations.CenterCrop(sz, sz), albumentations.Normalize(always_apply=True)]
)
training_dls = [
dl(
diff --git a/tests/strategies/test_fed_prox.py b/tests/strategies/test_fed_prox.py
index 26982cc32..4483cdeb8 100644
--- a/tests/strategies/test_fed_prox.py
+++ b/tests/strategies/test_fed_prox.py
@@ -69,9 +69,7 @@ def test_fed_prox_integration(n_clients):
mu = 0.1
optimizer_class = torch.optim.Adam
- s = FedProx(
- train_dataloader, m, loss, optimizer_class, lr, num_updates, nrounds, mu
- )
+ s = FedProx(train_dataloader, m, loss, optimizer_class, lr, num_updates, nrounds, mu)
m = s.run()
def accuracy(y_true, y_pred):
diff --git a/tests/strategies/test_scaffold.py b/tests/strategies/test_scaffold.py
index 8e019efd0..08f4c55aa 100644
--- a/tests/strategies/test_scaffold.py
+++ b/tests/strategies/test_scaffold.py
@@ -83,10 +83,7 @@ def accuracy(y_true, y_pred):
cleanup()
-@pytest.mark.parametrize(
- "seed, lr",
- [(42, 0.01), (43, 0.001), (44, 0.0001), (45, 7e-5)],
-)
+@pytest.mark.parametrize("seed, lr", [(42, 0.01), (43, 0.001), (44, 0.0001), (45, 7e-5)])
def test_scaffold_algorithm(seed, lr):
r"""Scaffold should add a correction term in each of its update step.
In the first round, this correction step is 0. In each subsequent round,