diff --git a/apebench/__init__.py b/apebench/__init__.py index a90fc31..02d4519 100644 --- a/apebench/__init__.py +++ b/apebench/__init__.py @@ -20,6 +20,8 @@ from ._utils import ( aggregate_gmean, check_for_nan, + compute_pvalues_against_best, + cumulative_aggregation, melt_data, melt_loss, melt_metrics, @@ -57,4 +59,6 @@ "split_train", "components", "check_for_nan", + "cumulative_aggregation", + "compute_pvalues_against_best", ] diff --git a/apebench/_utils.py b/apebench/_utils.py index 4d45174..2c5c727 100644 --- a/apebench/_utils.py +++ b/apebench/_utils.py @@ -1,8 +1,10 @@ -from typing import Union +from typing import Literal, Union import jax import jax.numpy as jnp +import numpy as np import pandas as pd +from scipy import stats from scipy.stats import gmean from ._base_scenario import BaseScenario @@ -250,3 +252,257 @@ def check_for_nan(scene: BaseScenario): del test_data_coarse except NotImplementedError: return + + +def cumulative_aggregation( + df: pd.DataFrame, + grouping_cols: Union[str, list[str]], + rolling_col: str, + agg_fn_name: str = "mean", + prefix: str = "cum", +) -> pd.DataFrame: + """ + Apply a cumulative aggregation to a DataFrame. This can be used to + cumulatively aggregate a metric over time. + + If you are only interested in the aggregation over some rows (but not + cumulatively), you can use the [`apebench.aggregate_gmean`][] function. + + **Arguments:** + + - `df`: The DataFrame to aggregate. + - `grouping_cols`: The columns to group by. Supply a list of column names + such that if `df.groupby(grouping_cols)` is called, the groups contain + only different time steps (or whatever you want to aggregate over). + - `rolling_col`: The column to aggregate. Ususally, this is the name of the + metric in a long DataFrame. + - `agg_fn_name`: The aggregation function to use. Must be one of `"mean"` + or, `"gmean"` (geometric mean), or `"sum"`. + + **Returns:** + + - A DataFrame with the cumulatively aggregated columns added. + + !!! example + + Train a feedforward ConvNet to emulate advection and then display the + mean nRMSE error rollout and the cumulative mean/gmean mean nRMSE error + rollout. The only column that varies is the `"see"` column, so we group + by that. + + ```python + import apebench + advection_scenario = apebench.scenarios.difficulty.Advection() + + data, trained_net = advection_scenario(num_seeds=3) + + metric_data = apebench.melt_metrics(data) + + metric_df = apebench.cumulative_aggregation( + metric_df, "seed", "mean_nRMSE", agg_fn_name="mean" + ) + metric_df = apebench.cumulative_aggregation( + metric_df, "seed", "mean_nRMSE", agg_fn_name="gmean" + ) + + fig, ax = plt.subplots() + sns.lineplot( + data=metric_df, x="time_step", + y="mean_nRMSE", ax=ax, label="original" + ) + sns.lineplot( + data=metric_df, x="time_step", + y="cummean_mean_nRMSE", ax=ax, label="cummean" + ) + sns.lineplot( + data=metric_df, x="time_step", + y="cumgmean_mean_nRMSE", ax=ax, label="cumgmean" + ) + ``` + + ![Metric Rollout absolutely and cumulatively aggregated via mean and + gmean](https://github.com/user-attachments/assets/ba342f26-a6ef-4a67-8a2c-2b1d94ddfde1) + """ + agg_fn = { + "mean": np.mean, + "gmean": stats.gmean, + "sum": np.sum, + }[agg_fn_name] + return df.groupby(grouping_cols, observed=True, group_keys=False).apply( + lambda x: x.assign( + **{ + f"{prefix}{agg_fn_name}_{rolling_col}": x[rolling_col] + .expanding() + .apply(agg_fn) + } + ) + ) + + +def compute_pvalues_against_best( + df: pd.DataFrame, + grouping_cols: list[str], + sorting_cols: list[str], + value_col: str, + performance_indicator="mean", + alternative: Literal["two-sided", "less", "greater"] = "two-sided", + equal_var: bool = True, + pivot: bool = False, +): + """ + Performs a t-test of the best configuration in a group against all other + configurations and returns the p-values. "Best" is defined as the + configuration with the lowest aggregated value (typically the mean). + + The returned DataFrame can be interpreted in that the best configuation has + a p-value of 1.0 against itself (or 0.5 in case `alternative` is not + `"two-sided"`) and a lower p-value against all other configurations. Only if + the p-value against the other configurations is below the significance level + (typically 0.05), the best configuration can be considered significantly + better. + + **Arguments:** + + - `df`: The DataFrame to compute the p-values for. + - `grouping_cols`: The columns to group by. Select the `grouping_cols` such + that when `df.groupby(grouping_cols + sorting_cols)` is called, the groups + only contains single samples to be used for the hypothesis test. For + example, if you want to compare network architecture and training + methodology for all time steps and investigated scenarios, then you would + use `grouping_cols = ["scenario", "time_step"]` together with + `sorting_cols = ["net", "train"]`. + - `sorting_cols`: The columns to sort by. Once the dataframe is grouped by + `grouping_cols`, the configuration out of all combinations in + `sorting_cols` with the lowest aggregated value (typically the mean) is + considered the best. + - `value_col`: The column to use for the hypothesis test. Typically, this is + the column with a test metric. (The default metric in APEBench is + `mean_nRMSE`.) + - `performance_indicator`: The aggregation to determine the best + configuration. Typically, this is the mean, which is also what is checked + with the hypothesis test. + - `alternative`: The alternative hypothesis to test. Must be one of + `"two-sided"`, `"less"`, or `"greater"`. + - `equal_var`: Whether to assume equal variance in the two samples. (This is + a parameter of the t-test.) + - `pivot`: Whether to pivot the DataFrame to have the p-values in a matrix + form. This is useful for directly comparing across the `sorting_cols`. + + See also [this SciPy + page](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_ind.html) + for more details on the t-test. + + **Returns:** + + - A DataFrame with the p-values of the hypothesis test. If `pivot` is + `True`, + the DataFrame is pivoted such that the p-values are in a matrix form. + + + !!! example + + Test for the significantly better architecture when emulating advection + under `advection_gamma=2.0`. + + ```python + CONFIGS = [ + { + "scenario": "diff_adv", + "task": "predict", + "net": net, + "train": "one", + "start_seed": 0, + "num_seeds": 10, + "advection_gamma": 2.0, + } + for net in ["Conv;34;10;relu", "FNO;12;8;4;gelu", "UNet;12;2;relu"] + ] + + metric_df, _, _, _ = apebench.run_study_convenience( + CONFIGS, "conv_vs_fno_vs_unet" + ) + + p_value_df_pivoted = apebench.compute_pvalues_against_best( + metric_df, ["time_step",], ["net",], "mean_nRMSE", pivot=True + ) + + print( + p_value_df_pivoted.query( + "time_step in [1, 5, 10, 50, 100, 200]" + ).round(4).to_markdown() + ) + ``` + + | time_step | Conv;34;10;relu | FNO;12;8;4;gelu | UNet;12;2;relu | + |------------:|------------------:|------------------:|-----------------:| + | 1 | 1 | 0.0017 | 0 | + | 5 | 1 | 0.0013 | 0 | + | 10 | 1 | 0.0007 | 0 | + | 50 | 1 | 0.0001 | 0.0457 | + | 100 | 1 | 0 | 0.0601 | + | 200 | 1 | 0 | 0.0479 | + + We can also visualize the p-values over time: + ```python + import seaborn as sns + import matplotlib.pyplot as plt + + p_value_df = apebench.compute_pvalues_against_best( + metric_df, ["time_step",], ["net",], "mean_nRMSE", "mean" + ) + + sns.lineplot( + p_value_df, + x="time_step", + y="p_value", + hue="net", + ) + plt.hlines(0.05, 0, 200, linestyles="--", colors="black") + + plt.yscale("log") + ``` + ![pvalue rollout over time](https://github.com/user-attachments/assets/4e99ed9b-735f-499c-8374-a7d9d94d1aca) + + Since the p-value of the ConvNet (the best architecture according to its + mean performance) against the FNO is constantly below 0.05, we can + confidently say the ConvNet is significantly better than the FNO. + However, the UNet's performance's p-value is above 0.05 after ~50 time + steps and thus, for this temporal horizon, we cannot reject the null + hypothesis. + """ + stats_df = ( + df.groupby(grouping_cols + sorting_cols, observed=True, group_keys=True) + .agg( + performance_indicator=(value_col, performance_indicator), + mean=(value_col, "mean"), + std=(value_col, "std"), + count=(value_col, "count"), + ) + .reset_index() + ) + + stats_df = stats_df.groupby(grouping_cols, observed=True, group_keys=False).apply( + lambda x: x.assign( + p_value=stats.ttest_ind_from_stats( + mean1=x["mean"].values[x["performance_indicator"].argmin()], + std1=x["std"].values[x["performance_indicator"].argmin()], + nobs1=x["count"].values[x["performance_indicator"].argmin()], + mean2=x["mean"].values, + std2=x["std"].values, + nobs2=x["count"].values, + alternative=alternative, + equal_var=equal_var, + ).pvalue + ) + ) + + if not pivot: + return stats_df + + p_value_df = stats_df.pivot( + index=grouping_cols, + columns=sorting_cols, + values="p_value", + ) + + return p_value_df diff --git a/docs/api/postprocess/dataframe_manipulation.md b/docs/api/postprocess/dataframe_manipulation.md index 04aaeee..07044cd 100644 --- a/docs/api/postprocess/dataframe_manipulation.md +++ b/docs/api/postprocess/dataframe_manipulation.md @@ -12,4 +12,12 @@ --- -::: apebench.relative_by_config \ No newline at end of file +::: apebench.cumulative_aggregation + +--- + +::: apebench.relative_by_config + +--- + +::: apebench.compute_pvalues_against_best \ No newline at end of file diff --git a/docs/examples/statistical_analysis.ipynb b/docs/examples/statistical_analysis.ipynb new file mode 100644 index 0000000..1559425 --- /dev/null +++ b/docs/examples/statistical_analysis.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Statistical Analysis of emulator performance\n", + "\n", + "This notebook compares a feed-forward ConvNet and a FNO for emulating the 1D\n", + "advection equation. A focus will be on using statistical hypothesis tests to\n", + "answer which model is better, under which conditions.\n", + "\n", + "We will discuss the following:\n", + "\n", + "TODO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import apebench\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "import equinox as eqx\n", + "import exponax as ex\n", + "import jax\n", + "from scipy import stats\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONFIGS = [\n", + " {\n", + " \"scenario\": \"diff_adv\",\n", + " \"task\": \"predict\",\n", + " \"net\": net,\n", + " \"train\": \"one\",\n", + " \"start_seed\": 0,\n", + " \"num_seeds\": 20,\n", + " }\n", + " for net in [\n", + " \"Conv;34;10;relu\",\n", + " \"FNO;12;18;4;gelu\",\n", + " ]\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " df_metric,\n", + " df_loss,\n", + " _,\n", + " network_list,\n", + ") = apebench.run_study_convenience(\n", + " CONFIGS,\n", + " \"statistical_analysis\",\n", + " do_loss=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(df_loss, x=\"update_step\", y=\"train_loss\", hue=\"net\")\n", + "plt.yscale(\"log\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.lineplot(df_metric, x=\"time_step\", y=\"mean_nRMSE\", hue=\"net\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "advection_scenario = apebench.scenarios.difficulty.Advection(\n", + " advection_gamma=2.0, num_test_samples=300\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fno_data, fno_models = advection_scenario(\n", + " network_config=\"FNO;12;18;4;gelu\",\n", + " num_seeds=20,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fno_models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_ic_set = advection_scenario.get_test_ic_set()\n", + "test_trj = advection_scenario.get_test_data()\n", + "test_trj_no_init = test_trj[:, 1:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fno_rollout = eqx.filter_vmap(\n", + " lambda m: jax.vmap(ex.rollout(m, advection_scenario.test_temporal_horizon))(\n", + " test_ic_set\n", + " )\n", + ")(fno_models)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fno_rollout.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_trj.shape, test_trj_no_init.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_rollout = jax.vmap(\n", + " jax.vmap(jax.vmap(ex.metrics.nRMSE)),\n", + " in_axes=(0, None),\n", + ")(fno_rollout, test_trj_no_init)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_rollout.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(metric_rollout[0, :, 0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "p_value_for_normality = jnp.array(\n", + " [\n", + " [\n", + " stats.normaltest(metric_rollout[s, :, t]).pvalue\n", + " for t in range(advection_scenario.test_temporal_horizon)\n", + " ]\n", + " for s in range(20)\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For most seeds, most of the time snap shots, the distribution over the 30 test\n", + "samples is likely not normally distributed.\n", + "\n", + "Well... actually, it is" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.semilogy(p_value_for_normality[:20].T)\n", + "plt.hlines(0.05, 0, 200, colors=\"r\", linestyles=\"--\", linewidth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mean_metric_rollout = jnp.mean(metric_rollout, axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "p_value_for_normality_mean = jnp.array(\n", + " [\n", + " stats.shapiro(mean_metric_rollout[:, t]).pvalue\n", + " for t in range(advection_scenario.test_temporal_horizon)\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(p_value_for_normality_mean)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hence\n", + "we might need a non-parametric test to compare the two models???" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "apebench_dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mkdocs.yml b/mkdocs.yml index 286e4da..93f98b6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -99,6 +99,7 @@ nav: - Benchmark Flax linen models: 'examples/benchmark_flax_models_with_linen.ipynb' - Benchmark PyTorch models: 'examples/benchmark_pytorch_models.ipynb' - Randomness & Reproducibility: 'examples/sources_of_randomness_and_reproducibility.ipynb' + - Statistical Analysis: 'examples/statistical_analysis.ipynb' - Difficulty & Receptive Field in 1D Advection: 'examples/difficulty_and_receptive_field_advection_1d.ipynb' - Scenarios: - Overview: 'api/scenarios/overview.md'