From f1e05644c94ee26bdd1d2b9fc3431f4a5e6b2a34 Mon Sep 17 00:00:00 2001 From: jeandut Date: Fri, 18 Nov 2022 15:07:42 +0100 Subject: [PATCH] Trying to fix both tests and linting (#258) * trying to fix both tests and linting * fixing flake8 versiion * trying to fix flake8 * fixing linting errors * fixing linting CI * fixing linting CI * fixing linting CI * fixing linting CI * fixing linting CI * fixing linting CI * fixing linting CI * fixing linting CI * fixing linting CI * fixing linting CI * fix linting * fix linting * trying to make flake8 work --- .github/workflows/linting.yml | 7 ++- .github/workflows/pr_validation.yml | 2 + docs/conf.py | 4 +- flamby/benchmarks/benchmark_utils.py | 4 +- flamby/benchmarks/conf.py | 3 +- flamby/benchmarks/fed_benchmark.py | 10 +--- flamby/create_dataset_config.py | 5 +- flamby/datasets/fed_camelyon16/dataset.py | 3 +- .../dataset_creation_scripts/download.py | 16 +++--- .../dataset_creation_scripts/tiling_slides.py | 4 +- flamby/datasets/fed_camelyon16/model.py | 6 +-- flamby/datasets/fed_dummy_dataset.py | 10 ++-- flamby/datasets/fed_heart_disease/dataset.py | 14 ++--- .../dataset_creation_scripts/download.py | 11 ++-- flamby/datasets/fed_heart_disease/metric.py | 1 - flamby/datasets/fed_isic2019/benchmark.py | 24 ++------- .../dataset_creation_scripts/download_isic.py | 12 ++--- .../fed_isic2019/heterogeneity_pic.py | 15 ++---- .../dataset_creation_scripts/download.py | 14 +++-- flamby/datasets/fed_ixi/ixi_plotting.py | 6 +-- flamby/datasets/fed_ixi/loss.py | 8 ++- flamby/datasets/fed_ixi/metric.py | 2 - flamby/datasets/fed_ixi/model.py | 32 +++-------- flamby/datasets/fed_ixi/utils.py | 54 ++++++++++++------- flamby/datasets/fed_kits19/benchmark.py | 42 +++------------ flamby/datasets/fed_kits19/dataset.py | 37 +++++-------- .../dataset_creation_scripts/create_config.py | 2 +- .../kits19_heterogenity_plot.py | 2 +- .../parsing_and_adding_metadata.py | 18 +++---- .../run_nnUnet_plan_and_preprocess.py | 14 ++--- .../utils/__init__.py | 2 + .../utils/data_augmentations.py | 15 +++--- flamby/datasets/fed_kits19/metric.py | 23 +++++--- flamby/datasets/fed_kits19/model.py | 3 +- flamby/datasets/fed_lidc_idri/__init__.py | 5 +- flamby/datasets/fed_lidc_idri/benchmark.py | 9 +--- flamby/datasets/fed_lidc_idri/data_utils.py | 31 +++-------- .../download_ct_scans.py | 4 +- .../dataset_creation_scripts/process_raw.py | 4 +- .../dataset_creation_scripts/tciaclient.py | 15 +++--- .../fed_lidc_idri/lidc_heterogeneity_plot.py | 7 +-- .../dataset_creation_scripts/download.py | 5 +- .../fed_synthetic/synthetic_generator.py | 4 +- flamby/datasets/fed_tcga_brca/benchmark.py | 14 ++--- flamby/datasets/fed_tcga_brca/dataset.py | 5 +- flamby/datasets/split_utils.py | 6 +-- flamby/extract_config.py | 4 +- .../plot_perso_results.py | 22 +------- flamby/results/plot_results.py | 2 +- flamby/strategies/cyclic.py | 4 +- flamby/strategies/scaffold.py | 10 +--- flamby/strategies/utils.py | 4 +- integration/FedML/fedml_utils.py | 6 +-- tests/benchmarks/test_fed_benchmark.py | 4 +- tests/strategies/test_fed_avg.py | 5 +- tests/strategies/test_fed_prox.py | 4 +- tests/strategies/test_scaffold.py | 5 +- 57 files changed, 210 insertions(+), 394 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 0772bd138..3fa23b877 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -23,9 +23,14 @@ jobs: - name: Install dependencies run: pip install isort black==22.3.0 + pip install flake8 - name: Run black - run: black --check . + run: black --line-length=89 --check . + + + - name: Run FLAKE8 + run: flake8 --max-line-length=89 --per-file-ignores="*/__init__.py:F401" ./flamby - name: Run isort run: isort . diff --git a/.github/workflows/pr_validation.yml b/.github/workflows/pr_validation.yml index ba1ae9385..4ec3eb1a1 100644 --- a/.github/workflows/pr_validation.yml +++ b/.github/workflows/pr_validation.yml @@ -15,6 +15,8 @@ jobs: steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 + with: + python-version: '3.10' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/docs/conf.py b/docs/conf.py index 4a2ceafd3..3bd1ad39d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -175,7 +175,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, "FLamby.tex", "FLamby Documentation", "Collaboration", "manual"), + (master_doc, "FLamby.tex", "FLamby Documentation", "Collaboration", "manual") ] @@ -200,7 +200,7 @@ "FLamby", "One line description of project.", "Miscellaneous", - ), + ) ] diff --git a/flamby/benchmarks/benchmark_utils.py b/flamby/benchmarks/benchmark_utils.py index a5f888fa9..f9e46f2d8 100644 --- a/flamby/benchmarks/benchmark_utils.py +++ b/flamby/benchmarks/benchmark_utils.py @@ -568,9 +568,7 @@ def ensemble_perf_from_predictions( return ensemble_perf -def set_dataset_specific_config( - dataset_name, compute_ensemble_perf=False, use_gpu=True -): +def set_dataset_specific_config(dataset_name, compute_ensemble_perf=False, use_gpu=True): """_summary_ Parameters diff --git a/flamby/benchmarks/conf.py b/flamby/benchmarks/conf.py index b35cf8b7f..0eaab9f20 100644 --- a/flamby/benchmarks/conf.py +++ b/flamby/benchmarks/conf.py @@ -85,8 +85,7 @@ def get_dataset_args( for param in params: try: p = getattr( - __import__(f"flamby.datasets.{dataset_name}", fromlist=param), - param, + __import__(f"flamby.datasets.{dataset_name}", fromlist=param), param ) except AttributeError: p = None diff --git a/flamby/benchmarks/fed_benchmark.py b/flamby/benchmarks/fed_benchmark.py index e4b285d15..852762a19 100644 --- a/flamby/benchmarks/fed_benchmark.py +++ b/flamby/benchmarks/fed_benchmark.py @@ -464,10 +464,7 @@ def main(args_cli): parser = argparse.ArgumentParser() parser.add_argument( - "--GPU", - type=int, - default=0, - help="GPU to run the training on (if available)", + "--GPU", type=int, default=0, help="GPU to run the training on (if available)" ) parser.add_argument( "--cpu-only", @@ -488,10 +485,7 @@ def main(args_cli): help="Do 0 round and 0 epoch to check if the script is working", ) parser.add_argument( - "--workers", - type=int, - default=0, - help="Numbers of workers for the dataloader", + "--workers", type=int, default=0, help="Numbers of workers for the dataloader" ) parser.add_argument( "--learning_rate", diff --git a/flamby/create_dataset_config.py b/flamby/create_dataset_config.py index 4d8d33f2c..1d10db829 100644 --- a/flamby/create_dataset_config.py +++ b/flamby/create_dataset_config.py @@ -5,10 +5,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--path", - type=str, - help="The path where the dataset is located", - required=True, + "--path", type=str, help="The path where the dataset is located", required=True ) parser.add_argument( "--dataset-name", diff --git a/flamby/datasets/fed_camelyon16/dataset.py b/flamby/datasets/fed_camelyon16/dataset.py index b658ee4ad..587bca18a 100644 --- a/flamby/datasets/fed_camelyon16/dataset.py +++ b/flamby/datasets/fed_camelyon16/dataset.py @@ -86,7 +86,8 @@ def __init__( self.features_centers = [] self.features_sets = [] self.perms = {} - # We need this ist to be sorted for reproducibility but shuffled to avoid weirdness + # We need this list to be sorted for reproducibility but shuffled to + # avoid weirdness npys_list = sorted(self.tiles_dir.glob("*.npy")) random.seed(0) random.shuffle(npys_list) diff --git a/flamby/datasets/fed_camelyon16/dataset_creation_scripts/download.py b/flamby/datasets/fed_camelyon16/dataset_creation_scripts/download.py index 68483e887..5a874df5f 100644 --- a/flamby/datasets/fed_camelyon16/dataset_creation_scripts/download.py +++ b/flamby/datasets/fed_camelyon16/dataset_creation_scripts/download.py @@ -5,6 +5,7 @@ import sys from pathlib import Path +import numpy as np import pandas as pd from google_client import create_service from googleapiclient.errors import HttpError @@ -65,12 +66,11 @@ def main(path_to_secret, output_folder, port=6006, debug=False): len(train_df.index) + len(test_df.index) ) downloaded_images_status_file["Slide"] = None - downloaded_images_status_file.Slide.iloc[: len(train_df.index)] = train_df[ - "name" - ] - downloaded_images_status_file.Slide.iloc[len(train_df.index) :] = test_df[ - "name" - ] + total_size = len(train_df.index) + len(test_df.index) + train_idxs = np.arange(0, len(train_df.index)) + test_idxs = np.arange(len(train_df.index), total_size) + downloaded_images_status_file.Slide.iloc[train_idxs] = train_df["name"] + downloaded_images_status_file.Slide.iloc[test_idxs] = test_df["name"] downloaded_images_status_file.to_csv( downloaded_images_status_file_path, index=False ) @@ -92,7 +92,9 @@ def main(path_to_secret, output_folder, port=6006, debug=False): port=port, ) regex = "(?<=https://drive.google.com/file/d/)[a-zA-Z0-9]+" - # Resourcekey is now mandatory (credit @Kris in: https://stackoverflow.com/questions/71343002/downloading-files-from-public-google-drive-in-python-scoping-issues) + # Resourcekey is now mandatory (credit @Kris in: + # https://stackoverflow.com/questions/71343002/ + # downloading-files-from-public-google-drive-in-python-scoping-issues) regex_rkey = "(?<=resourcekey=).+" for current_df in [train_df, test_df]: for i in tqdm(range(len(current_df.index))): diff --git a/flamby/datasets/fed_camelyon16/dataset_creation_scripts/tiling_slides.py b/flamby/datasets/fed_camelyon16/dataset_creation_scripts/tiling_slides.py index bae07cd0a..0acdce2dd 100644 --- a/flamby/datasets/fed_camelyon16/dataset_creation_scripts/tiling_slides.py +++ b/flamby/datasets/fed_camelyon16/dataset_creation_scripts/tiling_slides.py @@ -52,9 +52,7 @@ def __len__(self): def __getitem__(self, idx): pil_image = self.slide.read_region( - self.coords[idx].astype("int_"), - self.level, - (self.tile_size, self.tile_size), + self.coords[idx].astype("int_"), self.level, (self.tile_size, self.tile_size) ).convert("RGB") if self.transform is not None: pil_image = self.transform(pil_image) diff --git a/flamby/datasets/fed_camelyon16/model.py b/flamby/datasets/fed_camelyon16/model.py index 1ce7400a4..73d328b11 100644 --- a/flamby/datasets/fed_camelyon16/model.py +++ b/flamby/datasets/fed_camelyon16/model.py @@ -10,15 +10,13 @@ class Baseline(nn.Module): def __init__(self): super(Baseline, self).__init__() # As per the article - self.O = 2048 # Original dimension of the input embeddings + self.Od = 2048 # Original dimension of the input embeddings self.M = 128 # New dimension of the input embedding self.L = 128 # Dimension of the new features after query and value projections self.K = 1000 # Number of elements in each bag - self.feature_extractor_part1 = nn.Sequential( - nn.Linear(self.O, self.M), - ) + self.feature_extractor_part1 = nn.Sequential(nn.Linear(self.Od, self.M)) # The Gated Attention using tanh and sigmoid from Eq 9 # from https://arxiv.org/abs/1802.04712 diff --git a/flamby/datasets/fed_dummy_dataset.py b/flamby/datasets/fed_dummy_dataset.py index fcb958a73..ffe4e356a 100644 --- a/flamby/datasets/fed_dummy_dataset.py +++ b/flamby/datasets/fed_dummy_dataset.py @@ -24,8 +24,9 @@ def __len__(self): return self.size def __getitem__(self, idx): - return torch.rand(3, 224, 224).to(self.X_dtype), torch.randint(0, 2, (1,)).to( - self.y_dtype + return ( + torch.rand(3, 224, 224).to(self.X_dtype), + torch.randint(0, 2, (1,)).to(self.y_dtype), ) @@ -53,10 +54,7 @@ def forward(self, X): m = Baseline() lo = BaselineLoss() dl = DataLoader( - FedDummyDataset(center=1, train=True), - batch_size=32, - shuffle=True, - num_workers=0, + FedDummyDataset(center=1, train=True), batch_size=32, shuffle=True, num_workers=0 ) it = iter(dl) X, y = next(it) diff --git a/flamby/datasets/fed_heart_disease/dataset.py b/flamby/datasets/fed_heart_disease/dataset.py index f7dc5cff2..62c367b2a 100644 --- a/flamby/datasets/fed_heart_disease/dataset.py +++ b/flamby/datasets/fed_heart_disease/dataset.py @@ -73,12 +73,7 @@ def __init__( self.y_dtype = y_dtype self.debug = debug - self.centers_number = { - "cleveland": 0, - "hungarian": 1, - "switzerland": 2, - "va": 3, - } + self.centers_number = {"cleveland": 0, "hungarian": 1, "switzerland": 2, "va": 3} self.features = pd.DataFrame() self.labels = pd.DataFrame() @@ -165,9 +160,7 @@ def __init__( } # We finally broadcast the means and stds over all datasets - self.mean_of_features = torch.zeros( - (len(self.features), 13), dtype=self.X_dtype - ) + self.mean_of_features = torch.zeros((len(self.features), 13), dtype=self.X_dtype) self.std_of_features = torch.ones((len(self.features), 13), dtype=self.X_dtype) for i in range(self.mean_of_features.shape[0]): self.mean_of_features[i] = self.centers_stats[self.centers[i]]["mean"] @@ -177,8 +170,7 @@ def __init__( to_select = [(self.sets[idx] == "train") for idx, _ in enumerate(self.features)] features_train = [fp for idx, fp in enumerate(self.features) if to_select[idx]] features_tensor_train = torch.cat( - [features_train[i][None, :] for i in range(len(features_train))], - axis=0, + [features_train[i][None, :] for i in range(len(features_train))], axis=0 ) self.mean_of_features_pooled_train = features_tensor_train.mean(axis=0) self.std_of_features_pooled_train = features_tensor_train.std(axis=0) diff --git a/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py b/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py index 8014a8305..db249ba34 100755 --- a/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py +++ b/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py @@ -20,12 +20,9 @@ def main(output_folder, debug=False): # location of the files in the UCI archive accept_license( - "https://archive-beta.ics.uci.edu/ml/datasets/heart+disease", - "fed_heart_disease", - ) - base_url = ( - "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/" + "https://archive-beta.ics.uci.edu/ml/datasets/heart+disease", "fed_heart_disease" ) + base_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/" centers = ["cleveland", "hungarian", "switzerland", "va"] md5_hashes = [ "2d91a8ff69cfd9616aa47b59d6f843db", @@ -69,9 +66,7 @@ def main(output_folder, debug=False): sys.exit() # get status of download - downloaded_status_file_path = os.path.join( - output_folder, "download_status_file.csv" - ) + downloaded_status_file_path = os.path.join(output_folder, "download_status_file.csv") if not (os.path.exists(downloaded_status_file_path)): downloaded_status_file = pd.DataFrame() downloaded_status_file["Status"] = ["Not found"] * 4 diff --git a/flamby/datasets/fed_heart_disease/metric.py b/flamby/datasets/fed_heart_disease/metric.py index bd13ee11a..885d90bc3 100644 --- a/flamby/datasets/fed_heart_disease/metric.py +++ b/flamby/datasets/fed_heart_disease/metric.py @@ -1,5 +1,4 @@ import numpy as np -from sklearn.metrics import roc_auc_score def metric(y_true, y_pred): diff --git a/flamby/datasets/fed_isic2019/benchmark.py b/flamby/datasets/fed_isic2019/benchmark.py index e34c42f1e..7bf34b710 100644 --- a/flamby/datasets/fed_isic2019/benchmark.py +++ b/flamby/datasets/fed_isic2019/benchmark.py @@ -22,14 +22,7 @@ def train_model( - model, - optimizer, - scheduler, - dataloaders, - dataset_sizes, - device, - lossfunc, - num_epochs, + model, optimizer, scheduler, dataloaders, dataset_sizes, device, lossfunc, num_epochs ): """Training function Parameters @@ -224,16 +217,10 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument( - "--GPU", - type=int, - default=0, - help="GPU to run the training on (if available)", + "--GPU", type=int, default=0, help="GPU to run the training on (if available)" ) parser.add_argument( - "--workers", - type=int, - default=4, - help="Numbers of workers for the dataloader", + "--workers", type=int, default=4, help="Numbers of workers for the dataloader" ) args = parser.parse_args() @@ -243,10 +230,7 @@ def main(args): sz = 200 test_aug = albumentations.Compose( - [ - albumentations.CenterCrop(sz, sz), - albumentations.Normalize(always_apply=True), - ] + [albumentations.CenterCrop(sz, sz), albumentations.Normalize(always_apply=True)] ) test_dataset = dataset.FedIsic2019(train=False, pooled=True) test_dataloader = torch.utils.data.DataLoader( diff --git a/flamby/datasets/fed_isic2019/dataset_creation_scripts/download_isic.py b/flamby/datasets/fed_isic2019/dataset_creation_scripts/download_isic.py index b56e3a05e..247bdd161 100644 --- a/flamby/datasets/fed_isic2019/dataset_creation_scripts/download_isic.py +++ b/flamby/datasets/fed_isic2019/dataset_creation_scripts/download_isic.py @@ -76,23 +76,17 @@ for i, row in ISIC_2019_Training_Metadata.iterrows(): if pd.isnull(row["lesion_id"]): image = row["image"] - os.system( - "rm " + data_directory + "/ISIC_2019_Training_Input/" + image + ".jpg" - ) + os.system("rm " + data_directory + "/ISIC_2019_Training_Input/" + image + ".jpg") if image != ISIC_2019_Training_GroundTruth["image"][i]: print("Mismatch between Metadata and Ground Truth") ISIC_2019_Training_GroundTruth = ISIC_2019_Training_GroundTruth.drop(i) ISIC_2019_Training_Metadata = ISIC_2019_Training_Metadata.drop(i) # generating dataset field from lesion_id field in the metadata dataframe -ISIC_2019_Training_Metadata["dataset"] = ISIC_2019_Training_Metadata["lesion_id"].str[ - :4 -] +ISIC_2019_Training_Metadata["dataset"] = ISIC_2019_Training_Metadata["lesion_id"].str[:4] # join with HAM10000 metadata in order to expand the HAM datacenters -result = pd.merge( - ISIC_2019_Training_Metadata, HAM10000_metadata, how="left", on="image" -) +result = pd.merge(ISIC_2019_Training_Metadata, HAM10000_metadata, how="left", on="image") result["dataset"] = result["dataset_x"] + result["dataset_y"].astype(str) result.drop(["dataset_x", "dataset_y", "lesion_id"], axis=1, inplace=True) diff --git a/flamby/datasets/fed_isic2019/heterogeneity_pic.py b/flamby/datasets/fed_isic2019/heterogeneity_pic.py index d91818f8d..5edfa8fde 100644 --- a/flamby/datasets/fed_isic2019/heterogeneity_pic.py +++ b/flamby/datasets/fed_isic2019/heterogeneity_pic.py @@ -42,22 +42,13 @@ def forward(self, image): parser = argparse.ArgumentParser() parser.add_argument( - "--GPU", - type=int, - default=0, - help="GPU to run the training on (if available)", + "--GPU", type=int, default=0, help="GPU to run the training on (if available)" ) parser.add_argument( - "--workers", - type=int, - default=0, - help="Numbers of workers for the dataloader", + "--workers", type=int, default=0, help="Numbers of workers for the dataloader" ) parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for the UMPA and dataloading", + "--seed", type=int, default=42, help="The seed for the UMPA and dataloading" ) args = parser.parse_args() np.random.seed(args.seed) diff --git a/flamby/datasets/fed_ixi/dataset_creation_scripts/download.py b/flamby/datasets/fed_ixi/dataset_creation_scripts/download.py index 74c2ccca1..fd6c041de 100644 --- a/flamby/datasets/fed_ixi/dataset_creation_scripts/download.py +++ b/flamby/datasets/fed_ixi/dataset_creation_scripts/download.py @@ -21,10 +21,16 @@ def dl_ixi_tiny(output_folder, debug=False): The folder where to download the dataset. """ print( - "The IXI dataset is made available under the Creative Commons CC BY-SA 3.0 license.\n\ - If you use the IXI data please acknowledge the source of the IXI data, e.g. the following website: https://brain-development.org/ixi-dataset/\n\ - IXI Tiny is derived from the same source. Acknowledge the following reference on TorchIO : https://torchio.readthedocs.io/datasets.html#ixitiny\n\ - Pérez-García F, Sparks R, Ourselin S. TorchIO: a Python library for efficient loading, preprocessing, augmentation and patch-based sampling of medical images in deep learning. arXiv:2003.04696 [cs, eess, stat]. 2020. https://doi.org/10.48550/arXiv.2003.04696" + "The IXI dataset is made available under the Creative Commons CC BY-SA \ + 3.0 license.\n\ + If you use the IXI data please acknowledge the source of the IXI data, e.g.\ + the following website: https://brain-development.org/ixi-dataset/\ + IXI Tiny is derived from the same source. Acknowledge the following reference\ + on TorchIO : https://torchio.readthedocs.io/datasets.html#ixitiny\ + Pérez-García F, Sparks R, Ourselin S. TorchIO: a Python library for \ + efficient loading, preprocessing, augmentation and patch-based sampling \ + of medical images in deep learning. arXiv:2003.04696 [cs, eess, stat]. \ + 2020. https://doi.org/10.48550/arXiv.2003.04696" ) accept_license("https://brain-development.org/ixi-dataset/", "fed_ixi") os.makedirs(output_folder, exist_ok=True) diff --git a/flamby/datasets/fed_ixi/ixi_plotting.py b/flamby/datasets/fed_ixi/ixi_plotting.py index 3ed60c496..e791f5efd 100644 --- a/flamby/datasets/fed_ixi/ixi_plotting.py +++ b/flamby/datasets/fed_ixi/ixi_plotting.py @@ -53,11 +53,7 @@ def plot_histogram(axis, array, num_positions=100, label=None, alpha=0.05, color [0], [0], color="r", lw=4, label="Client {}".format(CENTER_LABELS_CORRESP["HH"]) ), Line2D( - [0], - [0], - color="b", - lw=4, - label="Client {}".format(CENTER_LABELS_CORRESP["IOP"]), + [0], [0], color="b", lw=4, label="Client {}".format(CENTER_LABELS_CORRESP["IOP"]) ), ] ax.legend(handles=legend_elements, loc="upper right") diff --git a/flamby/datasets/fed_ixi/loss.py b/flamby/datasets/fed_ixi/loss.py index 8a3cea80d..daafb82b3 100644 --- a/flamby/datasets/fed_ixi/loss.py +++ b/flamby/datasets/fed_ixi/loss.py @@ -7,7 +7,9 @@ def __init__(self): super(BaselineLoss, self).__init__() def forward(self, output: torch.Tensor, target: torch.Tensor): - """Get dice loss to evaluate the semantic segmentation model. Its value lies between 0 and 1. The more the loss is close to 0, the more the performance is good. + """Get dice loss to evaluate the semantic segmentation model. + Its value lies between 0 and 1. The more the loss is close to 0, + the more the performance is good. Parameters ---------- @@ -26,7 +28,9 @@ def forward(self, output: torch.Tensor, target: torch.Tensor): def get_dice_score(output, target, epsilon=1e-9): - """Get dice score to evaluate the semantic segmentation model. Its value lies between 0 and 1. The more the score is close to 1, the more the performance is good. + """Get dice score to evaluate the semantic segmentation model. + Its value lies between 0 and 1. The more the score is close to 1, + the more the performance is good. Parameters ---------- diff --git a/flamby/datasets/fed_ixi/metric.py b/flamby/datasets/fed_ixi/metric.py index 32ee2aa5f..005759d5d 100644 --- a/flamby/datasets/fed_ixi/metric.py +++ b/flamby/datasets/fed_ixi/metric.py @@ -1,6 +1,4 @@ import numpy as np -import torch -import torch.nn.functional as F def metric(y_true, y_pred): diff --git a/flamby/datasets/fed_ixi/model.py b/flamby/datasets/fed_ixi/model.py index db90ee1b3..f2e34ceaf 100644 --- a/flamby/datasets/fed_ixi/model.py +++ b/flamby/datasets/fed_ixi/model.py @@ -12,7 +12,7 @@ Requires: Python >=3.6 """ -#### UNet +# UNet class Baseline(nn.Module): @@ -117,11 +117,7 @@ def __init__( elif dimensions == 3: in_channels = 2 * out_channels_first_layer self.classifier = ConvolutionalBlock( - dimensions, - in_channels, - out_classes, - kernel_size=1, - activation=None, + dimensions, in_channels, out_classes, kernel_size=1, activation=None ) def forward(self, x): @@ -135,7 +131,7 @@ def forward(self, x): return x -#### Conv #### +# Conv class ConvolutionalBlock(nn.Module): @@ -218,16 +214,10 @@ def add_if_not_none(module_list, module): module_list.append(module) -#### Decoding #### +# Decoding CHANNELS_DIMENSION = 1 -UPSAMPLING_MODES = ( - "nearest", - "linear", - "bilinear", - "bicubic", - "trilinear", -) +UPSAMPLING_MODES = ("nearest", "linear", "bilinear", "bicubic", "trilinear") class Decoder(nn.Module): @@ -371,11 +361,7 @@ def get_upsampling_layer(upsampling_type: str) -> nn.Upsample: message = 'Upsampling type is "{}"' " but should be one of the following: {}" message = message.format(upsampling_type, UPSAMPLING_MODES) raise ValueError(message) - upsample = nn.Upsample( - scale_factor=2, - mode=upsampling_type, - align_corners=False, - ) + upsample = nn.Upsample(scale_factor=2, mode=upsampling_type, align_corners=False) return upsample @@ -395,7 +381,7 @@ def fix_upsampling_type(upsampling_type: str, dimensions: int): return upsampling_type -#### Encoding #### +# Encoding class Encoder(nn.Module): @@ -555,9 +541,7 @@ def out_channels(self): def get_downsampling_layer( - dimensions: int, - pooling_type: str, - kernel_size: int = 2, + dimensions: int, pooling_type: str, kernel_size: int = 2 ) -> nn.Module: class_name = "{}Pool{}d".format(pooling_type.capitalize(), dimensions) class_ = getattr(nn, class_name) diff --git a/flamby/datasets/fed_ixi/utils.py b/flamby/datasets/fed_ixi/utils.py index 5ee73b248..6229daeae 100644 --- a/flamby/datasets/fed_ixi/utils.py +++ b/flamby/datasets/fed_ixi/utils.py @@ -1,6 +1,6 @@ """Federated IXI Dataset utils - -A set of function that allow data management suited for the `IXI dataset `_. +A set of function that allow data management suited for the `IXI dataset +`_. """ @@ -14,11 +14,8 @@ from zipfile import ZipFile import nibabel as nib -import nibabel.processing as processing import numpy import numpy as np -from nibabel import Nifti1Header -from numpy import ndarray def _get_id_from_filename(x, verify_single_matches=True) -> Union[List[int], int]: @@ -55,7 +52,8 @@ def _get_id_from_filename(x, verify_single_matches=True) -> Union[List[int], int def _assembly_nifti_filename_regex( patient_id: int, modality: str ) -> Union[str, PathLike, Path]: - """Assembles NIFTI filename regular expression using the standard in the IXI dataset based on id and modality. + """Assembles NIFTI filename regular expression using the standard in the + IXI dataset based on id and modality. Parameters ---------- @@ -78,7 +76,8 @@ def _assembly_nifti_filename_regex( def _assembly_nifti_img_and_label_regex( patient_id: int, modality: str ) -> Tuple[Union[str, PathLike, Path], Union[str, PathLike, Path]]: - """Assembles NIFTI filename regular expression for image and label using the standard in the IXI tiny dataset based on id and modality. + """Assembles NIFTI filename regular expression for image and label using + the standard in the IXI tiny dataset based on id and modality. Parameters ---------- @@ -104,7 +103,8 @@ def _assembly_nifti_img_and_label_regex( def _find_file_in_tar(tar_file: TarFile, patient_id: int, modality) -> str: - """Searches the file in a TAR file that corresponds to a particular regular expression. + """Searches the file in a TAR file that corresponds to a particular + regular expression. Parameters ---------- @@ -134,7 +134,8 @@ def _find_file_in_tar(tar_file: TarFile, patient_id: int, modality) -> str: def _find_files_in_zip(zip_file: ZipFile, patient_id: int, modality) -> Tuple[str]: - """Searches the files in a ZIP file that corresponds to particular regular expressions. + """Searches the files in a ZIP file that corresponds to particular regular + expressions. Parameters ---------- @@ -165,7 +166,7 @@ def _find_files_in_zip(zip_file: ZipFile, patient_id: int, modality) -> Tuple[st result[1] = re.match(regex_label, filename).group() except AttributeError: continue - if result[0] != None and result[1] != None: + if result[0] is not None and result[1] is not None: return tuple(result) raise FileNotFoundError( @@ -176,7 +177,8 @@ def _find_files_in_zip(zip_file: ZipFile, patient_id: int, modality) -> Tuple[st def _extract_center_name_from_filename(filename: str): """Extracts center name from file dataset. - Unfortunately, IXI has the center encoded in the namefile rather than in the demographics information. + Unfortunately, IXI has the center encoded in the namefile rather than in + the demographics information. Parameters ---------- @@ -189,7 +191,8 @@ def _extract_center_name_from_filename(filename: str): Name of the center where the data comes from (e.g. Guys for the previous example) """ - # We decided to wrap a function for this for clarity and easier modularity for future expansion + # We decided to wrap a function for this for clarity and easier modularity + # for future expansion return filename.split("-")[1] @@ -201,7 +204,8 @@ def _load_nifti_image_by_id( Parameters ---------- tar_file : TarFile - `TarFile `_ object + `TarFile + `_ object patient_id : int Patient's ID whose image is to be extracted. modality : str @@ -214,7 +218,8 @@ def _load_nifti_image_by_id( img : ndarray NumPy array containing the intensities of the voxels. center_name : str - Name of the center the file comes from. In IXI this is encoded only in the filename. + Name of the center the file comes from. In IXI this is encoded only + in the filename. """ filename = _find_file_in_tar(tar_file, patient_id, modality) with tempfile.TemporaryDirectory() as td: @@ -237,7 +242,8 @@ def _load_nifti_image_and_label_by_id( Parameters ---------- zip_file : ZipFile - `ZipFile `_ object + `ZipFile + `_ object patient_id : int Patient's ID whose image is to be extracted. modality : str @@ -252,7 +258,8 @@ def _load_nifti_image_and_label_by_id( label : ndarray NumPy array containing the intensities of the voxels. center_name : str - Name of the center the file comes from. In IXI this is encoded only in the filename. + Name of the center the file comes from. In IXI this is encoded only in + the filename. """ img_filename, label_filename = _find_files_in_zip(zip_file, patient_id, modality) with tempfile.TemporaryDirectory() as td: @@ -313,11 +320,20 @@ def _create_train_test_split( Returns ------- train_test_hh : list - A list containing randomly generated dichotomous values. The size is the number of images from HH hospital. Dichotomous values (train and test) follow a train test split threshold (e. g. 70%). + A list containing randomly generated dichotomous values. + The size is the number of images from HH hospital. + Dichotomous values (train and test) follow a train test split + threshold (e. g. 70%). train_test_guys : list - A list containing randomly generated dichotomous values. The size is the number of images from Guys hospital. Dichotomous values (train and test) follow a train test split threshold (e. g. 70%). + A list containing randomly generated dichotomous values. + The size is the number of images from Guys hospital. + Dichotomous values (train and test) follow a train test split + threshold (e. g. 70%). train_test_iop : list - A list containing randomly generated dichotomous values. The size is the number of images from IOP hospital. Dichotomous values (train and test) follow a train test split threshold (e. g. 70%). + A list containing randomly generated dichotomous values. + The size is the number of images from IOP hospital. + Dichotomous values (train and test) follow a train test split + threshold (e. g. 70%). """ split_ratio = 0.8 diff --git a/flamby/datasets/fed_kits19/benchmark.py b/flamby/datasets/fed_kits19/benchmark.py index 45f1c94d0..43ea8a154 100644 --- a/flamby/datasets/fed_kits19/benchmark.py +++ b/flamby/datasets/fed_kits19/benchmark.py @@ -22,14 +22,7 @@ def train_model( - model, - optimizer, - scheduler, - dataloaders, - dataset_sizes, - device, - lossfunc, - num_epochs, + model, optimizer, scheduler, dataloaders, dataset_sizes, device, lossfunc, num_epochs ): """Training function Parameters @@ -105,9 +98,7 @@ def train_model( best_model_wts = copy.deepcopy(model.state_dict()) print( - "Training Loss: {:.4f} Validation Acc: {:.4f} ".format( - epoch_loss, epoch_acc - ) + "Training Loss: {:.4f} Validation Acc: {:.4f} ".format(epoch_loss, epoch_acc) ) training_loss_list.append(epoch_loss) training_dice_list.append(epoch_acc) @@ -134,7 +125,7 @@ def main(args): os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU) torch.use_deterministic_algorithms(False) - dict = check_dataset_from_config(dataset_name="fed_kits19", debug=False) + check_dataset_from_config(dataset_name="fed_kits19", debug=False) train_dataset = FedKits19(train=True, pooled=True) train_dataloader = torch.utils.data.DataLoader( @@ -152,7 +143,6 @@ def main(args): dataloaders = {"train": train_dataloader, "test": test_dataloader} dataset_sizes = {"train": len(train_dataset), "test": len(test_dataset)} - # device = torch.device("cuda:"+str(args.GPU) if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("device", device) @@ -167,9 +157,7 @@ def main(args): lossfunc = BaselineLoss() - optimizer = torch.optim.Adam( - model.parameters(), LR, weight_decay=3e-5, amsgrad=True - ) + optimizer = torch.optim.Adam(model.parameters(), LR, weight_decay=3e-5, amsgrad=True) scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", @@ -201,29 +189,15 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument( - "--GPU", - type=int, - default=0, - help="GPU to run the training on (if available)", - ) - parser.add_argument( - "--workers", - type=int, - default=1, - help="Numbers of workers for the dataloader", + "--GPU", type=int, default=0, help="GPU to run the training on (if available)" ) parser.add_argument( - "--epochs", - type=int, - default=NUM_EPOCHS_POOLED, - help="Numbers of Epochs", + "--workers", type=int, default=1, help="Numbers of workers for the dataloader" ) parser.add_argument( - "--seed", - type=int, - default=0, - help="Seed", + "--epochs", type=int, default=NUM_EPOCHS_POOLED, help="Numbers of Epochs" ) + parser.add_argument("--seed", type=int, default=0, help="Seed") args = parser.parse_args() main(args) diff --git a/flamby/datasets/fed_kits19/dataset.py b/flamby/datasets/fed_kits19/dataset.py index 07090fc06..458614994 100644 --- a/flamby/datasets/fed_kits19/dataset.py +++ b/flamby/datasets/fed_kits19/dataset.py @@ -25,7 +25,11 @@ import numpy as np import pandas as pd import torch -from batchgenerators.utilities.file_and_folder_operations import * +from batchgenerators.utilities.file_and_folder_operations import ( + isfile, + join, + load_pickle, +) from nnunet.training.data_augmentation.default_data_augmentation import ( default_3D_augmentation_params, get_patch_size, @@ -33,11 +37,9 @@ from torch.utils.data import Dataset import flamby.datasets.fed_kits19 -from flamby.datasets.fed_kits19.dataset_creation_scripts.utils.data_augmentations import ( - transformations, -) -from flamby.datasets.fed_kits19.dataset_creation_scripts.utils.set_environment_variables import ( +from flamby.datasets.fed_kits19.dataset_creation_scripts.utils import ( set_environment_variables, + transformations, ) from flamby.utils import check_dataset_from_config @@ -132,9 +134,7 @@ def __init__( self.images_path = OrderedDict() for i in self.images: self.images_path[c] = OrderedDict() - self.images_path[c]["data_file"] = join( - self.dataset_directory, "%s.npz" % i - ) + self.images_path[c]["data_file"] = join(self.dataset_directory, "%s.npz" % i) self.images_path[c]["properties_file"] = join( self.dataset_directory, "%s.pkl" % i ) @@ -160,18 +160,10 @@ def __getitem__(self, idx): # randomly oversample the foreground classes if self.oversample_next_sample == 1: self.oversample_next_sample = 0 - item = self.oversample_foreground_class( - case_all_data, - True, - properties, - ) + item = self.oversample_foreground_class(case_all_data, True, properties) else: self.oversample_next_sample = 1 - item = self.oversample_foreground_class( - case_all_data, - False, - properties, - ) + item = self.oversample_foreground_class(case_all_data, False, properties) # apply data augmentations if self.train_test == "train": @@ -181,12 +173,7 @@ def __getitem__(self, idx): return np.squeeze(item["data"], axis=1), np.squeeze(item["target"], axis=1) - def oversample_foreground_class( - self, - case_all_data, - force_fg, - properties, - ): + def oversample_foreground_class(self, case_all_data, force_fg, properties): # taken from nnunet data_shape = (1, 1, *self.patch_size) seg_shape = (1, 1, *self.patch_size) @@ -242,7 +229,7 @@ def oversample_foreground_class( # all selected_class = None voxels_of_that_class = None - print("case does not contain any foreground classes", i) + print("case does not contain any foreground classes") else: selected_class = np.random.choice(foreground_classes) diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/create_config.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/create_config.py index d7c736f6c..802e99e64 100644 --- a/flamby/datasets/fed_kits19/dataset_creation_scripts/create_config.py +++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/create_config.py @@ -1,6 +1,6 @@ import argparse -from flamby.utils import create_config, get_config_file_path, write_value_in_config +from flamby.utils import create_config if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/kits19_heterogenity_plot.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/kits19_heterogenity_plot.py index f87ad4d16..f0a0f20d7 100644 --- a/flamby/datasets/fed_kits19/dataset_creation_scripts/kits19_heterogenity_plot.py +++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/kits19_heterogenity_plot.py @@ -6,7 +6,7 @@ import nibabel as nib import numpy as np import seaborn as sns -from batchgenerators.utilities.file_and_folder_operations import * +from batchgenerators.utilities.file_and_folder_operations import join from matplotlib.lines import Line2D from numpy import random diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/parsing_and_adding_metadata.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/parsing_and_adding_metadata.py index f2407cd28..9b0c96080 100644 --- a/flamby/datasets/fed_kits19/dataset_creation_scripts/parsing_and_adding_metadata.py +++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/parsing_and_adding_metadata.py @@ -1,4 +1,5 @@ -# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# Copyright 2020 Division of Medical Image Computing, German Cancer Research +# Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +32,7 @@ subfolders, ) -from flamby.datasets.fed_kits19.dataset_creation_scripts.utils.set_environment_variables import ( +from flamby.datasets.fed_kits19.dataset_creation_scripts.utils import ( set_environment_variables, ) from flamby.utils import get_config_file_path, read_config, write_value_in_config @@ -103,9 +104,7 @@ def read_csv_file(csv_file="../metadata/anony_sites.csv"): data_length = len(client_data_idxx) train_ids = int(0.8 * data_length) for i in client_data_idxx[:train_ids]: - writer.writerow( - [i, silo_count, "train", "train_" + str(silo_count)] - ) + writer.writerow([i, silo_count, "train", "train_" + str(silo_count)]) for i in client_data_idxx[train_ids:]: writer.writerow([i, silo_count, "test", "test_" + str(silo_count)]) @@ -122,7 +121,8 @@ def read_csv_file(csv_file="../metadata/anony_sites.csv"): if __name__ == "__main__": """ - This is the KiTS dataset after Nick fixed all the labels that had errors. Downloaded on Jan 6th 2020 + This is the KiTS dataset after Nick fixed all the labels that had errors. + Downloaded on Jan 6th 2020 """ # parse python script input parameters @@ -155,7 +155,7 @@ def read_csv_file(csv_file="../metadata/anony_sites.csv"): case_ids, site_ids, unique_hospital_ids, thresholded_ids = read_csv_file() print(thresholded_ids) - if args.debug == True: + if args.debug: train_patients = thresholded_ids[:25] test_patients = all_cases[210:211] # we do not need the test data else: @@ -183,9 +183,7 @@ def read_csv_file(csv_file="../metadata/anony_sites.csv"): json_dict["reference"] = "KiTS data for nnunet_library" json_dict["licence"] = "" json_dict["release"] = "0.0" - json_dict["modality"] = { - "0": "CT", - } + json_dict["modality"] = {"0": "CT"} json_dict["labels"] = {"0": "background", "1": "Kidney", "2": "Tumor"} json_dict["numTraining"] = len(train_patient_names) diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/run_nnUnet_plan_and_preprocess.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/run_nnUnet_plan_and_preprocess.py index be8ef3e14..ceb7ab39b 100644 --- a/flamby/datasets/fed_kits19/dataset_creation_scripts/run_nnUnet_plan_and_preprocess.py +++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/run_nnUnet_plan_and_preprocess.py @@ -1,4 +1,5 @@ -# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# Copyright 2020 Division of Medical Image Computing, German Cancer Research +# Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +18,7 @@ # information import sys -from flamby.datasets.fed_kits19.dataset_creation_scripts.utils.set_environment_variables import ( +from flamby.datasets.fed_kits19.dataset_creation_scripts.utils import ( set_environment_variables, ) from flamby.utils import get_config_file_path, write_value_in_config @@ -47,14 +48,7 @@ for _ in range(2): sys.argv.pop(index_num_threads) - sys.argv = sys.argv + [ - "-t", - "064", - "-tf", - args.num_threads, - "-tl", - args.num_threads, - ] + sys.argv = sys.argv + ["-t", "064", "-tf", args.num_threads, "-tl", args.num_threads] main() path_to_config_file = get_config_file_path("fed_kits19", False) diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/__init__.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/__init__.py index e69de29bb..36dfb231b 100644 --- a/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/__init__.py +++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/__init__.py @@ -0,0 +1,2 @@ +from .data_augmentations import transformations +from .set_environment_variables import set_environment_variables diff --git a/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/data_augmentations.py b/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/data_augmentations.py index 207477eb1..3fc89900b 100644 --- a/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/data_augmentations.py +++ b/flamby/datasets/fed_kits19/dataset_creation_scripts/utils/data_augmentations.py @@ -1,4 +1,5 @@ -# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# Copyright 2020 Division of Medical Image Computing, German Cancer Research +# Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -47,16 +48,11 @@ from batchgenerators.dataloading.nondet_multi_threaded_augmenter import ( NonDetMultiThreadedAugmenter, ) -except ImportError as ie: +except ImportError: NonDetMultiThreadedAugmenter = None -def transformations( - patch_size, - params, - border_val_seg=-1, - regions=None, -): +def transformations(patch_size, params, border_val_seg=-1, regions=None): assert ( params.get("mirror") is None ), "old version of params, use new keyword do_mirror" @@ -72,7 +68,8 @@ def transformations( SegChannelSelectionTransform(params.get("selected_seg_channels")) ) - # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! + # don't do color augmentations while in 2d mode with 3d data because the + # color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] diff --git a/flamby/datasets/fed_kits19/metric.py b/flamby/datasets/fed_kits19/metric.py index 6cb533abe..27b254f1a 100644 --- a/flamby/datasets/fed_kits19/metric.py +++ b/flamby/datasets/fed_kits19/metric.py @@ -3,7 +3,21 @@ import torch.nn.functional as F from tqdm import tqdm -softmax_helper = lambda x: F.softmax(x, 1) + +def softmax_helper(x): + """This function computes the softmax using torch functionnal on the 1-axis. + + Parameters + ---------- + x : torch.Tensor + The input. + + Returns + ------- + torch.Tensor + Output + """ + return F.softmax(x, 1) def Dice_coef(output, target, eps=1e-5): # dice score used for evaluation @@ -25,12 +39,7 @@ def metric(predictions, gt): return (tk_dice + tu_dice) / 2 -def evaluate_dice_on_tests( - model, - test_dataloaders, - metric, - use_gpu=True, -): +def evaluate_dice_on_tests(model, test_dataloaders, metric, use_gpu=True): """This function takes a pytorch model and evaluate it on a list of\ dataloaders using the provided metric function. diff --git a/flamby/datasets/fed_kits19/model.py b/flamby/datasets/fed_kits19/model.py index 5d94d5c3e..1e8fd8b09 100644 --- a/flamby/datasets/fed_kits19/model.py +++ b/flamby/datasets/fed_kits19/model.py @@ -1,4 +1,5 @@ -# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# Copyright 2020 Division of Medical Image Computing, German Cancer Research +# Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/flamby/datasets/fed_lidc_idri/__init__.py b/flamby/datasets/fed_lidc_idri/__init__.py index 864393d05..09a84c341 100644 --- a/flamby/datasets/fed_lidc_idri/__init__.py +++ b/flamby/datasets/fed_lidc_idri/__init__.py @@ -10,8 +10,5 @@ ) from flamby.datasets.fed_lidc_idri.dataset import FedLidcIdri, LidcIdriRaw, collate_fn from flamby.datasets.fed_lidc_idri.loss import BaselineLoss -from flamby.datasets.fed_lidc_idri.metric import ( - evaluate_dice_on_tests_by_chunks, - metric, -) +from flamby.datasets.fed_lidc_idri.metric import evaluate_dice_on_tests_by_chunks, metric from flamby.datasets.fed_lidc_idri.model import Baseline diff --git a/flamby/datasets/fed_lidc_idri/benchmark.py b/flamby/datasets/fed_lidc_idri/benchmark.py index b9a56d23f..21a63b717 100644 --- a/flamby/datasets/fed_lidc_idri/benchmark.py +++ b/flamby/datasets/fed_lidc_idri/benchmark.py @@ -109,9 +109,7 @@ def main(num_workers_torch, use_gpu=True, gpu_id=0, log=False, debug=False): if log: writer.add_scalar( - "Loss/train/client", - tot_loss / num_local_steps_per_epoch, - e, + "Loss/train/client", tot_loss / num_local_steps_per_epoch, e ) # Finally, evaluate DICE @@ -143,10 +141,7 @@ def main(num_workers_torch, use_gpu=True, gpu_id=0, log=False, debug=False): default=20, ) parser.add_argument( - "--gpu-id", - type=int, - default=0, - help="PCI Bus id of the GPU to use.", + "--gpu-id", type=int, default=0, help="PCI Bus id of the GPU to use." ) parser.add_argument( "--cpu-only", diff --git a/flamby/datasets/fed_lidc_idri/data_utils.py b/flamby/datasets/fed_lidc_idri/data_utils.py index 1f651f78a..90e45df9f 100644 --- a/flamby/datasets/fed_lidc_idri/data_utils.py +++ b/flamby/datasets/fed_lidc_idri/data_utils.py @@ -184,11 +184,7 @@ def random_sampler(image, label, patch_shape=(128, 128, 128), n_samples=2): paddings, mode="reflect", ).squeeze() - label = F.pad( - label, - paddings, - mode="constant", - ) + label = F.pad(label, paddings, mode="constant") # Extract patches, taking into account the shift in coordinates due to padding image_patches = extract_patches(image, centroids + patch_shape, patch_shape) @@ -232,12 +228,7 @@ def all_sampler(X, y, patch_shape=(128, 128, 128)): def fast_sampler( - X, - y, - patch_shape=(128, 128, 128), - n_patches=2, - ratio=1.0, - center=False, + X, y, patch_shape=(128, 128, 128), n_patches=2, ratio=1.0, center=False ): """ Parameters @@ -274,9 +265,9 @@ def fast_sampler( # Add noise to centroids so that the nodules are not always centered in # the patch: if not center: - noise = ( - torch.rand(centroids_1.shape[0], 3) * torch.max(patch_shape) - ).long() % (patch_shape.div(2, rounding_mode="floor")[None, ...]) + noise = (torch.rand(centroids_1.shape[0], 3) * torch.max(patch_shape)).long() % ( + patch_shape.div(2, rounding_mode="floor")[None, ...] + ) centroids_1 += noise - patch_shape.div(4, rounding_mode="floor") # Sample random centroids @@ -297,16 +288,8 @@ def fast_sampler( paddings = tuple(torch.stack([patch_shape, patch_shape], dim=-1).flatten())[::-1] - X = F.pad( - X[None, None, :, :, :], - paddings, - mode="reflect", - ).squeeze() - y = F.pad( - y, - paddings, - mode="constant", - ) + X = F.pad(X[None, None, :, :, :], paddings, mode="reflect").squeeze() + y = F.pad(y, paddings, mode="constant") # Extract patches, taking into account the shift in coordinates due to padding image_patches = extract_patches(X, centroids + patch_shape, patch_shape) diff --git a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/download_ct_scans.py b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/download_ct_scans.py index ea79594af..d45f71ad7 100644 --- a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/download_ct_scans.py +++ b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/download_ct_scans.py @@ -242,9 +242,7 @@ def LIDC_to_niftis(extraction_results_dataframe, spacing=[1.0, 1.0, 1.0], debug= extraction_results_dataframe.iterrows(), ) progbar = tqdm.tqdm( - loop, - total=extraction_results_dataframe.shape[0], - desc="Converting to NiFTIs...", + loop, total=extraction_results_dataframe.shape[0], desc="Converting to NiFTIs..." ) converted_dicoms = Parallel(n_jobs=1, prefer="processes")( delayed(convert_to_niftis)(*t, spacing=spacing) for t in progbar diff --git a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/process_raw.py b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/process_raw.py index 7378b7ad4..96be254fc 100644 --- a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/process_raw.py +++ b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/process_raw.py @@ -295,9 +295,7 @@ def opN(Nodule): xs = roi.findall("{http://www.nih.gov}edgeMap/{http://www.nih.gov}xCoord") ys = roi.findall("{http://www.nih.gov}edgeMap/{http://www.nih.gov}yCoord") - zs = [ - locz, - ] * len(xs) + zs = [locz] * len(xs) xs = np.array([int(x.text) for x in xs]) ys = np.array([int(y.text) for y in ys]) diff --git a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/tciaclient.py b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/tciaclient.py index 9bb059e2b..675e8538b 100644 --- a/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/tciaclient.py +++ b/flamby/datasets/fed_lidc_idri/dataset_creation_scripts/tciaclient.py @@ -1,15 +1,16 @@ -# Copy pasted from https://github.com/nadirsaghar/TCIA-REST-API-Client/blob/master/tcia-rest-client-python/src/tciaclient.py +# Copy pasted from +# https://github.com/nadirsaghar/TCIA-REST-API-Client/blob/master/ +# tcia-rest-client-python/src/tciaclient.py # -import math import os -import sys import urllib.error import urllib.parse import urllib.request # -# Refer https://wiki.cancerimagingarchive.net/display/Public/REST+API+Usage+Guide for complete list of API +# Refer to https://wiki.cancerimagingarchive.net/display/ +# Public/REST+API+Usage+Guide for complete list of API # @@ -141,11 +142,7 @@ def get_body_part_values( return resp def get_patient_study( - self, - collection=None, - patientId=None, - studyInstanceUid=None, - outputFormat="json", + self, collection=None, patientId=None, studyInstanceUid=None, outputFormat="json" ): serviceUrl = self.baseUrl + "/query/" + self.GET_PATIENT_STUDY diff --git a/flamby/datasets/fed_lidc_idri/lidc_heterogeneity_plot.py b/flamby/datasets/fed_lidc_idri/lidc_heterogeneity_plot.py index 2e99c7aa9..dedf7aaaa 100644 --- a/flamby/datasets/fed_lidc_idri/lidc_heterogeneity_plot.py +++ b/flamby/datasets/fed_lidc_idri/lidc_heterogeneity_plot.py @@ -14,12 +14,7 @@ def make_plot(): list_x = np.linspace(0, 1, num=200) for center in [0, 1, 2, 3]: print(f"doing center {center}") - ds = FedLidcIdri( - train=True, - pooled=False, - center=center, - debug=False, - ) + ds = FedLidcIdri(train=True, pooled=False, center=center, debug=False) list_data = [] for k in tqdm(range(len(ds))): data = ds[k][0].detach().cpu().ravel() diff --git a/flamby/datasets/fed_synthetic/dataset_creation_scripts/download.py b/flamby/datasets/fed_synthetic/dataset_creation_scripts/download.py index 96dcddfa1..4f1546f32 100755 --- a/flamby/datasets/fed_synthetic/dataset_creation_scripts/download.py +++ b/flamby/datasets/fed_synthetic/dataset_creation_scripts/download.py @@ -98,10 +98,7 @@ def main(output_folder, debug=False, **kwargs): default=None, ) parser.add_argument( - "--noise-heterogeneity", - type=float, - help="Sample repartition.", - default=None, + "--noise-heterogeneity", type=float, help="Sample repartition.", default=None ) parser.add_argument( "--features-heterogeneity", diff --git a/flamby/datasets/fed_synthetic/synthetic_generator.py b/flamby/datasets/fed_synthetic/synthetic_generator.py index 0dc7f4a95..e4bc2ad03 100644 --- a/flamby/datasets/fed_synthetic/synthetic_generator.py +++ b/flamby/datasets/fed_synthetic/synthetic_generator.py @@ -127,9 +127,7 @@ def generate_synthetic_dataset( if noise_heterogeneity is None: snr_locs = np.ones(n_centers) * snr elif type(noise_heterogeneity) in [list, np.array]: - assert ( - snr == 3 - ), "Option snr is incompatible with noise_heterogeneity as a list." + assert snr == 3, "Option snr is incompatible with noise_heterogeneity as a list." snr_locs = np.array(noise_heterogeneity) else: raise ValueError( diff --git a/flamby/datasets/fed_tcga_brca/benchmark.py b/flamby/datasets/fed_tcga_brca/benchmark.py index 30db78f90..35270fbbc 100644 --- a/flamby/datasets/fed_tcga_brca/benchmark.py +++ b/flamby/datasets/fed_tcga_brca/benchmark.py @@ -199,21 +199,13 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument( - "--GPU", - type=int, - default=0, - help="GPU to run the training on (if available)", + "--GPU", type=int, default=0, help="GPU to run the training on (if available)" ) parser.add_argument( - "--workers", - type=int, - default=4, - help="Numbers of workers for the dataloader", + "--workers", type=int, default=4, help="Numbers of workers for the dataloader" ) parser.add_argument( - "--log", - action="store_true", - help="Whether or not to dump tensorboard events.", + "--log", action="store_true", help="Whether or not to dump tensorboard events." ) parser.add_argument( "--log-period", diff --git a/flamby/datasets/fed_tcga_brca/dataset.py b/flamby/datasets/fed_tcga_brca/dataset.py index 19a27954e..a30386224 100644 --- a/flamby/datasets/fed_tcga_brca/dataset.py +++ b/flamby/datasets/fed_tcga_brca/dataset.py @@ -48,10 +48,7 @@ def __len__(self): def __getitem__(self, idx): x = self.data.iloc[idx, 1:40] y = self.data.iloc[idx, 40:42] - return ( - torch.tensor(x, dtype=self.X_dtype), - torch.tensor(y, dtype=self.y_dtype), - ) + return (torch.tensor(x, dtype=self.X_dtype), torch.tensor(y, dtype=self.y_dtype)) class FedTcgaBrca(TcgaBrcaRaw): diff --git a/flamby/datasets/split_utils.py b/flamby/datasets/split_utils.py index 25585c206..ae297fa20 100644 --- a/flamby/datasets/split_utils.py +++ b/flamby/datasets/split_utils.py @@ -37,7 +37,7 @@ def split_indices_linear(original_table, dataset_sizes, num_target_centers): current_idx = 0 for idx_new_client in range(num_target_centers - 1): assignment_new_client[ - current_idx : current_idx + num_samples_per_new_client + slice(current_idx, current_idx + num_samples_per_new_client) ] = idx_new_client current_idx += num_samples_per_new_client assignment_new_client[current_idx:] = num_target_centers - 1 @@ -219,10 +219,10 @@ def split_dataset( _current_idx = 0 for idx_client_orig, length_client in enumerate(client_size_list): original_table[split][ - 0, _current_idx : _current_idx + length_client + 0, np.arange(_current_idx, _current_idx + length_client) ] = idx_client_orig original_table[split][ - 1, _current_idx : _current_idx + length_client + 1, np.arange(_current_idx, _current_idx + length_client) ] = np.arange(0, length_client) _current_idx += length_client diff --git a/flamby/extract_config.py b/flamby/extract_config.py index 087ba46d4..ba74c2817 100644 --- a/flamby/extract_config.py +++ b/flamby/extract_config.py @@ -36,9 +36,7 @@ def main(args_cli): "You should provide as many dataset names as you gave results" " files or 1 if they all come from the same dataset." ) - optimizers_classes = [ - e[1] for e in inspect.getmembers(torch.optim, inspect.isclass) - ] + optimizers_classes = [e[1] for e in inspect.getmembers(torch.optim, inspect.isclass)] csvs = [pd.read_csv(e) for e in csv_files] for dname, csv, csvf in zip(dataset_names, csvs, csv_files): config = {} diff --git a/flamby/personalization_example/plot_perso_results.py b/flamby/personalization_example/plot_perso_results.py index cebd604bb..96a16ed33 100644 --- a/flamby/personalization_example/plot_perso_results.py +++ b/flamby/personalization_example/plot_perso_results.py @@ -1,39 +1,21 @@ # Plot import matplotlib.pyplot as plt - -# import numpy as np import pandas as pd import seaborn as sns sns.set_theme(style="darkgrid") -# datasets_names = ["Fed-Heart-Disease", "Fed-Camelyon16", "Fed-ISIC2019"] -# n_repetitions = 5 -# rows = [] -# for d in datasets_names: -# for se in np.arange(42, 42 + n_repetitions): -# rows.append({"dataset": d, "seed": se, "perf": float(np.random.uniform(0., 1., 1)), "finetune": True}) -# rows.append({"dataset": d, "seed": se, "perf": float(np.random.uniform(0., 1., 1)), "finetune": False}) - -# results = pd.DataFrame.from_dict(rows) - results = pd.read_csv("results_perso_vs_normal.csv") results = results.rename(columns={"perf": "Performance"}) fig, ax = plt.subplots() -g = sns.boxplot( - data=results, - x="dataset", - y="Performance", - hue="finetune", - ax=ax, -) +g = sns.boxplot(data=results, x="dataset", y="Performance", hue="finetune", ax=ax) ax.set_xlabel(None) ax.set_ylim(0.0, 1.0) mapping_dict = {"True": "Fine-tuned", "False": "Not fine-tuned"} handles, labels = ax.get_legend_handles_labels() -ax.legend(handles=handles, labels=[mapping_dict[l] for l in labels]) +ax.legend(handles=handles, labels=[mapping_dict[lab] for lab in labels]) plt.savefig("perso_vs_non_perso.pdf", dpi=100, bbox_inches="tight") diff --git a/flamby/results/plot_results.py b/flamby/results/plot_results.py index bf476324a..41849d2cc 100644 --- a/flamby/results/plot_results.py +++ b/flamby/results/plot_results.py @@ -127,7 +127,7 @@ def partial_share_y_axes(axs): ] # Messing with palettes to keep the same color for pooled and strategies - current_palette = [palette[0]] + palette[1 : (current_num_clients + 1)] + current_palette = [palette[0]] + palette[slice(1, current_num_clients + 1)] current_palette += palette[7:] assert len(current_palette) == len(current_methods_display) # print(current_palette[len(current_methods_display) -1]) diff --git a/flamby/strategies/cyclic.py b/flamby/strategies/cyclic.py index 86ce60e43..0edae86fd 100644 --- a/flamby/strategies/cyclic.py +++ b/flamby/strategies/cyclic.py @@ -145,9 +145,7 @@ def __init__( self.deterministic_cycle = deterministic_cycle - self._rng = ( - rng if (rng is not None) else np.random.default_rng(int(time.time())) - ) + self._rng = rng if (rng is not None) else np.random.default_rng(int(time.time())) self._clients = self._shuffle_clients() self._current_idx = -1 diff --git a/flamby/strategies/scaffold.py b/flamby/strategies/scaffold.py index 453d6a7fb..b07214125 100644 --- a/flamby/strategies/scaffold.py +++ b/flamby/strategies/scaffold.py @@ -124,20 +124,14 @@ def __init__( ] # initialize the corrections used by each client to 0s. self.client_corrections_state_list = [ - [ - torch.zeros_like(torch.from_numpy(p)) - for p in _model._get_current_params() - ] + [torch.zeros_like(torch.from_numpy(p)) for p in _model._get_current_params()] for _model in self.models_list ] self.client_lr = learning_rate self.server_lr = server_learning_rate def _local_optimization( - self, - _model: _Model, - dataloader_with_memory, - correction_state: List, + self, _model: _Model, dataloader_with_memory, correction_state: List ): """Carry out the local optimization step. diff --git a/flamby/strategies/utils.py b/flamby/strategies/utils.py index 4e54862b7..4f552cd86 100644 --- a/flamby/strategies/utils.py +++ b/flamby/strategies/utils.py @@ -270,9 +270,7 @@ def _prox_local_train(self, dataloader_with_memory, num_updates, mu): _loss = _prox_loss.detach() if mu > 0.0: - squared_norm = compute_model_diff_squared_norm( - model_initial, self.model - ) + squared_norm = compute_model_diff_squared_norm(model_initial, self.model) _prox_loss += mu / 2 * squared_norm # Backpropagation diff --git a/integration/FedML/fedml_utils.py b/integration/FedML/fedml_utils.py index 73734a855..1ecb78eba 100644 --- a/integration/FedML/fedml_utils.py +++ b/integration/FedML/fedml_utils.py @@ -95,11 +95,7 @@ def test(self, test_data, device, args): def _test(self, test_data, device): logging.info("Evaluating on Trainer ID: {}".format(self.id)) - test_metrics = { - "test_correct": 0, - "test_total": 0, - "test_loss": 0, - } + test_metrics = {"test_correct": 0, "test_total": 0, "test_loss": 0} if not test_data: logging.info("No test data for this trainer") diff --git a/tests/benchmarks/test_fed_benchmark.py b/tests/benchmarks/test_fed_benchmark.py index 1a5ae9d7a..ff758306f 100644 --- a/tests/benchmarks/test_fed_benchmark.py +++ b/tests/benchmarks/test_fed_benchmark.py @@ -32,9 +32,7 @@ def assert_dfs_equal(pair0, pair1, ignore_columns=[]): ignore_columns : list, optional The columns that comparison should exclude, by default [] """ - ignore_columns = [ - col for col in ignore_columns if (col in pair0) and (col in pair1) - ] + ignore_columns = [col for col in ignore_columns if (col in pair0) and (col in pair1)] df1 = pair0.drop(columns=ignore_columns).fillna("-9") df2 = pair1.drop(columns=ignore_columns).fillna("-9")[df1.columns] assert ( diff --git a/tests/strategies/test_fed_avg.py b/tests/strategies/test_fed_avg.py index e5c409b46..cd97e75e5 100644 --- a/tests/strategies/test_fed_avg.py +++ b/tests/strategies/test_fed_avg.py @@ -111,10 +111,7 @@ def test_fedavg_Isic(): ] ) test_aug = albumentations.Compose( - [ - albumentations.CenterCrop(sz, sz), - albumentations.Normalize(always_apply=True), - ] + [albumentations.CenterCrop(sz, sz), albumentations.Normalize(always_apply=True)] ) training_dls = [ dl( diff --git a/tests/strategies/test_fed_prox.py b/tests/strategies/test_fed_prox.py index 26982cc32..4483cdeb8 100644 --- a/tests/strategies/test_fed_prox.py +++ b/tests/strategies/test_fed_prox.py @@ -69,9 +69,7 @@ def test_fed_prox_integration(n_clients): mu = 0.1 optimizer_class = torch.optim.Adam - s = FedProx( - train_dataloader, m, loss, optimizer_class, lr, num_updates, nrounds, mu - ) + s = FedProx(train_dataloader, m, loss, optimizer_class, lr, num_updates, nrounds, mu) m = s.run() def accuracy(y_true, y_pred): diff --git a/tests/strategies/test_scaffold.py b/tests/strategies/test_scaffold.py index 8e019efd0..08f4c55aa 100644 --- a/tests/strategies/test_scaffold.py +++ b/tests/strategies/test_scaffold.py @@ -83,10 +83,7 @@ def accuracy(y_true, y_pred): cleanup() -@pytest.mark.parametrize( - "seed, lr", - [(42, 0.01), (43, 0.001), (44, 0.0001), (45, 7e-5)], -) +@pytest.mark.parametrize("seed, lr", [(42, 0.01), (43, 0.001), (44, 0.0001), (45, 7e-5)]) def test_scaffold_algorithm(seed, lr): r"""Scaffold should add a correction term in each of its update step. In the first round, this correction step is 0. In each subsequent round,