generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change API to make more flexible in practice (#33)
* total change * format * format * format * oh well then * more * pre-commit * format * add more examples Co-authored-by: Nathan Simpson <[email protected]>
- Loading branch information
Nathan Simpson
and
Nathan Simpson
authored
Jun 24, 2022
1 parent
17b85c0
commit 5b8157c
Showing
26 changed files
with
3,605 additions
and
4,524 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 22.1.0 | ||
rev: 22.3.0 | ||
hooks: | ||
- id: black-jupyter | ||
|
||
|
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
celluloid | ||
git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor | ||
plothelp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import optax\n", | ||
"from jaxopt import OptaxSolver\n", | ||
"import relaxed\n", | ||
"from celluloid import Camera\n", | ||
"from functools import partial\n", | ||
"import matplotlib.lines as mlines\n", | ||
"\n", | ||
"# matplotlib settings\n", | ||
"plt.rc(\"figure\", figsize=(6, 3), dpi=220, facecolor=\"w\")\n", | ||
"plt.rc(\"legend\", fontsize=6)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Optimising a simple one-bin analysis with `relaxed`\n", | ||
"\n", | ||
"Let's define an analysis with a predicted number of signal and background events, with some uncertainty on the background estimate. We'll abstract the analysis configuration into a single parameter $\\phi$ like so:\n", | ||
"\n", | ||
"$$s = 15 + \\phi $$\n", | ||
"$$b = 45 - 2 \\phi $$\n", | ||
"$$\\sigma_b = 0.5 + 0.1*\\phi^2 $$\n", | ||
"\n", | ||
"Note that $s \\propto \\phi$ and $\\propto -2\\phi$, so increasing $\\phi$ corresponds to increasing the signal/backround ratio. However, our uncertainty scales like $\\phi^2$, so we're also going to compromise in our certainty of the background count as we do that. This kind of tradeoff between $s/b$ ratio and uncertainty is important for the discovery of a new signal, so we can't get away with optimising $s/b$ alone.\n", | ||
"\n", | ||
"To illustrate this, we'll plot the discovery significance for this model with and without uncertainty." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# model definition\n", | ||
"def yields(phi, uncertainty=True):\n", | ||
" s = 15 + phi\n", | ||
" b = 45 - 2 * phi\n", | ||
" db = (\n", | ||
" 0.5 + 0.1 * phi**2 if uncertainty else jnp.zeros_like(phi) + 0.001\n", | ||
" ) # small enough to be negligible\n", | ||
" return jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db])\n", | ||
"\n", | ||
"\n", | ||
"# our analysis pipeline, from phi to p-value\n", | ||
"def pipeline(phi, return_yields=False, uncertainty=True):\n", | ||
" y = yields(phi, uncertainty=uncertainty)\n", | ||
" # use a dummy version of pyhf for simplicity + compatibility with jax\n", | ||
" model = relaxed.dummy_pyhf.uncorrelated_background(*y)\n", | ||
" nominal_pars = jnp.array([1.0, 1.0])\n", | ||
" data = model.expected_data(nominal_pars) # we expect the nominal model\n", | ||
" # do the hypothesis test (and fit model pars with gradient descent)\n", | ||
" pvalue = relaxed.infer.hypotest(\n", | ||
" 0.0, # value of mu for the alternative hypothesis\n", | ||
" data,\n", | ||
" model,\n", | ||
" test_stat=\"q0\", # discovery significance test\n", | ||
" lr=1e-3,\n", | ||
" expected_pars=nominal_pars, # optionally providing MLE pars in advance\n", | ||
" )\n", | ||
" if return_yields:\n", | ||
" return pvalue, y\n", | ||
" else:\n", | ||
" return pvalue\n", | ||
"\n", | ||
"\n", | ||
"# calculate p-values for a range of phi values\n", | ||
"phis = jnp.linspace(0, 10, 100)\n", | ||
"\n", | ||
"# with uncertainty\n", | ||
"pipe = partial(pipeline, return_yields=True, uncertainty=True)\n", | ||
"pvals, ys = jax.vmap(pipe)(phis) # map over phi grid\n", | ||
"# without uncertainty\n", | ||
"pipe_no_uncertainty = partial(pipeline, uncertainty=False)\n", | ||
"pvals_no_uncertainty = jax.vmap(pipe_no_uncertainty)(phis)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"fig, axs = plt.subplots(2, 1, sharex=True)\n", | ||
"axs[0].plot(phis, pvals, label=\"with uncertainty\", color=\"C2\")\n", | ||
"axs[0].plot(phis, pvals_no_uncertainty, label=\"no uncertainty\", color=\"C4\")\n", | ||
"axs[0].set_ylabel(\"$p$-value\")\n", | ||
"# plot vertical dotted line at minimum of p-values + s/b\n", | ||
"best_phi = phis[jnp.argmin(pvals)]\n", | ||
"axs[0].axvline(x=best_phi, linestyle=\"dotted\", color=\"C2\", label=\"optimal p-value\")\n", | ||
"axs[0].axvline(\n", | ||
" x=phis[jnp.argmin(pvals_no_uncertainty)],\n", | ||
" linestyle=\"dotted\",\n", | ||
" color=\"C4\",\n", | ||
" label=r\"optimal $s/b$\",\n", | ||
")\n", | ||
"axs[0].legend(loc=\"upper left\", ncol=2)\n", | ||
"s, b, db = ys\n", | ||
"s, b, db = s.ravel(), b.ravel(), db.ravel() # everything is [[x]] for pyhf\n", | ||
"axs[1].fill_between(phis, s + b, b, color=\"C9\", label=\"signal\")\n", | ||
"axs[1].fill_between(phis, b, color=\"C1\", label=\"background\")\n", | ||
"axs[1].fill_between(phis, b - db, b + db, facecolor=\"k\", alpha=0.2, label=r\"$\\sigma_b$\")\n", | ||
"axs[1].set_xlabel(\"$\\phi$\")\n", | ||
"axs[1].set_ylabel(\"yield\")\n", | ||
"axs[1].legend(loc=\"lower left\")\n", | ||
"plt.suptitle(\"Discovery p-values, with and without uncertainty\")\n", | ||
"plt.tight_layout()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Using gradient descent, we can optimise this analysis in an uncertainty-aware way by directly optimising $\\phi$ for the lowest discovery p-value. Here's how you do that:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# The fast way!\n", | ||
"# use the OptaxSolver wrapper from jaxopt to perform the minimisation\n", | ||
"# set a couple of tolerance kwargs to make sure we don't get stuck\n", | ||
"solver = OptaxSolver(pipeline, opt=optax.adam(1e-3), tol=1e-8, maxiter=10000)\n", | ||
"pars = 9.0 # random init\n", | ||
"result = solver.run(pars).params\n", | ||
"print(\n", | ||
" f\"our solution: phi={result:.5f}\\ntrue optimum: phi={phis[jnp.argmin(pvals)]:.5f}\\nbest s/b: phi=10\"\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# The longer way (but with plots)!\n", | ||
"pipe = partial(pipeline, return_yields=True, uncertainty=True)\n", | ||
"solver = OptaxSolver(pipe, opt=optax.adam(1e-1), has_aux=True)\n", | ||
"pars = 9.0\n", | ||
"state = solver.init_state(pars) # we're doing init, update steps instead of .run()\n", | ||
"\n", | ||
"plt.rc(\"figure\", figsize=(6, 3), dpi=220, facecolor=\"w\")\n", | ||
"plt.rc(\"legend\", fontsize=8)\n", | ||
"fig, axs = plt.subplots(1, 2)\n", | ||
"cam = Camera(fig)\n", | ||
"steps = 5 # increase me for better results! (100ish works well)\n", | ||
"for i in range(steps):\n", | ||
" pars, state = solver.update(pars, state)\n", | ||
" s, b, db = state.aux\n", | ||
" val = state.value\n", | ||
" ax = axs[0]\n", | ||
" cv = ax.plot(phis, pvals, c=\"C0\")\n", | ||
" cvs = ax.plot(phis, pvals_no_uncertainty, c=\"green\")\n", | ||
" current = ax.scatter(pars, val, c=\"C0\")\n", | ||
" ax.set_xlabel(r\"analysis config $\\phi$\")\n", | ||
" ax.set_ylabel(\"p-value\")\n", | ||
" ax.legend(\n", | ||
" [\n", | ||
" mlines.Line2D([], [], color=\"C0\"),\n", | ||
" mlines.Line2D([], [], color=\"green\"),\n", | ||
" current,\n", | ||
" ],\n", | ||
" [\"p-value (with uncert)\", \"p-value (without uncert)\", \"current value\"],\n", | ||
" frameon=False,\n", | ||
" )\n", | ||
" ax.text(0.3, 0.61, f\"step {i}\", transform=ax.transAxes)\n", | ||
" ax = axs[1]\n", | ||
" ax.set_ylim((0, 80))\n", | ||
" b1 = ax.bar(0.5, b, facecolor=\"C1\", label=\"b\")\n", | ||
" b2 = ax.bar(0.5, s, bottom=b, facecolor=\"C9\", label=\"s\")\n", | ||
" b3 = ax.bar(\n", | ||
" 0.5, db, bottom=b - db / 2, facecolor=\"k\", alpha=0.5, label=r\"$\\sigma_b$\"\n", | ||
" )\n", | ||
" ax.set_ylabel(\"yield\")\n", | ||
" ax.set_xticks([])\n", | ||
" ax.legend([b1, b2, b3], [\"b\", \"s\", r\"$\\sigma_b$\"], frameon=False)\n", | ||
" plt.tight_layout()\n", | ||
" cam.snap()\n", | ||
"\n", | ||
"ani = cam.animate()\n", | ||
"# uncomment this to save and view the animation!\n", | ||
"# ani.save(\"ap00.gif\", fps=9)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "22d6333b89854cd01c2018f3ca2f5a59a2cde2765fbca789ff36cfad48ca629b" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.9.12 ('venv': venv)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.12" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.