diff --git a/notebooks/benchmarks_sandbox.ipynb b/notebooks/benchmarks_sandbox.ipynb index ab68c2a4..fa3887f5 100644 --- a/notebooks/benchmarks_sandbox.ipynb +++ b/notebooks/benchmarks_sandbox.ipynb @@ -9,7 +9,7 @@ "\n", "**Author**: Ivan Zvonkov\n", "\n", - "**Last Modified**: Jan 17, 2024\n", + "**Last Modified**: Feb 6, 2024\n", "\n", "**Description**: Code for benchmarking against different variations in models." ] @@ -64,7 +64,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/1v/87y9n_d5143c_6cp072v3b1c0000gn/T/ipykernel_33350/4119029012.py:4: DtypeWarning: Columns (17) have mixed types.Specify dtype option on import or set low_memory=False.\n", + "/var/folders/1v/87y9n_d5143c_6cp072v3b1c0000gn/T/ipykernel_25504/4119029012.py:4: DtypeWarning: Columns (17) have mixed types.Specify dtype option on import or set low_memory=False.\n", " df = d.load_df(to_np=True, disable_tqdm=True)\n" ] } @@ -1198,12 +1198,13 @@ }, { "cell_type": "code", - "execution_count": 144, - "id": "ec38c410", + "execution_count": 39, + "id": "0d54af51", "metadata": {}, "outputs": [], "source": [ - "!pip install einops -q" + "%load_ext autoreload\n", + "%autoreload 2" ] }, { @@ -1218,12 +1219,12 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 96, "id": "7b6e4903", "metadata": {}, "outputs": [], "source": [ - "from src.single_file_presto_v2 import Presto, DEVICE\n", + "from src.single_file_presto_v2 import Presto, DEVICE, Aggregate\n", "\n", "import numpy as np\n", "import torch\n", @@ -1232,16 +1233,6 @@ "from torch.utils.data import Dataset, DataLoader" ] }, - { - "cell_type": "code", - "execution_count": 64, - "id": "9ff1198a", - "metadata": {}, - "outputs": [], - "source": [ - "torch.tensor??" - ] - }, { "cell_type": "markdown", "id": "209caaaf", @@ -1252,7 +1243,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 97, "id": "aeec5ba1", "metadata": {}, "outputs": [], @@ -1282,17 +1273,7 @@ }, { "cell_type": "code", - "execution_count": 80, - "id": "d4a70d1f", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = PrestoDataset(val_df)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, + "execution_count": 98, "id": "eaf12d47", "metadata": {}, "outputs": [], @@ -1302,7 +1283,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 99, "id": "87006785", "metadata": {}, "outputs": [], @@ -1313,7 +1294,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 100, "id": "a4cfaf77", "metadata": {}, "outputs": [], @@ -1324,21 +1305,35 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 101, "id": "cd9e5493", "metadata": {}, "outputs": [], "source": [ - "def generate_encodings(dataset):\n", + "def generate_encodings(dataset, aggregate):\n", " dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=False)\n", " feature_list = []\n", " for (x, latlons, dw, start_month, _) in tqdm(dataloader, desc=\"Encodings\", leave=False):\n", " with torch.no_grad():\n", - " encodings = (pretrained_model(x, dynamic_world=dw, latlons=latlons, month=start_month).cpu().numpy())\n", + " encodings = (pretrained_model(\n", + " x, dynamic_world=dw, latlons=latlons, month=start_month, aggregate=aggregate\n", + " ).cpu().numpy())\n", " feature_list.append(encodings)\n", " return np.concatenate(feature_list)" ] }, + { + "cell_type": "code", + "execution_count": 102, + "id": "a98ac408", + "metadata": {}, + "outputs": [], + "source": [ + "# Use Sklearn scaling of encodings\n", + "from sklearn.pipeline import make_pipeline\n", + "from sklearn.preprocessing import StandardScaler" + ] + }, { "cell_type": "markdown", "id": "3da1c2e7", @@ -1349,14 +1344,14 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 104, "id": "e5b4ae73", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a4a4a0d01e0943faab809de125d203d6", + "model_id": "84cc0dcb39ea4ebb851fa5b918546d3c", "version_major": 2, "version_minor": 0 }, @@ -1399,7 +1394,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mali_lower_CEO_2019: 0.5882352941176471\n" + "Mali_lower_CEO_2019: 0.6198830409356726\n" ] }, { @@ -1434,7 +1429,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Togo: 0.7563025210084034\n" + "Togo: 0.7317073170731708\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" ] }, { @@ -1465,11 +1474,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Rwanda: 0.6273458445040215\n" + "Rwanda: 0.6847290640394088\n" ] }, { @@ -1500,11 +1523,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Uganda: 0.46464646464646464\n" + "Uganda: 0.5098039215686275\n" ] }, { @@ -1535,11 +1572,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Ethiopia_Tigray_2020: 0.6412213740458016\n" + "Ethiopia_Tigray_2020: 0.671480144404332\n" ] }, { @@ -1574,7 +1625,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Ethiopia_Tigray_2021: 0.6699029126213593\n" + "Ethiopia_Tigray_2021: 0.7222222222222223\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" ] }, { @@ -1609,7 +1674,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Ethiopia_Bure_Jimma_2019: 0.8193548387096774\n" + "Ethiopia_Bure_Jimma_2019: 0.8571428571428571\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" ] }, { @@ -1640,11 +1719,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Ethiopia_Bure_Jimma_2020: 0.8803088803088803\n" + "Ethiopia_Bure_Jimma_2020: 0.8673835125448028\n" ] }, { @@ -1675,11 +1768,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Malawi_CEO_2020: 0.14285714285714288\n" + "Malawi_CEO_2020: 0.4079601990049751\n" ] }, { @@ -1710,11 +1817,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Tanzania_CEO_2019: 0.8032200357781754\n" + "Tanzania_CEO_2019: 0.8313155770782888\n" ] }, { @@ -1749,7 +1870,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Sudan_Blue_Nile_CEO_2019: 0.8589341692789969\n" + "Sudan_Blue_Nile_CEO_2019: 0.9201101928374655\n" ] }, { @@ -1784,7 +1905,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "SudanBlueNileCEO2020: 0.7680608365019012\n" + "SudanBlueNileCEO2020: 0.7789473684210527\n" ] }, { @@ -1815,11 +1936,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Senegal_CEO_2022: 0.4444444444444444\n" + "Senegal_CEO_2022: 0.6244343891402715\n" ] }, { @@ -1854,7 +1989,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "SudanAlGadarefCEO2019: 0.5993031358885018\n" + "SudanAlGadarefCEO2019: 0.5892857142857143\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" ] }, { @@ -1885,11 +2034,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "SudanAlGadarefCEO2020: 0.6222222222222222\n" + "SudanAlGadarefCEO2020: 0.7209775967413442\n" ] }, { @@ -1920,11 +2083,25 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "SudanGedarefDarfurAlJazirah2022: 0.763157894736842\n" + "SudanGedarefDarfurAlJazirah2022: 0.7615658362989324\n" ] }, { @@ -1959,7 +2136,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "Uganda_NorthCEO2022: 0.28205128205128205\n" + "Uganda_NorthCEO2022: 0.4333333333333333\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniconda/base/envs/landcover-mapping/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" ] } ], @@ -1974,16 +2165,18 @@ " test_df = df[is_test] \n", " train_df = df[~is_test & is_local_lat & is_local_lon]\n", " \n", - " train_dataset = PrestoDataset(train_df, start_month=2)\n", - " test_dataset = PrestoDataset(test_df, start_month=2) \n", - " X_train = generate_encodings(train_dataset)\n", - " X_test = generate_encodings(test_dataset)\n", + " train_dataset = PrestoDataset(train_df, start_month=1)\n", + " test_dataset = PrestoDataset(test_df, start_month=1) \n", + " X_train = generate_encodings(train_dataset, Aggregate.BAND_GROUPS_MEAN)\n", + " X_test = generate_encodings(test_dataset, Aggregate.BAND_GROUPS_MEAN)\n", " \n", " y_train = train_df[\"is_crop\"].to_list() \n", " y_test = test_df[\"is_crop\"].to_list()\n", " \n", - " #model = LogisticRegression(class_weight=\"balanced\", max_iter=1000, random_state=DEFAULT_SEED)\n", - " model = RandomForestClassifier(class_weight=\"balanced\", random_state=DEFAULT_SEED)\n", + " model = LogisticRegression(class_weight=\"balanced\", max_iter=1000, random_state=DEFAULT_SEED)\n", + " #pipe = make_pipeline(StandardScaler(), model)\n", + " #pipe.fit(X_train, y_train)\n", + " #model = RandomForestClassifier(class_weight=\"balanced\", random_state=DEFAULT_SEED)\n", " model.fit(X_train, y_train)\n", " \n", " #y_pred = model.predict(X_test)\n", @@ -1996,18 +2189,20 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 105, "id": "83931aea", "metadata": {}, "outputs": [], "source": [ + "benchmark_name = \"Presto LR Feb-Feb F1 Score (no DW, band group encodings, per group LayerNorm)\"\n", "for dataset, f1 in f1_scores.items():\n", - " presto_benchmark.loc[presto_benchmark[\"Name\"] == dataset, \"Presto RF Mar-Mar F1 Score (no DW)\"] = f1" + " presto_benchmark.loc[presto_benchmark[\"Name\"] == dataset, benchmark_name] = f1\n", + " " ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 106, "id": "042873fd", "metadata": {}, "outputs": [ @@ -2036,8 +2231,10 @@ " Total\n", " Crop Amount\n", " Crop Rate\n", - " Presto RF Mar-Mar F1 Score (no DW)\n", - " Presto RF Feb-Feb F1 Score (no DW)\n", + " Presto LR Mar-Mar F1 Score (no DW, band group encodings, no norm)\n", + " Presto LR Mar-Mar F1 Score (no DW, band group encodings, per group norm)\n", + " Presto LR Mar-Mar F1 Score (no DW, band group encodings, sklearn StandardScaler)\n", + " Presto LR Feb-Feb F1 Score (no DW, band group encodings, per group LayerNorm)\n", " \n", " \n", " \n", @@ -2047,8 +2244,10 @@ " 271\n", " 94\n", " 0.347\n", - " 0.588235\n", - " 0.641711\n", + " 0.602273\n", + " 0.609195\n", + " 0.576087\n", + " 0.619883\n", " \n", " \n", " 3\n", @@ -2056,8 +2255,10 @@ " 310\n", " 107\n", " 0.345\n", - " 0.756303\n", " 0.742857\n", + " 0.741935\n", + " 0.742857\n", + " 0.731707\n", " \n", " \n", " 4\n", @@ -2065,8 +2266,10 @@ " 555\n", " 191\n", " 0.344\n", - " 0.627346\n", - " 0.576819\n", + " 0.666667\n", + " 0.682927\n", + " 0.663342\n", + " 0.684729\n", " \n", " \n", " 5\n", @@ -2074,8 +2277,10 @@ " 456\n", " 52\n", " 0.114\n", - " 0.464646\n", - " 0.437500\n", + " 0.472050\n", + " 0.465409\n", + " 0.465409\n", + " 0.509804\n", " \n", " \n", " 6\n", @@ -2083,8 +2288,10 @@ " 507\n", " 139\n", " 0.274\n", - " 0.641221\n", - " 0.651341\n", + " 0.666667\n", + " 0.654412\n", + " 0.654275\n", + " 0.671480\n", " \n", " \n", " 7\n", @@ -2092,8 +2299,10 @@ " 367\n", " 120\n", " 0.327\n", - " 0.669903\n", - " 0.689320\n", + " 0.712963\n", + " 0.694836\n", + " 0.712329\n", + " 0.722222\n", " \n", " \n", " 8\n", @@ -2101,8 +2310,10 @@ " 498\n", " 161\n", " 0.323\n", - " 0.819355\n", - " 0.825806\n", + " 0.868263\n", + " 0.867257\n", + " 0.861446\n", + " 0.857143\n", " \n", " \n", " 9\n", @@ -2110,8 +2321,10 @@ " 455\n", " 129\n", " 0.284\n", - " 0.880309\n", - " 0.849421\n", + " 0.827338\n", + " 0.845878\n", + " 0.826568\n", + " 0.867384\n", " \n", " \n", " 10\n", @@ -2119,8 +2332,10 @@ " 457\n", " 67\n", " 0.147\n", - " 0.142857\n", - " 0.117647\n", + " 0.340659\n", + " 0.381443\n", + " 0.352273\n", + " 0.407960\n", " \n", " \n", " 12\n", @@ -2128,8 +2343,10 @@ " 2037\n", " 626\n", " 0.307\n", - " 0.803220\n", - " 0.809991\n", + " 0.827200\n", + " 0.843875\n", + " 0.844051\n", + " 0.831316\n", " \n", " \n", " 14\n", @@ -2137,8 +2354,10 @@ " 526\n", " 173\n", " 0.329\n", - " 0.858934\n", - " 0.863354\n", + " 0.911602\n", + " 0.922652\n", + " 0.901099\n", + " 0.920110\n", " \n", " \n", " 16\n", @@ -2146,8 +2365,10 @@ " 610\n", " 90\n", " 0.148\n", - " 0.444444\n", - " 0.458333\n", + " 0.622642\n", + " 0.611111\n", + " 0.616114\n", + " 0.624434\n", " \n", " \n", " 18\n", @@ -2155,8 +2376,10 @@ " 533\n", " 121\n", " 0.227\n", - " 0.768061\n", - " 0.783270\n", + " 0.782007\n", + " 0.787671\n", + " 0.773050\n", + " 0.778947\n", " \n", " \n", " 19\n", @@ -2164,8 +2387,10 @@ " 533\n", " 135\n", " 0.253\n", - " 0.599303\n", - " 0.578397\n", + " 0.602410\n", + " 0.600000\n", + " 0.591900\n", + " 0.589286\n", " \n", " \n", " 21\n", @@ -2173,8 +2398,10 @@ " 532\n", " 202\n", " 0.380\n", - " 0.622222\n", - " 0.630986\n", + " 0.714588\n", + " 0.700210\n", + " 0.711864\n", + " 0.720978\n", " \n", " \n", " 23\n", @@ -2182,8 +2409,10 @@ " 375\n", " 127\n", " 0.339\n", - " 0.763158\n", - " 0.741379\n", + " 0.792857\n", + " 0.767025\n", + " 0.787234\n", + " 0.761566\n", " \n", " \n", " 24\n", @@ -2191,8 +2420,10 @@ " 319\n", " 56\n", " 0.176\n", - " 0.282051\n", - " 0.271605\n", + " 0.448000\n", + " 0.466667\n", + " 0.412698\n", + " 0.433333\n", " \n", " \n", "\n", @@ -2218,27 +2449,84 @@ "23 SudanGedarefDarfurAlJazirah2022 375 127 0.339 \n", "24 Uganda_NorthCEO2022 319 56 0.176 \n", "\n", - " Presto RF Mar-Mar F1 Score (no DW) Presto RF Feb-Feb F1 Score (no DW) \n", - "1 0.588235 0.641711 \n", - "3 0.756303 0.742857 \n", - "4 0.627346 0.576819 \n", - "5 0.464646 0.437500 \n", - "6 0.641221 0.651341 \n", - "7 0.669903 0.689320 \n", - "8 0.819355 0.825806 \n", - "9 0.880309 0.849421 \n", - "10 0.142857 0.117647 \n", - "12 0.803220 0.809991 \n", - "14 0.858934 0.863354 \n", - "16 0.444444 0.458333 \n", - "18 0.768061 0.783270 \n", - "19 0.599303 0.578397 \n", - "21 0.622222 0.630986 \n", - "23 0.763158 0.741379 \n", - "24 0.282051 0.271605 " - ] - }, - "execution_count": 28, + " Presto LR Mar-Mar F1 Score (no DW, band group encodings, no norm) \\\n", + "1 0.602273 \n", + "3 0.742857 \n", + "4 0.666667 \n", + "5 0.472050 \n", + "6 0.666667 \n", + "7 0.712963 \n", + "8 0.868263 \n", + "9 0.827338 \n", + "10 0.340659 \n", + "12 0.827200 \n", + "14 0.911602 \n", + "16 0.622642 \n", + "18 0.782007 \n", + "19 0.602410 \n", + "21 0.714588 \n", + "23 0.792857 \n", + "24 0.448000 \n", + "\n", + " Presto LR Mar-Mar F1 Score (no DW, band group encodings, per group norm) \\\n", + "1 0.609195 \n", + "3 0.741935 \n", + "4 0.682927 \n", + "5 0.465409 \n", + "6 0.654412 \n", + "7 0.694836 \n", + "8 0.867257 \n", + "9 0.845878 \n", + "10 0.381443 \n", + "12 0.843875 \n", + "14 0.922652 \n", + "16 0.611111 \n", + "18 0.787671 \n", + "19 0.600000 \n", + "21 0.700210 \n", + "23 0.767025 \n", + "24 0.466667 \n", + "\n", + " Presto LR Mar-Mar F1 Score (no DW, band group encodings, sklearn StandardScaler) \\\n", + "1 0.576087 \n", + "3 0.742857 \n", + "4 0.663342 \n", + "5 0.465409 \n", + "6 0.654275 \n", + "7 0.712329 \n", + "8 0.861446 \n", + "9 0.826568 \n", + "10 0.352273 \n", + "12 0.844051 \n", + "14 0.901099 \n", + "16 0.616114 \n", + "18 0.773050 \n", + "19 0.591900 \n", + "21 0.711864 \n", + "23 0.787234 \n", + "24 0.412698 \n", + "\n", + " Presto LR Feb-Feb F1 Score (no DW, band group encodings, per group LayerNorm) \n", + "1 0.619883 \n", + "3 0.731707 \n", + "4 0.684729 \n", + "5 0.509804 \n", + "6 0.671480 \n", + "7 0.722222 \n", + "8 0.857143 \n", + "9 0.867384 \n", + "10 0.407960 \n", + "12 0.831316 \n", + "14 0.920110 \n", + "16 0.624434 \n", + "18 0.778947 \n", + "19 0.589286 \n", + "21 0.720978 \n", + "23 0.761566 \n", + "24 0.433333 " + ] + }, + "execution_count": 106, "metadata": {}, "output_type": "execute_result" } diff --git a/src/single_file_presto_v2.py b/src/single_file_presto_v2.py index 01b2e814..a361079e 100644 --- a/src/single_file_presto_v2.py +++ b/src/single_file_presto_v2.py @@ -358,7 +358,9 @@ def band_groups_mean( mask = (kept_indices >= min_idx) & (kept_indices < max_idx) # we assume kept_elements is the same for all batches kept_elements = sum(mask[0, :]) - groups.append(x[mask.bool()].view(batch_size, kept_elements, embedding_dim).mean(dim=1)) + one_group = x[mask.bool()].view(batch_size, kept_elements, embedding_dim).mean(dim=1) + one_group_normed = self.norm(one_group) + groups.append(one_group_normed) cur_idx = max_idx return torch.cat(groups, dim=1) @@ -455,7 +457,7 @@ def forward( if aggregate == Aggregate.MEAN: return self.norm(x.mean(dim=1)) elif aggregate == Aggregate.BAND_GROUPS_MEAN: - return self.norm(self.band_groups_mean(x, kept_indices, num_timesteps)) + return self.band_groups_mean(x, kept_indices, num_timesteps) return self.norm(x), kept_indices, removed_indices