From eabf77daccf5fb5ac38a4c99edf9aece3322c700 Mon Sep 17 00:00:00 2001 From: Dominic Orchard Date: Wed, 8 Nov 2023 23:05:09 +0000 Subject: [PATCH 1/3] update author --- pyproject-poetry.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject-poetry.toml b/pyproject-poetry.toml index a0630b71..885c22c2 100644 --- a/pyproject-poetry.toml +++ b/pyproject-poetry.toml @@ -66,7 +66,7 @@ namespaces = false # to disable scanning PEP 420 namespaces (true by default) name = "gz21_ocean_momentum" version = "0.2.0" description = "TODO" -authors = ["Sébastien Eustace "] +authors = ["Arthur Guillaumin "] readme = "README.md" packages = [{include = "poetry_demo"}] From d3947e51deafc0c57d871258d00965ef727eddea Mon Sep 17 00:00:00 2001 From: Dominic Orchard Date: Wed, 8 Nov 2023 23:04:17 +0000 Subject: [PATCH 2/3] fix training script parameter bug on default for learning rate --- src/gz21_ocean_momentum/trainScript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gz21_ocean_momentum/trainScript.py b/src/gz21_ocean_momentum/trainScript.py index f7d36e76..391e50da 100755 --- a/src/gz21_ocean_momentum/trainScript.py +++ b/src/gz21_ocean_momentum/trainScript.py @@ -119,7 +119,7 @@ def check_str_is_None(string_in: str): parser.add_argument("--batchsize", type=int, default=8) parser.add_argument("--n_epochs", type=int, default=100) parser.add_argument( - "--learning_rate", type=learning_rates_from_string, default={"0\1e-3"} + "--learning_rate", type=learning_rates_from_string, default="0/1e-3" ) parser.add_argument("--train_split", type=float, default=0.8, help="Between 0 and 1") parser.add_argument( From ce70a3412baf12f1fadb285ab95b5b429d2501ac Mon Sep 17 00:00:00 2001 From: Dominic Orchard Date: Wed, 8 Nov 2023 23:03:19 +0000 Subject: [PATCH 3/3] fix in datasets module to match import structure --- src/gz21_ocean_momentum/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gz21_ocean_momentum/data/datasets.py b/src/gz21_ocean_momentum/data/datasets.py index 9e80d87d..eb6f2ebb 100644 --- a/src/gz21_ocean_momentum/data/datasets.py +++ b/src/gz21_ocean_momentum/data/datasets.py @@ -188,7 +188,7 @@ def add_targets_transform(self, transform): ) self.transforms["targets"].add_transform(transform) - def fit(self, x: torch.utils.data.Dataset): + def fit(self, x: torch.Dataset): """ Call the fit method of all array transforms in the list of features and target transforms on the passed Dataset.